Skip to main content

memvid_core/
vec.rs

1use blake3::hash;
2use serde::{Deserialize, Serialize};
3
4use crate::{MemvidError, Result, types::FrameId};
5
6#[cfg(any(feature = "vec", feature = "hnsw_bench"))]
7use hnsw::{Hnsw, Params, Searcher};
8#[cfg(any(feature = "vec", feature = "hnsw_bench"))]
9use rand_pcg::Pcg64;
10#[cfg(any(feature = "vec", feature = "hnsw_bench"))]
11use space::Metric;
12
13fn vec_config() -> impl bincode::config::Config {
14    bincode::config::standard()
15        .with_fixed_int_encoding()
16        .with_little_endian()
17}
18
19#[allow(clippy::cast_possible_truncation)]
20const VEC_DECODE_LIMIT: usize = crate::MAX_INDEX_BYTES as usize;
21
22#[cfg(any(feature = "vec", feature = "hnsw_bench"))]
23const HNSW_THRESHOLD: usize = 1000;
24/// Fixed-point scaling factor for HNSW distances.
25/// Necessary because `space::Metric` requires `Unit: Unsigned`, but we use f32 L2 distances.
26/// 100,000.0 gives 1e-5 precision and max distance ~42,000 (enough for high-dim embeddings).
27#[cfg(any(feature = "vec", feature = "hnsw_bench"))]
28const HNSW_DISTANCE_SCALE: f32 = 100_000.0;
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct VecDocument {
32    pub frame_id: FrameId,
33    pub embedding: Vec<f32>,
34}
35
36#[derive(Default)]
37pub struct VecIndexBuilder {
38    documents: Vec<VecDocument>,
39}
40
41impl VecIndexBuilder {
42    #[must_use]
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    pub fn add_document<I>(&mut self, frame_id: FrameId, embedding: I)
48    where
49        I: Into<Vec<f32>>,
50    {
51        self.documents.push(VecDocument {
52            frame_id,
53            embedding: embedding.into(),
54        });
55    }
56
57    pub fn finish(self) -> Result<VecIndexArtifact> {
58        #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
59        if self.documents.len() >= HNSW_THRESHOLD {
60            return self.finish_hnsw();
61        }
62
63        let bytes = bincode::serde::encode_to_vec(&self.documents, vec_config())?;
64
65        let checksum = *hash(&bytes).as_bytes();
66        let dimension = self
67            .documents
68            .first()
69            .map_or(0, |doc| u32::try_from(doc.embedding.len()).unwrap_or(0));
70        #[cfg(feature = "parallel_segments")]
71        let bytes_uncompressed = self
72            .documents
73            .iter()
74            .map(|doc| doc.embedding.len() * std::mem::size_of::<f32>())
75            .sum::<usize>() as u64;
76        Ok(VecIndexArtifact {
77            bytes,
78            vector_count: self.documents.len() as u64,
79            dimension,
80            checksum,
81            #[cfg(feature = "parallel_segments")]
82            bytes_uncompressed,
83        })
84    }
85
86    #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
87    #[allow(clippy::cast_possible_truncation)]
88    fn finish_hnsw(self) -> Result<VecIndexArtifact> {
89        let count = self.documents.len() as u64;
90        let dimension = self
91            .documents
92            .first()
93            .map(|d| d.embedding.len() as u32)
94            .unwrap_or(0);
95
96        #[cfg(feature = "parallel_segments")]
97        let bytes_uncompressed = self
98            .documents
99            .iter()
100            .map(|doc| doc.embedding.len() * std::mem::size_of::<f32>())
101            .sum::<usize>() as u64;
102
103        let index = HnswVecIndex::build(&self.documents)?;
104        let bytes = bincode::serde::encode_to_vec(&index, vec_config())?;
105        let checksum = *hash(&bytes).as_bytes();
106
107        Ok(VecIndexArtifact {
108            bytes,
109            vector_count: count,
110            dimension,
111            checksum,
112            #[cfg(feature = "parallel_segments")]
113            bytes_uncompressed,
114        })
115    }
116}
117
118#[derive(Debug, Clone)]
119pub struct VecIndexArtifact {
120    pub bytes: Vec<u8>,
121    pub vector_count: u64,
122    pub dimension: u32,
123    pub checksum: [u8; 32],
124    #[cfg(feature = "parallel_segments")]
125    pub bytes_uncompressed: u64,
126}
127
128#[derive(Debug, Clone)]
129pub enum VecIndex {
130    Uncompressed {
131        documents: Vec<VecDocument>,
132    },
133    Compressed(crate::vec_pq::QuantizedVecIndex),
134    #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
135    Hnsw(HnswVecIndex),
136}
137
138impl VecIndex {
139    /// Decode vector index from bytes
140    /// For backward compatibility, defaults to uncompressed if no manifest provided
141    pub fn decode(bytes: &[u8]) -> Result<Self> {
142        Self::decode_with_compression(bytes, crate::VectorCompression::None)
143    }
144
145    /// Decode vector index with compression mode from manifest
146    ///
147    /// ALWAYS tries uncompressed format first, regardless of compression flag.
148    /// This is necessary because `MIN_VECTORS_FOR_PQ` threshold (100 vectors)
149    /// causes most segments to be stored as uncompressed even when Pq96 is requested.
150    /// Falls back to PQ format for true compressed segments.
151    pub fn decode_with_compression(
152        bytes: &[u8],
153        _compression: crate::VectorCompression,
154    ) -> Result<Self> {
155        // Try uncompressed format first, regardless of compression flag.
156        // This is necessary because MIN_VECTORS_FOR_PQ threshold (100 vectors)
157        // causes most segments to be stored as uncompressed even when Pq96 is requested.
158        match bincode::serde::decode_from_slice::<Vec<VecDocument>, _>(
159            bytes,
160            bincode::config::standard()
161                .with_fixed_int_encoding()
162                .with_little_endian()
163                .with_limit::<VEC_DECODE_LIMIT>(),
164        ) {
165            Ok((documents, read)) if read == bytes.len() => {
166                tracing::debug!(
167                    bytes_len = bytes.len(),
168                    docs_count = documents.len(),
169                    "decoded as uncompressed"
170                );
171                return Ok(Self::Uncompressed { documents });
172            }
173            Ok((_, read)) => {
174                tracing::debug!(
175                    bytes_len = bytes.len(),
176                    read = read,
177                    "uncompressed decode partial read, trying HNSW/PQ"
178                );
179            }
180            Err(err) => {
181                tracing::debug!(
182                    error = %err,
183                    bytes_len = bytes.len(),
184                    "uncompressed decode failed, trying HNSW/PQ"
185                );
186            }
187        }
188
189        #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
190        {
191            match bincode::serde::decode_from_slice::<HnswVecIndex, _>(
192                bytes,
193                bincode::config::standard()
194                    .with_fixed_int_encoding()
195                    .with_little_endian()
196                    .with_limit::<VEC_DECODE_LIMIT>(),
197            ) {
198                Ok((index, _)) => {
199                    tracing::debug!(bytes_len = bytes.len(), "decoded as HNSW");
200                    return Ok(Self::Hnsw(index));
201                }
202                Err(err) => {
203                    tracing::debug!(
204                        error = %err,
205                        bytes_len = bytes.len(),
206                        "HNSW decode failed, trying PQ"
207                    );
208                }
209            }
210        }
211
212        // Try Product Quantization format
213        match crate::vec_pq::QuantizedVecIndex::decode(bytes) {
214            Ok(quantized_index) => {
215                tracing::debug!(bytes_len = bytes.len(), "decoded as PQ");
216                Ok(Self::Compressed(quantized_index))
217            }
218            Err(err) => {
219                tracing::debug!(
220                    error = %err,
221                    bytes_len = bytes.len(),
222                    "PQ decode also failed"
223                );
224                Err(MemvidError::InvalidToc {
225                    reason: "unsupported vector index encoding".into(),
226                })
227            }
228        }
229    }
230
231    #[must_use]
232    pub fn search(&self, query: &[f32], limit: usize) -> Vec<VecSearchHit> {
233        if query.is_empty() {
234            return Vec::new();
235        }
236        match self {
237            VecIndex::Uncompressed { documents } => {
238                let mut hits: Vec<VecSearchHit> = documents
239                    .iter()
240                    .map(|doc| {
241                        let distance = l2_distance(query, &doc.embedding);
242                        VecSearchHit {
243                            frame_id: doc.frame_id,
244                            distance,
245                        }
246                    })
247                    .collect();
248                hits.sort_by(|a, b| {
249                    a.distance
250                        .partial_cmp(&b.distance)
251                        .unwrap_or(std::cmp::Ordering::Equal)
252                });
253                hits.truncate(limit);
254                hits
255            }
256            VecIndex::Compressed(quantized) => quantized.search(query, limit),
257            #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
258            VecIndex::Hnsw(index) => index.search(query, limit),
259        }
260    }
261
262    #[must_use]
263    pub fn entries(&self) -> Box<dyn Iterator<Item = (FrameId, &[f32])> + '_> {
264        match self {
265            VecIndex::Uncompressed { documents } => Box::new(
266                documents
267                    .iter()
268                    .map(|doc| (doc.frame_id, doc.embedding.as_slice())),
269            ),
270            VecIndex::Compressed(_) => {
271                // Compressed vectors don't have direct f32 access
272                Box::new(std::iter::empty())
273            }
274            #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
275            VecIndex::Hnsw(_) => {
276                // HNSW graph doesn't easily iterate all embeddings
277                Box::new(std::iter::empty())
278            }
279        }
280    }
281
282    #[must_use]
283    pub fn embedding_for(&self, frame_id: FrameId) -> Option<&[f32]> {
284        match self {
285            VecIndex::Uncompressed { documents } => documents
286                .iter()
287                .find(|doc| doc.frame_id == frame_id)
288                .map(|doc| doc.embedding.as_slice()),
289            VecIndex::Compressed(_) => {
290                // Compressed vectors don't have direct f32 access
291                None
292            }
293            #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
294            VecIndex::Hnsw(_) => {
295                // HNSW storage is internal, would need traversal to find exact embedding
296                // For now, return None as we do for Compressed
297                None
298            }
299        }
300    }
301
302    pub fn remove(&mut self, frame_id: FrameId) {
303        match self {
304            VecIndex::Uncompressed { documents } => {
305                documents.retain(|doc| doc.frame_id != frame_id);
306            }
307            VecIndex::Compressed(_quantized) => {
308                // Compressed indices are immutable
309            }
310            #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
311            VecIndex::Hnsw(_) => {
312                // HNSW indices are immutable in this implementation
313            }
314        }
315    }
316}
317
318#[derive(Debug, Clone, PartialEq)]
319pub struct VecSearchHit {
320    pub frame_id: FrameId,
321    pub distance: f32,
322}
323
324fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
325    crate::simd::l2_distance_simd(a, b)
326}
327
328#[cfg(any(feature = "vec", feature = "hnsw_bench"))]
329#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct Euclidean;
331
332#[cfg(any(feature = "vec", feature = "hnsw_bench"))]
333impl Metric<Vec<f32>> for Euclidean {
334    type Unit = u32;
335    fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> u32 {
336        let d = l2_distance(a, b);
337        // Saturating cast prevents overflow for huge distances (though unlikely for embeddings)
338        (d * HNSW_DISTANCE_SCALE).min(u32::MAX as f32) as u32
339    }
340}
341
342#[cfg(any(feature = "vec", feature = "hnsw_bench"))]
343#[derive(Clone, Serialize, Deserialize)]
344#[allow(clippy::unsafe_derive_deserialize)]
345pub struct HnswVecIndex {
346    graph: Hnsw<Euclidean, Vec<f32>, Pcg64, 16, 32>,
347    ids: Vec<FrameId>,
348    dimension: u32,
349}
350
351#[cfg(any(feature = "vec", feature = "hnsw_bench"))]
352impl std::fmt::Debug for HnswVecIndex {
353    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
354        f.debug_struct("HnswVecIndex")
355            .field("dimension", &self.dimension)
356            .field("vector_count", &self.ids.len())
357            .finish_non_exhaustive()
358    }
359}
360
361#[cfg(any(feature = "vec", feature = "hnsw_bench"))]
362impl HnswVecIndex {
363    #[allow(clippy::cast_possible_truncation)]
364    pub fn build(documents: &[VecDocument]) -> Result<Self> {
365        let params = Params::new().ef_construction(100);
366        let mut graph = Hnsw::new_params(Euclidean, params);
367        let mut ids = Vec::with_capacity(documents.len());
368        let mut searcher = Searcher::default();
369
370        for doc in documents {
371            graph.insert(doc.embedding.clone(), &mut searcher);
372            ids.push(doc.frame_id);
373        }
374
375        Ok(Self {
376            graph,
377            ids,
378            dimension: documents
379                .first()
380                .map(|d| d.embedding.len() as u32)
381                .unwrap_or(0),
382        })
383    }
384
385    #[must_use]
386    pub fn search(&self, query: &[f32], limit: usize) -> Vec<VecSearchHit> {
387        // Use thread-local searcher and dest buffer to avoid per-query allocations
388        thread_local! {
389            static SEARCHER: std::cell::RefCell<Searcher<u32>> = std::cell::RefCell::new(Searcher::new());
390            static DEST: std::cell::RefCell<Vec<space::Neighbor<u32>>> = const { std::cell::RefCell::new(Vec::new()) };
391        }
392
393        // ef_search: query-time search width. Higher = better recall, slower search.
394        // Default: 50 as per maintainer specification. Can be exposed as option later.
395        let ef_search = 50;
396
397        SEARCHER.with(|searcher_cell| {
398            DEST.with(|dest_cell| {
399                let mut searcher = searcher_cell.borrow_mut();
400                let mut dest = dest_cell.borrow_mut();
401
402                // Ensure dest has enough capacity
403                let required_size = limit.max(ef_search);
404                if dest.len() < required_size {
405                    dest.resize(
406                        required_size,
407                        space::Neighbor {
408                            index: !0,
409                            distance: 0,
410                        },
411                    );
412                }
413
414                // Convert query slice to Vec for the graph
415                let query_vec: Vec<f32> = query.to_vec();
416
417                let found = self.graph.nearest(
418                    &query_vec,
419                    ef_search,
420                    &mut searcher,
421                    &mut dest[..required_size],
422                );
423
424                found
425                    .iter()
426                    .take(limit)
427                    .map(|neighbor| VecSearchHit {
428                        frame_id: self.ids[neighbor.index],
429                        distance: (neighbor.distance as f32) / HNSW_DISTANCE_SCALE,
430                    })
431                    .collect()
432            })
433        })
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440
441    #[test]
442    fn builder_roundtrip() {
443        let mut builder = VecIndexBuilder::new();
444        builder.add_document(1, vec![0.0, 1.0, 2.0]);
445        builder.add_document(2, vec![1.0, 2.0, 3.0]);
446        let artifact = builder.finish().expect("finish");
447        assert_eq!(artifact.vector_count, 2);
448        assert_eq!(artifact.dimension, 3);
449
450        let index = VecIndex::decode(&artifact.bytes).expect("decode");
451        let hits = index.search(&[0.0, 1.0, 2.0], 10);
452        assert_eq!(hits[0].frame_id, 1);
453    }
454
455    #[test]
456    fn l2_distance_behaves() {
457        let d = l2_distance(&[0.0, 0.0], &[3.0, 4.0]);
458        assert!((d - 5.0).abs() < 1e-6);
459    }
460
461    /// Test that HNSW is used for indices with >1000 vectors (HNSW_THRESHOLD)
462    #[test]
463    #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
464    fn hnsw_threshold_triggers_hnsw_index() {
465        use super::HNSW_THRESHOLD;
466
467        // Create index with exactly HNSW_THRESHOLD vectors
468        let mut builder = VecIndexBuilder::new();
469        let dim = 32;
470        for i in 0..HNSW_THRESHOLD {
471            let embedding: Vec<f32> = (0..dim).map(|j| (i * dim + j) as f32 / 1000.0).collect();
472            builder.add_document(i as FrameId, embedding);
473        }
474
475        let artifact = builder.finish().expect("finish hnsw");
476        assert_eq!(artifact.vector_count, HNSW_THRESHOLD as u64);
477
478        // Decode and verify it's an HNSW index
479        let index = VecIndex::decode(&artifact.bytes).expect("decode");
480        assert!(
481            matches!(index, VecIndex::Hnsw(_)),
482            "Expected HNSW index for {} vectors",
483            HNSW_THRESHOLD
484        );
485    }
486
487    /// Test that brute force is used for indices below threshold
488    #[test]
489    #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
490    fn below_threshold_uses_brute_force() {
491        use super::HNSW_THRESHOLD;
492
493        // Create index with fewer than HNSW_THRESHOLD vectors
494        let mut builder = VecIndexBuilder::new();
495        let count = HNSW_THRESHOLD - 1;
496        let dim = 32;
497        for i in 0..count {
498            let embedding: Vec<f32> = (0..dim).map(|j| (i * dim + j) as f32 / 1000.0).collect();
499            builder.add_document(i as FrameId, embedding);
500        }
501
502        let artifact = builder.finish().expect("finish brute force");
503        assert_eq!(artifact.vector_count, count as u64);
504
505        // Decode and verify it's NOT an HNSW index
506        let index = VecIndex::decode(&artifact.bytes).expect("decode");
507        assert!(
508            matches!(index, VecIndex::Uncompressed { .. }),
509            "Expected Uncompressed index for {} vectors",
510            count
511        );
512    }
513
514    /// Test HNSW search returns correct nearest neighbors
515    #[test]
516    #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
517    fn hnsw_search_finds_nearest_neighbors() {
518        use super::HNSW_THRESHOLD;
519
520        let mut builder = VecIndexBuilder::new();
521        let dim = 32;
522
523        // Insert HNSW_THRESHOLD vectors with predictable embeddings
524        for i in 0..HNSW_THRESHOLD {
525            let embedding: Vec<f32> = (0..dim).map(|_| i as f32).collect();
526            builder.add_document(i as FrameId, embedding);
527        }
528
529        let artifact = builder.finish().expect("finish");
530        let index = VecIndex::decode(&artifact.bytes).expect("decode");
531
532        // Query with a vector identical to frame_id=500
533        let query: Vec<f32> = (0..dim).map(|_| 500.0_f32).collect();
534        let hits = index.search(&query, 5);
535
536        assert!(!hits.is_empty(), "Should find at least one hit");
537        assert_eq!(
538            hits[0].frame_id, 500,
539            "Nearest neighbor should be exact match"
540        );
541        assert!(
542            hits[0].distance < 0.001,
543            "Distance to exact match should be near zero"
544        );
545    }
546
547    /// Test HNSW serialization/deserialization roundtrip
548    #[test]
549    #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
550    fn hnsw_serialization_roundtrip() {
551        use super::HNSW_THRESHOLD;
552
553        let mut builder = VecIndexBuilder::new();
554        let dim = 64;
555
556        for i in 0..HNSW_THRESHOLD {
557            let embedding: Vec<f32> = (0..dim).map(|j| ((i + j) % 100) as f32 / 100.0).collect();
558            builder.add_document(i as FrameId, embedding);
559        }
560
561        let artifact = builder.finish().expect("finish");
562        let original_bytes = artifact.bytes.clone();
563
564        // Decode
565        let index = VecIndex::decode(&original_bytes).expect("decode");
566        assert!(matches!(index, VecIndex::Hnsw(_)));
567
568        // Search before any re-serialization
569        let query: Vec<f32> = (0..dim).map(|j| (j % 100) as f32 / 100.0).collect();
570        let hits_1 = index.search(&query, 10);
571
572        // Decode again from same bytes (simulates loading from disk)
573        let index_2 = VecIndex::decode(&original_bytes).expect("decode again");
574        let hits_2 = index_2.search(&query, 10);
575
576        // Results should be identical
577        assert_eq!(hits_1.len(), hits_2.len());
578        for (h1, h2) in hits_1.iter().zip(hits_2.iter()) {
579            assert_eq!(h1.frame_id, h2.frame_id);
580            assert!((h1.distance - h2.distance).abs() < 1e-6);
581        }
582    }
583
584    /// Test HNSW with larger dataset to verify approximate search quality
585    #[test]
586    #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
587    fn hnsw_recall_quality() {
588        use super::HNSW_THRESHOLD;
589
590        let count = HNSW_THRESHOLD + 500; // 1500 vectors
591        let dim = 32;
592
593        // Build HNSW index
594        let mut builder = VecIndexBuilder::new();
595        let embeddings: Vec<Vec<f32>> = (0..count)
596            .map(|i| {
597                (0..dim)
598                    .map(|j| ((i * 7 + j * 13) % 1000) as f32 / 1000.0)
599                    .collect()
600            })
601            .collect();
602
603        for (i, emb) in embeddings.iter().enumerate() {
604            builder.add_document(i as FrameId, emb.clone());
605        }
606
607        let artifact = builder.finish().expect("finish");
608        let hnsw_index = VecIndex::decode(&artifact.bytes).expect("decode");
609
610        // Also build brute force index for ground truth
611        let brute_index = VecIndex::Uncompressed {
612            documents: embeddings
613                .iter()
614                .enumerate()
615                .map(|(i, emb)| VecDocument {
616                    frame_id: i as FrameId,
617                    embedding: emb.clone(),
618                })
619                .collect(),
620        };
621
622        // Query with vector similar to index 750
623        let query = embeddings[750].clone();
624        let k = 10;
625
626        let hnsw_hits = hnsw_index.search(&query, k);
627        let brute_hits = brute_index.search(&query, k);
628
629        // HNSW should find the exact match first
630        assert_eq!(hnsw_hits[0].frame_id, 750, "HNSW should find exact match");
631        assert_eq!(
632            brute_hits[0].frame_id, 750,
633            "Brute force should find exact match"
634        );
635
636        // Calculate recall: how many of top-k from HNSW are in top-k from brute force
637        let brute_set: std::collections::HashSet<_> =
638            brute_hits.iter().map(|h| h.frame_id).collect();
639        let recall = hnsw_hits
640            .iter()
641            .filter(|h| brute_set.contains(&h.frame_id))
642            .count();
643        let recall_ratio = recall as f32 / k as f32;
644
645        // HNSW should achieve at least 80% recall on this simple dataset
646        assert!(
647            recall_ratio >= 0.8,
648            "HNSW recall {} should be >= 0.8",
649            recall_ratio
650        );
651    }
652}