1use std::collections::HashMap;
12use std::fs;
13use std::io::Write;
14use std::path::{Path, PathBuf};
15
16use anyhow::{Context, Result};
17use ndarray::{Array2, ArrayView1};
18use serde::{Deserialize, Serialize};
19
20use crate::backend::{self, BackendConfig, BackendIndex};
21use crate::hnsw::search::SearchParams;
22use crate::index::DistanceMetric;
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct TokenLabel {
31 pub doc_id: u32,
33 pub seq_id: u32,
35 #[serde(default)]
37 pub metadata: HashMap<String, serde_json::Value>,
38}
39
40struct PendingDoc {
46 doc_id: u32,
47 embeddings: Array2<f32>,
48 metadata: HashMap<String, serde_json::Value>,
49}
50
51pub struct MultiVectorBuilder {
53 dim: usize,
54 pending: Vec<PendingDoc>,
55 backend_config: BackendConfig,
56}
57
58impl MultiVectorBuilder {
59 pub fn new(dim: usize) -> Self {
61 let mut config = BackendConfig::hnsw_default();
62 config.set_distance_metric(DistanceMetric::Mips);
64 config.set_recompute(false);
66 config.set_compact(false);
67 Self {
68 dim,
69 pending: Vec::new(),
70 backend_config: config,
71 }
72 }
73
74 pub fn set_m(&mut self, m: usize) -> &mut Self {
76 self.backend_config.set_m(m);
77 self
78 }
79
80 pub fn set_ef_construction(&mut self, ef: usize) -> &mut Self {
82 self.backend_config.set_ef_construction(ef);
83 self
84 }
85
86 pub fn insert(
90 &mut self,
91 doc_id: u32,
92 embeddings: Array2<f32>,
93 metadata: HashMap<String, serde_json::Value>,
94 ) -> &mut Self {
95 assert_eq!(
96 embeddings.ncols(),
97 self.dim,
98 "embedding dim {} != expected {}",
99 embeddings.ncols(),
100 self.dim
101 );
102 self.pending.push(PendingDoc {
103 doc_id,
104 embeddings,
105 metadata,
106 });
107 self
108 }
109
110 pub fn build(&self, index_path: &Path) -> Result<()> {
117 anyhow::ensure!(!self.pending.is_empty(), "no documents inserted");
118
119 let total_tokens: usize = self.pending.iter().map(|d| d.embeddings.nrows()).sum();
121 let mut flat = Array2::<f32>::zeros((total_tokens, self.dim));
122 let mut labels = Vec::with_capacity(total_tokens);
123
124 let mut row = 0;
125 for doc in &self.pending {
126 for seq_id in 0..doc.embeddings.nrows() {
127 flat.row_mut(row).assign(&doc.embeddings.row(seq_id));
128 labels.push(TokenLabel {
129 doc_id: doc.doc_id,
130 seq_id: seq_id as u32,
131 metadata: doc.metadata.clone(),
132 });
133 row += 1;
134 }
135 }
136
137 let index_file = with_ext(index_path, "index");
139 backend::build_backend(&self.backend_config, &flat, &index_file, None)?;
140
141 let labels_file = with_ext(index_path, "labels.json");
143 let labels_json = serde_json::to_string(&labels)?;
144 fs::write(&labels_file, labels_json)
145 .with_context(|| format!("writing {}", labels_file.display()))?;
146
147 let npy_file = with_ext(index_path, "emb.npy");
149 write_npy(&flat, &npy_file)?;
150
151 Ok(())
152 }
153}
154
155pub struct MultiVectorSearcher {
161 index: BackendIndex,
162 labels: Vec<TokenLabel>,
163 doc_to_rows: HashMap<u32, Vec<usize>>,
165 #[cfg(feature = "multi-vector")]
167 emb_mmap: memmap2::Mmap,
168 #[cfg(not(feature = "multi-vector"))]
169 emb_data: Vec<u8>,
170 dim: usize,
171 total_tokens: usize,
172}
173
174impl MultiVectorSearcher {
175 pub fn open(index_path: &Path) -> Result<Self> {
177 let index_file = with_ext(index_path, "index");
179 let index = backend::read_backend_index("hnsw", &index_file)?;
180
181 let labels_file = with_ext(index_path, "labels.json");
183 let labels_data = fs::read_to_string(&labels_file)
184 .with_context(|| format!("reading {}", labels_file.display()))?;
185 let labels: Vec<TokenLabel> = serde_json::from_str(&labels_data)?;
186
187 let mut doc_to_rows: HashMap<u32, Vec<usize>> = HashMap::new();
189 for (i, label) in labels.iter().enumerate() {
190 doc_to_rows.entry(label.doc_id).or_default().push(i);
191 }
192
193 let dim = index.dimensions();
194 let total_tokens = labels.len();
195
196 let npy_file = with_ext(index_path, "emb.npy");
198
199 #[cfg(feature = "multi-vector")]
200 let emb_mmap = {
201 let file = fs::File::open(&npy_file)
202 .with_context(|| format!("opening {}", npy_file.display()))?;
203 unsafe { memmap2::Mmap::map(&file)? }
204 };
205
206 Ok(Self {
207 index,
208 labels,
209 doc_to_rows,
210 #[cfg(feature = "multi-vector")]
211 emb_mmap,
212 #[cfg(not(feature = "multi-vector"))]
213 emb_data: fs::read(&npy_file)?,
214 dim,
215 total_tokens,
216 })
217 }
218
219 pub fn num_docs(&self) -> usize {
221 self.doc_to_rows.len()
222 }
223
224 pub fn num_tokens(&self) -> usize {
226 self.total_tokens
227 }
228
229 pub fn search(
236 &self,
237 query_tokens: &Array2<f32>,
238 top_k: usize,
239 ) -> Result<Vec<MultiVectorResult>> {
240 self.search_with_params(query_tokens, top_k, 50)
241 }
242
243 pub fn search_with_params(
245 &self,
246 query_tokens: &Array2<f32>,
247 top_k: usize,
248 per_token_k: usize,
249 ) -> Result<Vec<MultiVectorResult>> {
250 let params = SearchParams::default();
251
252 let mut doc_scores: HashMap<u32, f32> = HashMap::new();
255
256 for qi in 0..query_tokens.nrows() {
257 let query_vec = query_tokens.row(qi);
258 let query_slice = query_vec.as_slice().unwrap();
259
260 let (labels_idx, distances) =
261 backend::search_backend(&self.index, query_slice, per_token_k, ¶ms);
262
263 let mut best_per_doc: HashMap<u32, f32> = HashMap::new();
266 for (idx, dist) in labels_idx.into_iter().zip(distances) {
267 if idx >= self.labels.len() {
268 continue;
269 }
270 let doc_id = self.labels[idx].doc_id;
271 let sim = -dist; let entry = best_per_doc.entry(doc_id).or_insert(f32::NEG_INFINITY);
273 if sim > *entry {
274 *entry = sim;
275 }
276 }
277
278 for (doc_id, score) in best_per_doc {
280 *doc_scores.entry(doc_id).or_insert(0.0) += score;
281 }
282 }
283
284 Ok(top_k_results(
285 &doc_scores,
286 top_k,
287 &self.doc_to_rows,
288 &self.labels,
289 ))
290 }
291
292 pub fn search_exact(
297 &self,
298 query_tokens: &Array2<f32>,
299 top_k: usize,
300 first_stage_k: usize,
301 ) -> Result<Vec<MultiVectorResult>> {
302 let approx = self.search_with_params(query_tokens, first_stage_k, 50)?;
304 let candidate_docs: Vec<u32> = approx.iter().map(|r| r.doc_id).collect();
305
306 if candidate_docs.is_empty() {
307 return Ok(Vec::new());
308 }
309
310 let emb_bytes = self.emb_bytes();
312 let (header_len, _rows, _cols) = parse_npy_header(emb_bytes)?;
313 let data_start = header_len;
314 let float_data = &emb_bytes[data_start..];
315
316 let mut doc_scores: HashMap<u32, f32> = HashMap::new();
318 for &doc_id in &candidate_docs {
319 if let Some(row_indices) = self.doc_to_rows.get(&doc_id) {
320 let score = exact_max_sim(query_tokens, float_data, row_indices, self.dim);
321 doc_scores.insert(doc_id, score);
322 }
323 }
324
325 Ok(top_k_results(
326 &doc_scores,
327 top_k,
328 &self.doc_to_rows,
329 &self.labels,
330 ))
331 }
332
333 fn emb_bytes(&self) -> &[u8] {
334 #[cfg(feature = "multi-vector")]
335 {
336 &self.emb_mmap
337 }
338 #[cfg(not(feature = "multi-vector"))]
339 {
340 &self.emb_data
341 }
342 }
343}
344
345#[derive(Debug, Clone)]
351pub struct MultiVectorResult {
352 pub doc_id: u32,
353 pub score: f32,
354 pub metadata: HashMap<String, serde_json::Value>,
356}
357
358fn exact_max_sim(
364 query_tokens: &Array2<f32>,
365 float_data: &[u8],
366 doc_row_indices: &[usize],
367 dim: usize,
368) -> f32 {
369 let mut total = 0.0f32;
370 for qi in 0..query_tokens.nrows() {
371 let q = query_tokens.row(qi);
372 let mut best = f32::NEG_INFINITY;
373 for &row_idx in doc_row_indices {
374 let offset = row_idx * dim * 4;
375 let end = offset + dim * 4;
376 if end > float_data.len() {
377 continue;
378 }
379 let dot = dot_product_bytes(q, &float_data[offset..end]);
380 if dot > best {
381 best = dot;
382 }
383 }
384 if best > f32::NEG_INFINITY {
385 total += best;
386 }
387 }
388 total
389}
390
391#[inline]
393fn dot_product_bytes(a: ArrayView1<f32>, b_bytes: &[u8]) -> f32 {
394 let mut sum = 0.0f32;
395 for (i, &ai) in a.iter().enumerate() {
396 let offset = i * 4;
397 let bi = f32::from_le_bytes(b_bytes[offset..offset + 4].try_into().unwrap());
398 sum += ai * bi;
399 }
400 sum
401}
402
403fn top_k_results(
405 doc_scores: &HashMap<u32, f32>,
406 top_k: usize,
407 doc_to_rows: &HashMap<u32, Vec<usize>>,
408 labels: &[TokenLabel],
409) -> Vec<MultiVectorResult> {
410 let mut entries: Vec<(u32, f32)> = doc_scores.iter().map(|(&d, &s)| (d, s)).collect();
411 entries.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
412 entries.truncate(top_k);
413
414 entries
415 .into_iter()
416 .map(|(doc_id, score)| {
417 let metadata = doc_to_rows
418 .get(&doc_id)
419 .and_then(|rows| rows.first())
420 .map(|&idx| labels[idx].metadata.clone())
421 .unwrap_or_default();
422 MultiVectorResult {
423 doc_id,
424 score,
425 metadata,
426 }
427 })
428 .collect()
429}
430
431fn write_npy(arr: &Array2<f32>, path: &Path) -> Result<()> {
437 let (rows, cols) = arr.dim();
438 let header = format!(
439 "{{'descr': '<f4', 'fortran_order': False, 'shape': ({}, {}), }}",
440 rows, cols
441 );
442 let prefix_len = 10; let total_unpadded = prefix_len + header.len() + 1; let padding = (64 - (total_unpadded % 64)) % 64;
446 let header_content_len = header.len() + padding + 1; let mut file = fs::File::create(path)?;
449 file.write_all(&[0x93, b'N', b'U', b'M', b'P', b'Y'])?;
451 file.write_all(&[1, 0])?;
453 file.write_all(&(header_content_len as u16).to_le_bytes())?;
455 file.write_all(header.as_bytes())?;
457 for _ in 0..padding {
458 file.write_all(b" ")?;
459 }
460 file.write_all(b"\n")?;
461
462 for val in arr.iter() {
464 file.write_all(&val.to_le_bytes())?;
465 }
466
467 Ok(())
468}
469
470fn parse_npy_header(data: &[u8]) -> Result<(usize, usize, usize)> {
472 anyhow::ensure!(data.len() >= 10, "npy file too small");
473 anyhow::ensure!(&data[0..6] == b"\x93NUMPY", "invalid npy magic");
474
475 let header_len = u16::from_le_bytes([data[8], data[9]]) as usize;
476 let header_end = 10 + header_len;
477 anyhow::ensure!(data.len() >= header_end, "npy header truncated");
478
479 let header_str = std::str::from_utf8(&data[10..header_end])?;
480 let shape_start = header_str
482 .find("'shape': (")
483 .context("no shape in npy header")?
484 + "'shape': (".len();
485 let shape_end = header_str[shape_start..]
486 .find(')')
487 .context("unclosed shape tuple")?
488 + shape_start;
489 let shape_str = &header_str[shape_start..shape_end];
490 let dims: Vec<usize> = shape_str
491 .split(',')
492 .filter_map(|s| s.trim().parse().ok())
493 .collect();
494
495 anyhow::ensure!(dims.len() == 2, "expected 2D shape, got {:?}", dims);
496
497 Ok((header_end, dims[0], dims[1]))
498}
499
500fn with_ext(base: &Path, ext: &str) -> PathBuf {
505 let mut p = base.to_path_buf();
506 let name = p
507 .file_name()
508 .unwrap_or_default()
509 .to_string_lossy()
510 .to_string();
511 p.set_file_name(format!("{}.{}", name, ext));
512 p
513}
514
515#[cfg(test)]
520mod tests {
521 use super::*;
522 use ndarray::array;
523
524 fn make_test_data() -> (Array2<f32>, Array2<f32>, Array2<f32>) {
525 let doc0 = array![
527 [1.0, 0.0, 0.0, 0.0],
528 [0.9, 0.1, 0.0, 0.0],
529 [0.8, 0.2, 0.0, 0.0],
530 ];
531 let doc1 = array![[0.0, 1.0, 0.0, 0.0], [0.1, 0.9, 0.0, 0.0],];
533 let query = array![[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0],];
535 (doc0, doc1, query)
536 }
537
538 #[test]
539 fn test_build_and_search() {
540 let dir = tempfile::tempdir().unwrap();
541 let index_path = dir.path().join("test_mv");
542
543 let (doc0, doc1, query) = make_test_data();
544
545 let mut builder = MultiVectorBuilder::new(4);
546 builder.insert(0, doc0, HashMap::new());
547 builder.insert(1, doc1, HashMap::new());
548 builder.build(&index_path).unwrap();
549
550 assert!(with_ext(&index_path, "index").exists());
552 assert!(with_ext(&index_path, "labels.json").exists());
553 assert!(with_ext(&index_path, "emb.npy").exists());
554
555 let searcher = MultiVectorSearcher::open(&index_path).unwrap();
556 assert_eq!(searcher.num_docs(), 2);
557 assert_eq!(searcher.num_tokens(), 5);
558
559 let results = searcher.search(&query, 2).unwrap();
561 assert_eq!(results.len(), 2);
562 let exact_results = searcher.search_exact(&query, 2, 10).unwrap();
566 assert_eq!(exact_results.len(), 2);
567 }
568
569 #[test]
570 fn test_max_sim_scoring() {
571 let dir = tempfile::tempdir().unwrap();
572 let index_path = dir.path().join("test_scoring");
573
574 let doc0 = array![[1.0, 0.0, 0.0, 0.0]];
576 let doc1 = array![[0.0, 1.0, 0.0, 0.0]];
578 let query = array![[1.0, 0.0, 0.0, 0.0]];
580
581 let mut builder = MultiVectorBuilder::new(4);
582 builder.insert(0, doc0, HashMap::new());
583 builder.insert(1, doc1, HashMap::new());
584 builder.build(&index_path).unwrap();
585
586 let searcher = MultiVectorSearcher::open(&index_path).unwrap();
587 let results = searcher.search_exact(&query, 2, 10).unwrap();
588
589 assert_eq!(results[0].doc_id, 0);
590 assert!(results[0].score > results[1].score);
591 assert!((results[0].score - 1.0).abs() < 1e-5);
592 assert!((results[1].score - 0.0).abs() < 1e-5);
593 }
594
595 #[test]
596 fn test_npy_roundtrip() {
597 let dir = tempfile::tempdir().unwrap();
598 let path = dir.path().join("test.npy");
599
600 let arr = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
601 write_npy(&arr, &path).unwrap();
602
603 let data = fs::read(&path).unwrap();
604 let (header_len, rows, cols) = parse_npy_header(&data).unwrap();
605 assert_eq!(rows, 2);
606 assert_eq!(cols, 3);
607
608 let float_data = &data[header_len..];
609 assert_eq!(float_data.len(), 2 * 3 * 4);
610 let first = f32::from_le_bytes(float_data[0..4].try_into().unwrap());
611 assert!((first - 1.0).abs() < 1e-6);
612 }
613
614 #[test]
615 fn test_metadata_propagation() {
616 let dir = tempfile::tempdir().unwrap();
617 let index_path = dir.path().join("test_meta");
618
619 let doc0 = array![[1.0, 0.0]];
620 let mut meta = HashMap::new();
621 meta.insert("filepath".to_string(), serde_json::json!("/tmp/page1.png"));
622
623 let mut builder = MultiVectorBuilder::new(2);
624 builder.insert(42, doc0, meta);
625 builder.build(&index_path).unwrap();
626
627 let searcher = MultiVectorSearcher::open(&index_path).unwrap();
628 let query = array![[1.0, 0.0]];
629 let results = searcher.search(&query, 1).unwrap();
630
631 assert_eq!(results[0].doc_id, 42);
632 assert_eq!(results[0].metadata["filepath"], "/tmp/page1.png");
633 }
634
635 #[test]
636 fn test_many_docs_ranking() {
637 let dir = tempfile::tempdir().unwrap();
640 let index_path = dir.path().join("test_many");
641 let dim = 16;
642
643 let mut builder = MultiVectorBuilder::new(dim);
644 for doc_id in 0..10u32 {
645 let mut tokens = Array2::<f32>::zeros((3, dim));
646 for t in 0..3 {
648 tokens[[t, doc_id as usize]] = 1.0;
649 tokens[[t, (doc_id as usize + 1) % dim]] = 0.1 * (t as f32);
651 }
652 builder.insert(doc_id, tokens, HashMap::new());
653 }
654 builder.build(&index_path).unwrap();
655
656 let searcher = MultiVectorSearcher::open(&index_path).unwrap();
657 assert_eq!(searcher.num_docs(), 10);
658
659 let mut query = Array2::<f32>::zeros((1, dim));
661 query[[0, 5]] = 1.0;
662
663 let results = searcher.search_exact(&query, 3, 30).unwrap();
664 assert_eq!(results[0].doc_id, 5);
665 }
666
667 #[test]
668 fn test_multi_token_query_aggregation() {
669 let dir = tempfile::tempdir().unwrap();
671 let index_path = dir.path().join("test_agg");
672
673 let doc0 = array![[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0],];
675 let doc1 = array![[0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.9, 0.1],];
677
678 let mut builder = MultiVectorBuilder::new(4);
679 builder.insert(0, doc0, HashMap::new());
680 builder.insert(1, doc1, HashMap::new());
681 builder.build(&index_path).unwrap();
682
683 let searcher = MultiVectorSearcher::open(&index_path).unwrap();
684
685 let query = array![[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0],];
688 let results = searcher.search_exact(&query, 2, 10).unwrap();
689 assert_eq!(results[0].doc_id, 0);
690 assert!((results[0].score - 2.0).abs() < 1e-5);
692 assert!(results[1].score < 0.2);
694 }
695
696 #[test]
697 fn test_single_doc_single_token() {
698 let dir = tempfile::tempdir().unwrap();
699 let index_path = dir.path().join("test_single");
700
701 let doc = array![[0.6, 0.8]];
702 let mut builder = MultiVectorBuilder::new(2);
703 builder.insert(0, doc, HashMap::new());
704 builder.build(&index_path).unwrap();
705
706 let searcher = MultiVectorSearcher::open(&index_path).unwrap();
707 assert_eq!(searcher.num_docs(), 1);
708 assert_eq!(searcher.num_tokens(), 1);
709
710 let query = array![[0.6, 0.8]];
711 let results = searcher.search(&query, 1).unwrap();
712 assert_eq!(results.len(), 1);
713 assert!((results[0].score - 1.0).abs() < 1e-5);
715 }
716
717 #[test]
718 fn test_top_k_limits_results() {
719 let dir = tempfile::tempdir().unwrap();
720 let index_path = dir.path().join("test_topk");
721
722 let mut builder = MultiVectorBuilder::new(4);
723 for i in 0..5u32 {
724 let doc = array![[1.0, 0.0, 0.0, 0.0]];
725 builder.insert(i, doc, HashMap::new());
726 }
727 builder.build(&index_path).unwrap();
728
729 let searcher = MultiVectorSearcher::open(&index_path).unwrap();
730 let query = array![[1.0, 0.0, 0.0, 0.0]];
731
732 let results = searcher.search(&query, 3).unwrap();
733 assert_eq!(results.len(), 3);
734
735 let results_all = searcher.search(&query, 10).unwrap();
736 assert_eq!(results_all.len(), 5);
737 }
738
739 #[test]
740 fn test_variable_token_counts() {
741 let dir = tempfile::tempdir().unwrap();
743 let index_path = dir.path().join("test_vartok");
744
745 let doc0 = array![[1.0, 0.0]]; let doc1 = array![[0.0, 1.0], [0.5, 0.5], [0.3, 0.7]]; let doc2 = array![[0.7, 0.7], [0.8, 0.6]]; let mut builder = MultiVectorBuilder::new(2);
750 builder.insert(0, doc0, HashMap::new());
751 builder.insert(1, doc1, HashMap::new());
752 builder.insert(2, doc2, HashMap::new());
753 builder.build(&index_path).unwrap();
754
755 let searcher = MultiVectorSearcher::open(&index_path).unwrap();
756 assert_eq!(searcher.num_docs(), 3);
757 assert_eq!(searcher.num_tokens(), 6); let query = array![[0.0, 1.0]];
760 let results = searcher.search_exact(&query, 3, 10).unwrap();
761 assert_eq!(results.len(), 3);
762 assert_eq!(results[0].doc_id, 1);
764 }
765
766 #[test]
767 fn test_labels_sidecar_format() {
768 let dir = tempfile::tempdir().unwrap();
769 let index_path = dir.path().join("test_labels");
770
771 let doc0 = array![[1.0, 0.0], [0.0, 1.0]];
772 let doc1 = array![[0.5, 0.5]];
773
774 let mut meta0 = HashMap::new();
775 meta0.insert("page".to_string(), serde_json::json!(1));
776
777 let mut builder = MultiVectorBuilder::new(2);
778 builder.insert(10, doc0, meta0);
779 builder.insert(20, doc1, HashMap::new());
780 builder.build(&index_path).unwrap();
781
782 let labels_path = with_ext(&index_path, "labels.json");
784 let data = fs::read_to_string(&labels_path).unwrap();
785 let labels: Vec<TokenLabel> = serde_json::from_str(&data).unwrap();
786
787 assert_eq!(labels.len(), 3);
788 assert_eq!(labels[0].doc_id, 10);
789 assert_eq!(labels[0].seq_id, 0);
790 assert_eq!(labels[0].metadata["page"], 1);
791 assert_eq!(labels[1].doc_id, 10);
792 assert_eq!(labels[1].seq_id, 1);
793 assert_eq!(labels[2].doc_id, 20);
794 assert_eq!(labels[2].seq_id, 0);
795 assert!(labels[2].metadata.is_empty());
796 }
797
798 #[test]
799 fn test_exact_vs_approximate_consistency() {
800 let dir = tempfile::tempdir().unwrap();
803 let index_path = dir.path().join("test_consistency");
804
805 let dim = 8;
807 let mut builder = MultiVectorBuilder::new(dim);
808 for i in 0..8u32 {
809 let mut emb = Array2::<f32>::zeros((1, dim));
810 emb[[0, i as usize]] = 1.0;
811 builder.insert(i, emb, HashMap::new());
812 }
813 builder.build(&index_path).unwrap();
814
815 let searcher = MultiVectorSearcher::open(&index_path).unwrap();
816 let mut query = Array2::<f32>::zeros((1, dim));
817 query[[0, 2]] = 1.0;
818
819 let exact = searcher.search_exact(&query, 1, 10).unwrap();
820 assert_eq!(exact[0].doc_id, 2);
821 assert!((exact[0].score - 1.0).abs() < 1e-5);
822
823 let approx = searcher.search(&query, 1).unwrap();
825 assert_eq!(approx[0].doc_id, 2);
826 }
827
828 #[test]
829 #[should_panic(expected = "no documents inserted")]
830 fn test_build_empty_panics() {
831 let dir = tempfile::tempdir().unwrap();
832 let index_path = dir.path().join("test_empty");
833 let builder = MultiVectorBuilder::new(4);
834 builder.build(&index_path).unwrap();
835 }
836
837 #[test]
838 #[should_panic(expected = "embedding dim 3 != expected 4")]
839 fn test_dimension_mismatch_panics() {
840 let mut builder = MultiVectorBuilder::new(4);
841 builder.insert(0, array![[1.0, 2.0, 3.0]], HashMap::new());
842 }
843}