1use blake3::hash;
2use serde::{Deserialize, Serialize};
3
4use crate::{MemvidError, Result, types::FrameId};
5
6#[cfg(any(feature = "vec", feature = "hnsw_bench"))]
7use hnsw::{Hnsw, Params, Searcher};
8#[cfg(any(feature = "vec", feature = "hnsw_bench"))]
9use rand_pcg::Pcg64;
10#[cfg(any(feature = "vec", feature = "hnsw_bench"))]
11use space::Metric;
12
13fn vec_config() -> impl bincode::config::Config {
14 bincode::config::standard()
15 .with_fixed_int_encoding()
16 .with_little_endian()
17}
18
19#[allow(clippy::cast_possible_truncation)]
20const VEC_DECODE_LIMIT: usize = crate::MAX_INDEX_BYTES as usize;
21
22#[cfg(any(feature = "vec", feature = "hnsw_bench"))]
23const HNSW_THRESHOLD: usize = 1000;
24#[cfg(any(feature = "vec", feature = "hnsw_bench"))]
28const HNSW_DISTANCE_SCALE: f32 = 100_000.0;
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct VecDocument {
32 pub frame_id: FrameId,
33 pub embedding: Vec<f32>,
34}
35
36#[derive(Default)]
37pub struct VecIndexBuilder {
38 documents: Vec<VecDocument>,
39}
40
41impl VecIndexBuilder {
42 #[must_use]
43 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn add_document<I>(&mut self, frame_id: FrameId, embedding: I)
48 where
49 I: Into<Vec<f32>>,
50 {
51 self.documents.push(VecDocument {
52 frame_id,
53 embedding: embedding.into(),
54 });
55 }
56
57 pub fn finish(self) -> Result<VecIndexArtifact> {
58 #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
59 if self.documents.len() >= HNSW_THRESHOLD {
60 return self.finish_hnsw();
61 }
62
63 let bytes = bincode::serde::encode_to_vec(&self.documents, vec_config())?;
64
65 let checksum = *hash(&bytes).as_bytes();
66 let dimension = self
67 .documents
68 .first()
69 .map_or(0, |doc| u32::try_from(doc.embedding.len()).unwrap_or(0));
70 #[cfg(feature = "parallel_segments")]
71 let bytes_uncompressed = self
72 .documents
73 .iter()
74 .map(|doc| doc.embedding.len() * std::mem::size_of::<f32>())
75 .sum::<usize>() as u64;
76 Ok(VecIndexArtifact {
77 bytes,
78 vector_count: self.documents.len() as u64,
79 dimension,
80 checksum,
81 #[cfg(feature = "parallel_segments")]
82 bytes_uncompressed,
83 })
84 }
85
86 #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
87 #[allow(clippy::cast_possible_truncation)]
88 fn finish_hnsw(self) -> Result<VecIndexArtifact> {
89 let count = self.documents.len() as u64;
90 let dimension = self
91 .documents
92 .first()
93 .map(|d| d.embedding.len() as u32)
94 .unwrap_or(0);
95
96 #[cfg(feature = "parallel_segments")]
97 let bytes_uncompressed = self
98 .documents
99 .iter()
100 .map(|doc| doc.embedding.len() * std::mem::size_of::<f32>())
101 .sum::<usize>() as u64;
102
103 let index = HnswVecIndex::build(&self.documents)?;
104 let bytes = bincode::serde::encode_to_vec(&index, vec_config())?;
105 let checksum = *hash(&bytes).as_bytes();
106
107 Ok(VecIndexArtifact {
108 bytes,
109 vector_count: count,
110 dimension,
111 checksum,
112 #[cfg(feature = "parallel_segments")]
113 bytes_uncompressed,
114 })
115 }
116}
117
118#[derive(Debug, Clone)]
119pub struct VecIndexArtifact {
120 pub bytes: Vec<u8>,
121 pub vector_count: u64,
122 pub dimension: u32,
123 pub checksum: [u8; 32],
124 #[cfg(feature = "parallel_segments")]
125 pub bytes_uncompressed: u64,
126}
127
128#[derive(Debug, Clone)]
129pub enum VecIndex {
130 Uncompressed {
131 documents: Vec<VecDocument>,
132 },
133 Compressed(crate::vec_pq::QuantizedVecIndex),
134 #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
135 Hnsw(HnswVecIndex),
136}
137
138impl VecIndex {
139 pub fn decode(bytes: &[u8]) -> Result<Self> {
142 Self::decode_with_compression(bytes, crate::VectorCompression::None)
143 }
144
145 pub fn decode_with_compression(
152 bytes: &[u8],
153 _compression: crate::VectorCompression,
154 ) -> Result<Self> {
155 match bincode::serde::decode_from_slice::<Vec<VecDocument>, _>(
159 bytes,
160 bincode::config::standard()
161 .with_fixed_int_encoding()
162 .with_little_endian()
163 .with_limit::<VEC_DECODE_LIMIT>(),
164 ) {
165 Ok((documents, read)) if read == bytes.len() => {
166 tracing::debug!(
167 bytes_len = bytes.len(),
168 docs_count = documents.len(),
169 "decoded as uncompressed"
170 );
171 return Ok(Self::Uncompressed { documents });
172 }
173 Ok((_, read)) => {
174 tracing::debug!(
175 bytes_len = bytes.len(),
176 read = read,
177 "uncompressed decode partial read, trying HNSW/PQ"
178 );
179 }
180 Err(err) => {
181 tracing::debug!(
182 error = %err,
183 bytes_len = bytes.len(),
184 "uncompressed decode failed, trying HNSW/PQ"
185 );
186 }
187 }
188
189 #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
190 {
191 match bincode::serde::decode_from_slice::<HnswVecIndex, _>(
192 bytes,
193 bincode::config::standard()
194 .with_fixed_int_encoding()
195 .with_little_endian()
196 .with_limit::<VEC_DECODE_LIMIT>(),
197 ) {
198 Ok((index, _)) => {
199 tracing::debug!(bytes_len = bytes.len(), "decoded as HNSW");
200 return Ok(Self::Hnsw(index));
201 }
202 Err(err) => {
203 tracing::debug!(
204 error = %err,
205 bytes_len = bytes.len(),
206 "HNSW decode failed, trying PQ"
207 );
208 }
209 }
210 }
211
212 match crate::vec_pq::QuantizedVecIndex::decode(bytes) {
214 Ok(quantized_index) => {
215 tracing::debug!(bytes_len = bytes.len(), "decoded as PQ");
216 Ok(Self::Compressed(quantized_index))
217 }
218 Err(err) => {
219 tracing::debug!(
220 error = %err,
221 bytes_len = bytes.len(),
222 "PQ decode also failed"
223 );
224 Err(MemvidError::InvalidToc {
225 reason: "unsupported vector index encoding".into(),
226 })
227 }
228 }
229 }
230
231 #[must_use]
232 pub fn search(&self, query: &[f32], limit: usize) -> Vec<VecSearchHit> {
233 if query.is_empty() {
234 return Vec::new();
235 }
236 match self {
237 VecIndex::Uncompressed { documents } => {
238 let mut hits: Vec<VecSearchHit> = documents
239 .iter()
240 .map(|doc| {
241 let distance = l2_distance(query, &doc.embedding);
242 VecSearchHit {
243 frame_id: doc.frame_id,
244 distance,
245 }
246 })
247 .collect();
248 hits.sort_by(|a, b| {
249 a.distance
250 .partial_cmp(&b.distance)
251 .unwrap_or(std::cmp::Ordering::Equal)
252 });
253 hits.truncate(limit);
254 hits
255 }
256 VecIndex::Compressed(quantized) => quantized.search(query, limit),
257 #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
258 VecIndex::Hnsw(index) => index.search(query, limit),
259 }
260 }
261
262 #[must_use]
263 pub fn entries(&self) -> Box<dyn Iterator<Item = (FrameId, &[f32])> + '_> {
264 match self {
265 VecIndex::Uncompressed { documents } => Box::new(
266 documents
267 .iter()
268 .map(|doc| (doc.frame_id, doc.embedding.as_slice())),
269 ),
270 VecIndex::Compressed(_) => {
271 Box::new(std::iter::empty())
273 }
274 #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
275 VecIndex::Hnsw(_) => {
276 Box::new(std::iter::empty())
278 }
279 }
280 }
281
282 #[must_use]
283 pub fn embedding_for(&self, frame_id: FrameId) -> Option<&[f32]> {
284 match self {
285 VecIndex::Uncompressed { documents } => documents
286 .iter()
287 .find(|doc| doc.frame_id == frame_id)
288 .map(|doc| doc.embedding.as_slice()),
289 VecIndex::Compressed(_) => {
290 None
292 }
293 #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
294 VecIndex::Hnsw(_) => {
295 None
298 }
299 }
300 }
301
302 pub fn remove(&mut self, frame_id: FrameId) {
303 match self {
304 VecIndex::Uncompressed { documents } => {
305 documents.retain(|doc| doc.frame_id != frame_id);
306 }
307 VecIndex::Compressed(_quantized) => {
308 }
310 #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
311 VecIndex::Hnsw(_) => {
312 }
314 }
315 }
316}
317
318#[derive(Debug, Clone, PartialEq)]
319pub struct VecSearchHit {
320 pub frame_id: FrameId,
321 pub distance: f32,
322}
323
324fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
325 crate::simd::l2_distance_simd(a, b)
326}
327
328#[cfg(any(feature = "vec", feature = "hnsw_bench"))]
329#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct Euclidean;
331
332#[cfg(any(feature = "vec", feature = "hnsw_bench"))]
333impl Metric<Vec<f32>> for Euclidean {
334 type Unit = u32;
335 fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> u32 {
336 let d = l2_distance(a, b);
337 (d * HNSW_DISTANCE_SCALE).min(u32::MAX as f32) as u32
339 }
340}
341
342#[cfg(any(feature = "vec", feature = "hnsw_bench"))]
343#[derive(Clone, Serialize, Deserialize)]
344#[allow(clippy::unsafe_derive_deserialize)]
345pub struct HnswVecIndex {
346 graph: Hnsw<Euclidean, Vec<f32>, Pcg64, 16, 32>,
347 ids: Vec<FrameId>,
348 dimension: u32,
349}
350
351#[cfg(any(feature = "vec", feature = "hnsw_bench"))]
352impl std::fmt::Debug for HnswVecIndex {
353 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
354 f.debug_struct("HnswVecIndex")
355 .field("dimension", &self.dimension)
356 .field("vector_count", &self.ids.len())
357 .finish_non_exhaustive()
358 }
359}
360
361#[cfg(any(feature = "vec", feature = "hnsw_bench"))]
362impl HnswVecIndex {
363 #[allow(clippy::cast_possible_truncation)]
364 pub fn build(documents: &[VecDocument]) -> Result<Self> {
365 let params = Params::new().ef_construction(100);
366 let mut graph = Hnsw::new_params(Euclidean, params);
367 let mut ids = Vec::with_capacity(documents.len());
368 let mut searcher = Searcher::default();
369
370 for doc in documents {
371 graph.insert(doc.embedding.clone(), &mut searcher);
372 ids.push(doc.frame_id);
373 }
374
375 Ok(Self {
376 graph,
377 ids,
378 dimension: documents
379 .first()
380 .map(|d| d.embedding.len() as u32)
381 .unwrap_or(0),
382 })
383 }
384
385 #[must_use]
386 pub fn search(&self, query: &[f32], limit: usize) -> Vec<VecSearchHit> {
387 thread_local! {
389 static SEARCHER: std::cell::RefCell<Searcher<u32>> = std::cell::RefCell::new(Searcher::new());
390 static DEST: std::cell::RefCell<Vec<space::Neighbor<u32>>> = const { std::cell::RefCell::new(Vec::new()) };
391 }
392
393 let ef_search = 50;
396
397 SEARCHER.with(|searcher_cell| {
398 DEST.with(|dest_cell| {
399 let mut searcher = searcher_cell.borrow_mut();
400 let mut dest = dest_cell.borrow_mut();
401
402 let required_size = limit.max(ef_search);
404 if dest.len() < required_size {
405 dest.resize(
406 required_size,
407 space::Neighbor {
408 index: !0,
409 distance: 0,
410 },
411 );
412 }
413
414 let query_vec: Vec<f32> = query.to_vec();
416
417 let found = self.graph.nearest(
418 &query_vec,
419 ef_search,
420 &mut searcher,
421 &mut dest[..required_size],
422 );
423
424 found
425 .iter()
426 .take(limit)
427 .map(|neighbor| VecSearchHit {
428 frame_id: self.ids[neighbor.index],
429 distance: (neighbor.distance as f32) / HNSW_DISTANCE_SCALE,
430 })
431 .collect()
432 })
433 })
434 }
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440
441 #[test]
442 fn builder_roundtrip() {
443 let mut builder = VecIndexBuilder::new();
444 builder.add_document(1, vec![0.0, 1.0, 2.0]);
445 builder.add_document(2, vec![1.0, 2.0, 3.0]);
446 let artifact = builder.finish().expect("finish");
447 assert_eq!(artifact.vector_count, 2);
448 assert_eq!(artifact.dimension, 3);
449
450 let index = VecIndex::decode(&artifact.bytes).expect("decode");
451 let hits = index.search(&[0.0, 1.0, 2.0], 10);
452 assert_eq!(hits[0].frame_id, 1);
453 }
454
455 #[test]
456 fn l2_distance_behaves() {
457 let d = l2_distance(&[0.0, 0.0], &[3.0, 4.0]);
458 assert!((d - 5.0).abs() < 1e-6);
459 }
460
461 #[test]
463 #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
464 fn hnsw_threshold_triggers_hnsw_index() {
465 use super::HNSW_THRESHOLD;
466
467 let mut builder = VecIndexBuilder::new();
469 let dim = 32;
470 for i in 0..HNSW_THRESHOLD {
471 let embedding: Vec<f32> = (0..dim).map(|j| (i * dim + j) as f32 / 1000.0).collect();
472 builder.add_document(i as FrameId, embedding);
473 }
474
475 let artifact = builder.finish().expect("finish hnsw");
476 assert_eq!(artifact.vector_count, HNSW_THRESHOLD as u64);
477
478 let index = VecIndex::decode(&artifact.bytes).expect("decode");
480 assert!(
481 matches!(index, VecIndex::Hnsw(_)),
482 "Expected HNSW index for {} vectors",
483 HNSW_THRESHOLD
484 );
485 }
486
487 #[test]
489 #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
490 fn below_threshold_uses_brute_force() {
491 use super::HNSW_THRESHOLD;
492
493 let mut builder = VecIndexBuilder::new();
495 let count = HNSW_THRESHOLD - 1;
496 let dim = 32;
497 for i in 0..count {
498 let embedding: Vec<f32> = (0..dim).map(|j| (i * dim + j) as f32 / 1000.0).collect();
499 builder.add_document(i as FrameId, embedding);
500 }
501
502 let artifact = builder.finish().expect("finish brute force");
503 assert_eq!(artifact.vector_count, count as u64);
504
505 let index = VecIndex::decode(&artifact.bytes).expect("decode");
507 assert!(
508 matches!(index, VecIndex::Uncompressed { .. }),
509 "Expected Uncompressed index for {} vectors",
510 count
511 );
512 }
513
514 #[test]
516 #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
517 fn hnsw_search_finds_nearest_neighbors() {
518 use super::HNSW_THRESHOLD;
519
520 let mut builder = VecIndexBuilder::new();
521 let dim = 32;
522
523 for i in 0..HNSW_THRESHOLD {
525 let embedding: Vec<f32> = (0..dim).map(|_| i as f32).collect();
526 builder.add_document(i as FrameId, embedding);
527 }
528
529 let artifact = builder.finish().expect("finish");
530 let index = VecIndex::decode(&artifact.bytes).expect("decode");
531
532 let query: Vec<f32> = (0..dim).map(|_| 500.0_f32).collect();
534 let hits = index.search(&query, 5);
535
536 assert!(!hits.is_empty(), "Should find at least one hit");
537 assert_eq!(
538 hits[0].frame_id, 500,
539 "Nearest neighbor should be exact match"
540 );
541 assert!(
542 hits[0].distance < 0.001,
543 "Distance to exact match should be near zero"
544 );
545 }
546
547 #[test]
549 #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
550 fn hnsw_serialization_roundtrip() {
551 use super::HNSW_THRESHOLD;
552
553 let mut builder = VecIndexBuilder::new();
554 let dim = 64;
555
556 for i in 0..HNSW_THRESHOLD {
557 let embedding: Vec<f32> = (0..dim).map(|j| ((i + j) % 100) as f32 / 100.0).collect();
558 builder.add_document(i as FrameId, embedding);
559 }
560
561 let artifact = builder.finish().expect("finish");
562 let original_bytes = artifact.bytes.clone();
563
564 let index = VecIndex::decode(&original_bytes).expect("decode");
566 assert!(matches!(index, VecIndex::Hnsw(_)));
567
568 let query: Vec<f32> = (0..dim).map(|j| (j % 100) as f32 / 100.0).collect();
570 let hits_1 = index.search(&query, 10);
571
572 let index_2 = VecIndex::decode(&original_bytes).expect("decode again");
574 let hits_2 = index_2.search(&query, 10);
575
576 assert_eq!(hits_1.len(), hits_2.len());
578 for (h1, h2) in hits_1.iter().zip(hits_2.iter()) {
579 assert_eq!(h1.frame_id, h2.frame_id);
580 assert!((h1.distance - h2.distance).abs() < 1e-6);
581 }
582 }
583
584 #[test]
586 #[cfg(any(feature = "vec", feature = "hnsw_bench"))]
587 fn hnsw_recall_quality() {
588 use super::HNSW_THRESHOLD;
589
590 let count = HNSW_THRESHOLD + 500; let dim = 32;
592
593 let mut builder = VecIndexBuilder::new();
595 let embeddings: Vec<Vec<f32>> = (0..count)
596 .map(|i| {
597 (0..dim)
598 .map(|j| ((i * 7 + j * 13) % 1000) as f32 / 1000.0)
599 .collect()
600 })
601 .collect();
602
603 for (i, emb) in embeddings.iter().enumerate() {
604 builder.add_document(i as FrameId, emb.clone());
605 }
606
607 let artifact = builder.finish().expect("finish");
608 let hnsw_index = VecIndex::decode(&artifact.bytes).expect("decode");
609
610 let brute_index = VecIndex::Uncompressed {
612 documents: embeddings
613 .iter()
614 .enumerate()
615 .map(|(i, emb)| VecDocument {
616 frame_id: i as FrameId,
617 embedding: emb.clone(),
618 })
619 .collect(),
620 };
621
622 let query = embeddings[750].clone();
624 let k = 10;
625
626 let hnsw_hits = hnsw_index.search(&query, k);
627 let brute_hits = brute_index.search(&query, k);
628
629 assert_eq!(hnsw_hits[0].frame_id, 750, "HNSW should find exact match");
631 assert_eq!(
632 brute_hits[0].frame_id, 750,
633 "Brute force should find exact match"
634 );
635
636 let brute_set: std::collections::HashSet<_> =
638 brute_hits.iter().map(|h| h.frame_id).collect();
639 let recall = hnsw_hits
640 .iter()
641 .filter(|h| brute_set.contains(&h.frame_id))
642 .count();
643 let recall_ratio = recall as f32 / k as f32;
644
645 assert!(
647 recall_ratio >= 0.8,
648 "HNSW recall {} should be >= 0.8",
649 recall_ratio
650 );
651 }
652}