memvid_core/
vec.rs

1use blake3::hash;
2use serde::{Deserialize, Serialize};
3
4use crate::{MemvidError, Result, types::FrameId};
5
6fn vec_config() -> impl bincode::config::Config {
7    bincode::config::standard()
8        .with_fixed_int_encoding()
9        .with_little_endian()
10}
11
12const VEC_DECODE_LIMIT: usize = crate::MAX_INDEX_BYTES as usize;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct VecDocument {
16    pub frame_id: FrameId,
17    pub embedding: Vec<f32>,
18}
19
20#[derive(Default)]
21pub struct VecIndexBuilder {
22    documents: Vec<VecDocument>,
23}
24
25impl VecIndexBuilder {
26    pub fn new() -> Self {
27        Self::default()
28    }
29
30    pub fn add_document<I>(&mut self, frame_id: FrameId, embedding: I)
31    where
32        I: Into<Vec<f32>>,
33    {
34        self.documents.push(VecDocument {
35            frame_id,
36            embedding: embedding.into(),
37        });
38    }
39
40    pub fn finish(self) -> Result<VecIndexArtifact> {
41        let bytes = bincode::serde::encode_to_vec(&self.documents, vec_config())?;
42
43        let checksum = *hash(&bytes).as_bytes();
44        let dimension = self
45            .documents
46            .first()
47            .map(|doc| doc.embedding.len() as u32)
48            .unwrap_or(0);
49        #[cfg(feature = "parallel_segments")]
50        let bytes_uncompressed = self
51            .documents
52            .iter()
53            .map(|doc| doc.embedding.len() * std::mem::size_of::<f32>())
54            .sum::<usize>() as u64;
55        Ok(VecIndexArtifact {
56            bytes,
57            vector_count: self.documents.len() as u64,
58            dimension,
59            checksum,
60            #[cfg(feature = "parallel_segments")]
61            bytes_uncompressed,
62        })
63    }
64}
65
66#[derive(Debug, Clone)]
67pub struct VecIndexArtifact {
68    pub bytes: Vec<u8>,
69    pub vector_count: u64,
70    pub dimension: u32,
71    pub checksum: [u8; 32],
72    #[cfg(feature = "parallel_segments")]
73    pub bytes_uncompressed: u64,
74}
75
76#[derive(Debug, Clone)]
77pub enum VecIndex {
78    Uncompressed { documents: Vec<VecDocument> },
79    Compressed(crate::vec_pq::QuantizedVecIndex),
80}
81
82impl VecIndex {
83    /// Decode vector index from bytes
84    /// For backward compatibility, defaults to uncompressed if no manifest provided
85    pub fn decode(bytes: &[u8]) -> Result<Self> {
86        Self::decode_with_compression(bytes, crate::VectorCompression::None)
87    }
88
89    /// Decode vector index with compression mode from manifest
90    ///
91    /// ALWAYS tries uncompressed format first, regardless of compression flag.
92    /// This is necessary because MIN_VECTORS_FOR_PQ threshold (100 vectors)
93    /// causes most segments to be stored as uncompressed even when Pq96 is requested.
94    /// Falls back to PQ format for true compressed segments.
95    pub fn decode_with_compression(
96        bytes: &[u8],
97        _compression: crate::VectorCompression,
98    ) -> Result<Self> {
99        // Try uncompressed format first, regardless of compression flag.
100        // This is necessary because MIN_VECTORS_FOR_PQ threshold (100 vectors)
101        // causes most segments to be stored as uncompressed even when Pq96 is requested.
102        match bincode::serde::decode_from_slice::<Vec<VecDocument>, _>(
103            bytes,
104            bincode::config::standard()
105                .with_fixed_int_encoding()
106                .with_little_endian()
107                .with_limit::<VEC_DECODE_LIMIT>(),
108        ) {
109            Ok((documents, read)) if read == bytes.len() => {
110                tracing::debug!(
111                    bytes_len = bytes.len(),
112                    docs_count = documents.len(),
113                    "decoded as uncompressed"
114                );
115                return Ok(Self::Uncompressed { documents });
116            }
117            Ok((_, read)) => {
118                tracing::debug!(
119                    bytes_len = bytes.len(),
120                    read = read,
121                    "uncompressed decode partial read, trying PQ"
122                );
123            }
124            Err(err) => {
125                tracing::debug!(
126                    error = %err,
127                    bytes_len = bytes.len(),
128                    "uncompressed decode failed, trying PQ"
129                );
130            }
131        }
132
133        // Try Product Quantization format
134        match crate::vec_pq::QuantizedVecIndex::decode(bytes) {
135            Ok(quantized_index) => {
136                tracing::debug!(bytes_len = bytes.len(), "decoded as PQ");
137                Ok(Self::Compressed(quantized_index))
138            }
139            Err(err) => {
140                tracing::debug!(
141                    error = %err,
142                    bytes_len = bytes.len(),
143                    "PQ decode also failed"
144                );
145                Err(MemvidError::InvalidToc {
146                    reason: "unsupported vector index encoding".into(),
147                })
148            }
149        }
150    }
151
152    pub fn search(&self, query: &[f32], limit: usize) -> Vec<VecSearchHit> {
153        if query.is_empty() {
154            return Vec::new();
155        }
156        match self {
157            VecIndex::Uncompressed { documents } => {
158                let mut hits: Vec<VecSearchHit> = documents
159                    .iter()
160                    .map(|doc| {
161                        let distance = l2_distance(query, &doc.embedding);
162                        VecSearchHit {
163                            frame_id: doc.frame_id,
164                            distance,
165                        }
166                    })
167                    .collect();
168                hits.sort_by(|a, b| {
169                    a.distance
170                        .partial_cmp(&b.distance)
171                        .unwrap_or(std::cmp::Ordering::Equal)
172                });
173                hits.truncate(limit);
174                hits
175            }
176            VecIndex::Compressed(quantized) => quantized.search(query, limit),
177        }
178    }
179
180    pub fn entries(&self) -> Box<dyn Iterator<Item = (FrameId, &[f32])> + '_> {
181        match self {
182            VecIndex::Uncompressed { documents } => Box::new(
183                documents
184                    .iter()
185                    .map(|doc| (doc.frame_id, doc.embedding.as_slice())),
186            ),
187            VecIndex::Compressed(_) => {
188                // Compressed vectors don't have direct f32 access
189                Box::new(std::iter::empty())
190            }
191        }
192    }
193
194    pub fn embedding_for(&self, frame_id: FrameId) -> Option<&[f32]> {
195        match self {
196            VecIndex::Uncompressed { documents } => documents
197                .iter()
198                .find(|doc| doc.frame_id == frame_id)
199                .map(|doc| doc.embedding.as_slice()),
200            VecIndex::Compressed(_) => {
201                // Compressed vectors don't have direct f32 access
202                None
203            }
204        }
205    }
206
207    pub fn remove(&mut self, frame_id: FrameId) {
208        match self {
209            VecIndex::Uncompressed { documents } => {
210                documents.retain(|doc| doc.frame_id != frame_id);
211            }
212            VecIndex::Compressed(_quantized) => {
213                // TODO: Implement removal for compressed indices
214                // For now, compressed indices are immutable
215            }
216        }
217    }
218}
219
220#[derive(Debug, Clone, PartialEq)]
221pub struct VecSearchHit {
222    pub frame_id: FrameId,
223    pub distance: f32,
224}
225
226fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
227    a.iter()
228        .zip(b.iter())
229        .map(|(x, y)| (x - y).powi(2))
230        .sum::<f32>()
231        .sqrt()
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn builder_roundtrip() {
240        let mut builder = VecIndexBuilder::new();
241        builder.add_document(1, vec![0.0, 1.0, 2.0]);
242        builder.add_document(2, vec![1.0, 2.0, 3.0]);
243        let artifact = builder.finish().expect("finish");
244        assert_eq!(artifact.vector_count, 2);
245        assert_eq!(artifact.dimension, 3);
246
247        let index = VecIndex::decode(&artifact.bytes).expect("decode");
248        let hits = index.search(&[0.0, 1.0, 2.0], 10);
249        assert_eq!(hits[0].frame_id, 1);
250    }
251
252    #[test]
253    fn l2_distance_behaves() {
254        let d = l2_distance(&[0.0, 0.0], &[3.0, 4.0]);
255        assert!((d - 5.0).abs() < 1e-6);
256    }
257}