Skip to main content

leann_core/
multi_vector.rs

1//! Multi-vector (ColBERT) index support.
2//!
3//! Builds on the existing HNSW backend by flattening all token vectors from all
4//! documents into a single index (one HNSW node per token), then aggregating
5//! results per-document at query time using the ColBERT MaxSim formula:
6//!
7//! ```text
8//! score(Q, D) = Σ_i max_j (q_i · d_j)
9//! ```
10
11use std::collections::HashMap;
12use std::fs;
13use std::io::Write;
14use std::path::{Path, PathBuf};
15
16use anyhow::{Context, Result};
17use ndarray::{Array2, ArrayView1};
18use serde::{Deserialize, Serialize};
19
20use crate::backend::{self, BackendConfig, BackendIndex};
21use crate::hnsw::search::SearchParams;
22use crate::index::DistanceMetric;
23
24// ---------------------------------------------------------------------------
25// Token label — one per HNSW node
26// ---------------------------------------------------------------------------
27
28/// Metadata for a single token vector in the flattened index.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct TokenLabel {
31    /// Document this token belongs to.
32    pub doc_id: u32,
33    /// Token position within the document.
34    pub seq_id: u32,
35    /// Arbitrary per-document metadata (filepath, image_path, etc.).
36    #[serde(default)]
37    pub metadata: HashMap<String, serde_json::Value>,
38}
39
40// ---------------------------------------------------------------------------
41// Builder
42// ---------------------------------------------------------------------------
43
44/// Pending document: token embeddings + metadata, buffered before index build.
45struct PendingDoc {
46    doc_id: u32,
47    embeddings: Array2<f32>,
48    metadata: HashMap<String, serde_json::Value>,
49}
50
51/// Builds a multi-vector index from per-document token embeddings.
52pub struct MultiVectorBuilder {
53    dim: usize,
54    pending: Vec<PendingDoc>,
55    backend_config: BackendConfig,
56}
57
58impl MultiVectorBuilder {
59    /// Create a builder for the given embedding dimension.
60    pub fn new(dim: usize) -> Self {
61        let mut config = BackendConfig::hnsw_default();
62        // Multi-vector indexes use MIPS (inner product) for ColBERT scoring.
63        config.set_distance_metric(DistanceMetric::Mips);
64        // Store vectors so we can do exact MaxSim reranking.
65        config.set_recompute(false);
66        config.set_compact(false);
67        Self {
68            dim,
69            pending: Vec::new(),
70            backend_config: config,
71        }
72    }
73
74    /// Set the HNSW M parameter.
75    pub fn set_m(&mut self, m: usize) -> &mut Self {
76        self.backend_config.set_m(m);
77        self
78    }
79
80    /// Set the HNSW efConstruction parameter.
81    pub fn set_ef_construction(&mut self, ef: usize) -> &mut Self {
82        self.backend_config.set_ef_construction(ef);
83        self
84    }
85
86    /// Insert a document's token embeddings.
87    ///
88    /// `embeddings` has shape `[num_tokens, dim]`.
89    pub fn insert(
90        &mut self,
91        doc_id: u32,
92        embeddings: Array2<f32>,
93        metadata: HashMap<String, serde_json::Value>,
94    ) -> &mut Self {
95        assert_eq!(
96            embeddings.ncols(),
97            self.dim,
98            "embedding dim {} != expected {}",
99            embeddings.ncols(),
100            self.dim
101        );
102        self.pending.push(PendingDoc {
103            doc_id,
104            embeddings,
105            metadata,
106        });
107        self
108    }
109
110    /// Build the index and write it to `index_path`.
111    ///
112    /// Produces:
113    /// - `<index_path>.index` — HNSW binary index
114    /// - `<index_path>.labels.json` — per-node token labels
115    /// - `<index_path>.emb.npy` — raw embedding matrix for exact reranking
116    pub fn build(&self, index_path: &Path) -> Result<()> {
117        anyhow::ensure!(!self.pending.is_empty(), "no documents inserted");
118
119        // Flatten all token vectors + build labels.
120        let total_tokens: usize = self.pending.iter().map(|d| d.embeddings.nrows()).sum();
121        let mut flat = Array2::<f32>::zeros((total_tokens, self.dim));
122        let mut labels = Vec::with_capacity(total_tokens);
123
124        let mut row = 0;
125        for doc in &self.pending {
126            for seq_id in 0..doc.embeddings.nrows() {
127                flat.row_mut(row).assign(&doc.embeddings.row(seq_id));
128                labels.push(TokenLabel {
129                    doc_id: doc.doc_id,
130                    seq_id: seq_id as u32,
131                    metadata: doc.metadata.clone(),
132                });
133                row += 1;
134            }
135        }
136
137        // Build HNSW index.
138        let index_file = with_ext(index_path, "index");
139        backend::build_backend(&self.backend_config, &flat, &index_file, None)?;
140
141        // Write labels sidecar.
142        let labels_file = with_ext(index_path, "labels.json");
143        let labels_json = serde_json::to_string(&labels)?;
144        fs::write(&labels_file, labels_json)
145            .with_context(|| format!("writing {}", labels_file.display()))?;
146
147        // Write .emb.npy for exact reranking.
148        let npy_file = with_ext(index_path, "emb.npy");
149        write_npy(&flat, &npy_file)?;
150
151        Ok(())
152    }
153}
154
155// ---------------------------------------------------------------------------
156// Searcher
157// ---------------------------------------------------------------------------
158
159/// A loaded multi-vector index, ready for MaxSim search.
160pub struct MultiVectorSearcher {
161    index: BackendIndex,
162    labels: Vec<TokenLabel>,
163    /// doc_id → list of flat row indices into the embedding matrix.
164    doc_to_rows: HashMap<u32, Vec<usize>>,
165    /// Memory-mapped embedding matrix for exact reranking.
166    #[cfg(feature = "multi-vector")]
167    emb_mmap: memmap2::Mmap,
168    #[cfg(not(feature = "multi-vector"))]
169    emb_data: Vec<u8>,
170    dim: usize,
171    total_tokens: usize,
172}
173
174impl MultiVectorSearcher {
175    /// Open a multi-vector index from disk.
176    pub fn open(index_path: &Path) -> Result<Self> {
177        // Read HNSW index.
178        let index_file = with_ext(index_path, "index");
179        let index = backend::read_backend_index("hnsw", &index_file)?;
180
181        // Read labels.
182        let labels_file = with_ext(index_path, "labels.json");
183        let labels_data = fs::read_to_string(&labels_file)
184            .with_context(|| format!("reading {}", labels_file.display()))?;
185        let labels: Vec<TokenLabel> = serde_json::from_str(&labels_data)?;
186
187        // Build doc_id → rows mapping.
188        let mut doc_to_rows: HashMap<u32, Vec<usize>> = HashMap::new();
189        for (i, label) in labels.iter().enumerate() {
190            doc_to_rows.entry(label.doc_id).or_default().push(i);
191        }
192
193        let dim = index.dimensions();
194        let total_tokens = labels.len();
195
196        // Mmap the .emb.npy file.
197        let npy_file = with_ext(index_path, "emb.npy");
198
199        #[cfg(feature = "multi-vector")]
200        let emb_mmap = {
201            let file = fs::File::open(&npy_file)
202                .with_context(|| format!("opening {}", npy_file.display()))?;
203            unsafe { memmap2::Mmap::map(&file)? }
204        };
205
206        Ok(Self {
207            index,
208            labels,
209            doc_to_rows,
210            #[cfg(feature = "multi-vector")]
211            emb_mmap,
212            #[cfg(not(feature = "multi-vector"))]
213            emb_data: fs::read(&npy_file)?,
214            dim,
215            total_tokens,
216        })
217    }
218
219    /// Number of documents in the index.
220    pub fn num_docs(&self) -> usize {
221        self.doc_to_rows.len()
222    }
223
224    /// Total number of token vectors in the index.
225    pub fn num_tokens(&self) -> usize {
226        self.total_tokens
227    }
228
229    /// Approximate MaxSim search.
230    ///
231    /// For each query token, runs HNSW ANN search, then aggregates per-document
232    /// using the MaxSim formula.
233    ///
234    /// `query_tokens` has shape `[num_query_tokens, dim]`.
235    pub fn search(
236        &self,
237        query_tokens: &Array2<f32>,
238        top_k: usize,
239    ) -> Result<Vec<MultiVectorResult>> {
240        self.search_with_params(query_tokens, top_k, 50)
241    }
242
243    /// Approximate MaxSim search with configurable per-token k.
244    pub fn search_with_params(
245        &self,
246        query_tokens: &Array2<f32>,
247        top_k: usize,
248        per_token_k: usize,
249    ) -> Result<Vec<MultiVectorResult>> {
250        let params = SearchParams::default();
251
252        // For each query token, find nearest neighbors in the HNSW index.
253        // Accumulate MaxSim: for each doc, sum of max scores across query tokens.
254        let mut doc_scores: HashMap<u32, f32> = HashMap::new();
255
256        for qi in 0..query_tokens.nrows() {
257            let query_vec = query_tokens.row(qi);
258            let query_slice = query_vec.as_slice().unwrap();
259
260            let (labels_idx, distances) =
261                backend::search_backend(&self.index, query_slice, per_token_k, &params);
262
263            // For this query token, find best score per doc.
264            // HNSW inner_product_distance returns -dot(a,b), so negate to get similarity.
265            let mut best_per_doc: HashMap<u32, f32> = HashMap::new();
266            for (idx, dist) in labels_idx.into_iter().zip(distances) {
267                if idx >= self.labels.len() {
268                    continue;
269                }
270                let doc_id = self.labels[idx].doc_id;
271                let sim = -dist; // negate HNSW's negated inner product
272                let entry = best_per_doc.entry(doc_id).or_insert(f32::NEG_INFINITY);
273                if sim > *entry {
274                    *entry = sim;
275                }
276            }
277
278            // Accumulate into global scores.
279            for (doc_id, score) in best_per_doc {
280                *doc_scores.entry(doc_id).or_insert(0.0) += score;
281            }
282        }
283
284        Ok(top_k_results(
285            &doc_scores,
286            top_k,
287            &self.doc_to_rows,
288            &self.labels,
289        ))
290    }
291
292    /// Two-stage exact MaxSim search.
293    ///
294    /// Stage 1: approximate HNSW search to find candidate doc_ids.
295    /// Stage 2: exact MaxSim reranking using mmap'd embeddings.
296    pub fn search_exact(
297        &self,
298        query_tokens: &Array2<f32>,
299        top_k: usize,
300        first_stage_k: usize,
301    ) -> Result<Vec<MultiVectorResult>> {
302        // Stage 1: collect candidate doc_ids via approximate search.
303        let approx = self.search_with_params(query_tokens, first_stage_k, 50)?;
304        let candidate_docs: Vec<u32> = approx.iter().map(|r| r.doc_id).collect();
305
306        if candidate_docs.is_empty() {
307            return Ok(Vec::new());
308        }
309
310        // Parse the npy data to get the embedding slice.
311        let emb_bytes = self.emb_bytes();
312        let (header_len, _rows, _cols) = parse_npy_header(emb_bytes)?;
313        let data_start = header_len;
314        let float_data = &emb_bytes[data_start..];
315
316        // Stage 2: exact MaxSim for each candidate.
317        let mut doc_scores: HashMap<u32, f32> = HashMap::new();
318        for &doc_id in &candidate_docs {
319            if let Some(row_indices) = self.doc_to_rows.get(&doc_id) {
320                let score = exact_max_sim(query_tokens, float_data, row_indices, self.dim);
321                doc_scores.insert(doc_id, score);
322            }
323        }
324
325        Ok(top_k_results(
326            &doc_scores,
327            top_k,
328            &self.doc_to_rows,
329            &self.labels,
330        ))
331    }
332
333    fn emb_bytes(&self) -> &[u8] {
334        #[cfg(feature = "multi-vector")]
335        {
336            &self.emb_mmap
337        }
338        #[cfg(not(feature = "multi-vector"))]
339        {
340            &self.emb_data
341        }
342    }
343}
344
345// ---------------------------------------------------------------------------
346// Search result
347// ---------------------------------------------------------------------------
348
349/// A multi-vector search result (one per document).
350#[derive(Debug, Clone)]
351pub struct MultiVectorResult {
352    pub doc_id: u32,
353    pub score: f32,
354    /// Per-document metadata from the first token label.
355    pub metadata: HashMap<String, serde_json::Value>,
356}
357
358// ---------------------------------------------------------------------------
359// MaxSim helpers
360// ---------------------------------------------------------------------------
361
362/// Compute exact MaxSim: Σ_i max_j (q_i · d_j).
363fn exact_max_sim(
364    query_tokens: &Array2<f32>,
365    float_data: &[u8],
366    doc_row_indices: &[usize],
367    dim: usize,
368) -> f32 {
369    let mut total = 0.0f32;
370    for qi in 0..query_tokens.nrows() {
371        let q = query_tokens.row(qi);
372        let mut best = f32::NEG_INFINITY;
373        for &row_idx in doc_row_indices {
374            let offset = row_idx * dim * 4;
375            let end = offset + dim * 4;
376            if end > float_data.len() {
377                continue;
378            }
379            let dot = dot_product_bytes(q, &float_data[offset..end]);
380            if dot > best {
381                best = dot;
382            }
383        }
384        if best > f32::NEG_INFINITY {
385            total += best;
386        }
387    }
388    total
389}
390
391/// Dot product between an ndarray row and raw LE f32 bytes.
392#[inline]
393fn dot_product_bytes(a: ArrayView1<f32>, b_bytes: &[u8]) -> f32 {
394    let mut sum = 0.0f32;
395    for (i, &ai) in a.iter().enumerate() {
396        let offset = i * 4;
397        let bi = f32::from_le_bytes(b_bytes[offset..offset + 4].try_into().unwrap());
398        sum += ai * bi;
399    }
400    sum
401}
402
403/// Extract top-k docs from score map, sorted descending.
404fn top_k_results(
405    doc_scores: &HashMap<u32, f32>,
406    top_k: usize,
407    doc_to_rows: &HashMap<u32, Vec<usize>>,
408    labels: &[TokenLabel],
409) -> Vec<MultiVectorResult> {
410    let mut entries: Vec<(u32, f32)> = doc_scores.iter().map(|(&d, &s)| (d, s)).collect();
411    entries.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
412    entries.truncate(top_k);
413
414    entries
415        .into_iter()
416        .map(|(doc_id, score)| {
417            let metadata = doc_to_rows
418                .get(&doc_id)
419                .and_then(|rows| rows.first())
420                .map(|&idx| labels[idx].metadata.clone())
421                .unwrap_or_default();
422            MultiVectorResult {
423                doc_id,
424                score,
425                metadata,
426            }
427        })
428        .collect()
429}
430
431// ---------------------------------------------------------------------------
432// Minimal NPY v1.0 read/write
433// ---------------------------------------------------------------------------
434
435/// Write an Array2<f32> as a NumPy .npy v1.0 file.
436fn write_npy(arr: &Array2<f32>, path: &Path) -> Result<()> {
437    let (rows, cols) = arr.dim();
438    let header = format!(
439        "{{'descr': '<f4', 'fortran_order': False, 'shape': ({}, {}), }}",
440        rows, cols
441    );
442    // Pad header to 64-byte alignment (magic(6) + version(2) + header_len(2) + header + \n).
443    let prefix_len = 10; // magic + version + header_len
444    let total_unpadded = prefix_len + header.len() + 1; // +1 for trailing \n
445    let padding = (64 - (total_unpadded % 64)) % 64;
446    let header_content_len = header.len() + padding + 1; // header + spaces + \n
447
448    let mut file = fs::File::create(path)?;
449    // Magic
450    file.write_all(&[0x93, b'N', b'U', b'M', b'P', b'Y'])?;
451    // Version 1.0
452    file.write_all(&[1, 0])?;
453    // Header length (little-endian u16)
454    file.write_all(&(header_content_len as u16).to_le_bytes())?;
455    // Header string + padding + newline
456    file.write_all(header.as_bytes())?;
457    for _ in 0..padding {
458        file.write_all(b" ")?;
459    }
460    file.write_all(b"\n")?;
461
462    // Data: row-major f32 LE
463    for val in arr.iter() {
464        file.write_all(&val.to_le_bytes())?;
465    }
466
467    Ok(())
468}
469
470/// Parse a .npy v1.0 header, returning (data_offset, rows, cols).
471fn parse_npy_header(data: &[u8]) -> Result<(usize, usize, usize)> {
472    anyhow::ensure!(data.len() >= 10, "npy file too small");
473    anyhow::ensure!(&data[0..6] == b"\x93NUMPY", "invalid npy magic");
474
475    let header_len = u16::from_le_bytes([data[8], data[9]]) as usize;
476    let header_end = 10 + header_len;
477    anyhow::ensure!(data.len() >= header_end, "npy header truncated");
478
479    let header_str = std::str::from_utf8(&data[10..header_end])?;
480    // Parse shape tuple from the header string.
481    let shape_start = header_str
482        .find("'shape': (")
483        .context("no shape in npy header")?
484        + "'shape': (".len();
485    let shape_end = header_str[shape_start..]
486        .find(')')
487        .context("unclosed shape tuple")?
488        + shape_start;
489    let shape_str = &header_str[shape_start..shape_end];
490    let dims: Vec<usize> = shape_str
491        .split(',')
492        .filter_map(|s| s.trim().parse().ok())
493        .collect();
494
495    anyhow::ensure!(dims.len() == 2, "expected 2D shape, got {:?}", dims);
496
497    Ok((header_end, dims[0], dims[1]))
498}
499
500// ---------------------------------------------------------------------------
501// Path helpers
502// ---------------------------------------------------------------------------
503
504fn with_ext(base: &Path, ext: &str) -> PathBuf {
505    let mut p = base.to_path_buf();
506    let name = p
507        .file_name()
508        .unwrap_or_default()
509        .to_string_lossy()
510        .to_string();
511    p.set_file_name(format!("{}.{}", name, ext));
512    p
513}
514
515// ---------------------------------------------------------------------------
516// Tests
517// ---------------------------------------------------------------------------
518
519#[cfg(test)]
520mod tests {
521    use super::*;
522    use ndarray::array;
523
524    fn make_test_data() -> (Array2<f32>, Array2<f32>, Array2<f32>) {
525        // Doc 0: 3 tokens, dim=4. Tokens point roughly in the +x direction.
526        let doc0 = array![
527            [1.0, 0.0, 0.0, 0.0],
528            [0.9, 0.1, 0.0, 0.0],
529            [0.8, 0.2, 0.0, 0.0],
530        ];
531        // Doc 1: 2 tokens, dim=4. Tokens point roughly in the +y direction.
532        let doc1 = array![[0.0, 1.0, 0.0, 0.0], [0.1, 0.9, 0.0, 0.0],];
533        // Query: 2 tokens — one in +x, one in +y. Should score both docs.
534        let query = array![[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0],];
535        (doc0, doc1, query)
536    }
537
538    #[test]
539    fn test_build_and_search() {
540        let dir = tempfile::tempdir().unwrap();
541        let index_path = dir.path().join("test_mv");
542
543        let (doc0, doc1, query) = make_test_data();
544
545        let mut builder = MultiVectorBuilder::new(4);
546        builder.insert(0, doc0, HashMap::new());
547        builder.insert(1, doc1, HashMap::new());
548        builder.build(&index_path).unwrap();
549
550        // Verify files exist.
551        assert!(with_ext(&index_path, "index").exists());
552        assert!(with_ext(&index_path, "labels.json").exists());
553        assert!(with_ext(&index_path, "emb.npy").exists());
554
555        let searcher = MultiVectorSearcher::open(&index_path).unwrap();
556        assert_eq!(searcher.num_docs(), 2);
557        assert_eq!(searcher.num_tokens(), 5);
558
559        // Approximate search.
560        let results = searcher.search(&query, 2).unwrap();
561        assert_eq!(results.len(), 2);
562        // Both docs should appear — doc0 best for +x query token, doc1 for +y.
563
564        // Exact search.
565        let exact_results = searcher.search_exact(&query, 2, 10).unwrap();
566        assert_eq!(exact_results.len(), 2);
567    }
568
569    #[test]
570    fn test_max_sim_scoring() {
571        let dir = tempfile::tempdir().unwrap();
572        let index_path = dir.path().join("test_scoring");
573
574        // Doc 0: perfect match in +x
575        let doc0 = array![[1.0, 0.0, 0.0, 0.0]];
576        // Doc 1: perfect match in +y
577        let doc1 = array![[0.0, 1.0, 0.0, 0.0]];
578        // Query: just +x — should prefer doc0.
579        let query = array![[1.0, 0.0, 0.0, 0.0]];
580
581        let mut builder = MultiVectorBuilder::new(4);
582        builder.insert(0, doc0, HashMap::new());
583        builder.insert(1, doc1, HashMap::new());
584        builder.build(&index_path).unwrap();
585
586        let searcher = MultiVectorSearcher::open(&index_path).unwrap();
587        let results = searcher.search_exact(&query, 2, 10).unwrap();
588
589        assert_eq!(results[0].doc_id, 0);
590        assert!(results[0].score > results[1].score);
591        assert!((results[0].score - 1.0).abs() < 1e-5);
592        assert!((results[1].score - 0.0).abs() < 1e-5);
593    }
594
595    #[test]
596    fn test_npy_roundtrip() {
597        let dir = tempfile::tempdir().unwrap();
598        let path = dir.path().join("test.npy");
599
600        let arr = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
601        write_npy(&arr, &path).unwrap();
602
603        let data = fs::read(&path).unwrap();
604        let (header_len, rows, cols) = parse_npy_header(&data).unwrap();
605        assert_eq!(rows, 2);
606        assert_eq!(cols, 3);
607
608        let float_data = &data[header_len..];
609        assert_eq!(float_data.len(), 2 * 3 * 4);
610        let first = f32::from_le_bytes(float_data[0..4].try_into().unwrap());
611        assert!((first - 1.0).abs() < 1e-6);
612    }
613
614    #[test]
615    fn test_metadata_propagation() {
616        let dir = tempfile::tempdir().unwrap();
617        let index_path = dir.path().join("test_meta");
618
619        let doc0 = array![[1.0, 0.0]];
620        let mut meta = HashMap::new();
621        meta.insert("filepath".to_string(), serde_json::json!("/tmp/page1.png"));
622
623        let mut builder = MultiVectorBuilder::new(2);
624        builder.insert(42, doc0, meta);
625        builder.build(&index_path).unwrap();
626
627        let searcher = MultiVectorSearcher::open(&index_path).unwrap();
628        let query = array![[1.0, 0.0]];
629        let results = searcher.search(&query, 1).unwrap();
630
631        assert_eq!(results[0].doc_id, 42);
632        assert_eq!(results[0].metadata["filepath"], "/tmp/page1.png");
633    }
634
635    #[test]
636    fn test_many_docs_ranking() {
637        // 10 docs, each with tokens along a different basis direction.
638        // Query for a specific direction should rank that doc first.
639        let dir = tempfile::tempdir().unwrap();
640        let index_path = dir.path().join("test_many");
641        let dim = 16;
642
643        let mut builder = MultiVectorBuilder::new(dim);
644        for doc_id in 0..10u32 {
645            let mut tokens = Array2::<f32>::zeros((3, dim));
646            // Each doc's tokens have energy in dimension doc_id.
647            for t in 0..3 {
648                tokens[[t, doc_id as usize]] = 1.0;
649                // Add some noise in other dims.
650                tokens[[t, (doc_id as usize + 1) % dim]] = 0.1 * (t as f32);
651            }
652            builder.insert(doc_id, tokens, HashMap::new());
653        }
654        builder.build(&index_path).unwrap();
655
656        let searcher = MultiVectorSearcher::open(&index_path).unwrap();
657        assert_eq!(searcher.num_docs(), 10);
658
659        // Query for doc 5's direction.
660        let mut query = Array2::<f32>::zeros((1, dim));
661        query[[0, 5]] = 1.0;
662
663        let results = searcher.search_exact(&query, 3, 30).unwrap();
664        assert_eq!(results[0].doc_id, 5);
665    }
666
667    #[test]
668    fn test_multi_token_query_aggregation() {
669        // Verify MaxSim aggregates across query tokens correctly.
670        let dir = tempfile::tempdir().unwrap();
671        let index_path = dir.path().join("test_agg");
672
673        // Doc 0 has tokens in +x and +y.
674        let doc0 = array![[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0],];
675        // Doc 1 has tokens only in +z.
676        let doc1 = array![[0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.9, 0.1],];
677
678        let mut builder = MultiVectorBuilder::new(4);
679        builder.insert(0, doc0, HashMap::new());
680        builder.insert(1, doc1, HashMap::new());
681        builder.build(&index_path).unwrap();
682
683        let searcher = MultiVectorSearcher::open(&index_path).unwrap();
684
685        // Query with +x and +y tokens — should strongly prefer doc0
686        // since it matches both query tokens perfectly.
687        let query = array![[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0],];
688        let results = searcher.search_exact(&query, 2, 10).unwrap();
689        assert_eq!(results[0].doc_id, 0);
690        // Doc 0 score: max(1,0) for q0=+x + max(0,1) for q1=+y = 2.0
691        assert!((results[0].score - 2.0).abs() < 1e-5);
692        // Doc 1 score: max(0,0) for q0=+x + max(0,0) for q1=+y ≈ 0.0 + 0.0
693        assert!(results[1].score < 0.2);
694    }
695
696    #[test]
697    fn test_single_doc_single_token() {
698        let dir = tempfile::tempdir().unwrap();
699        let index_path = dir.path().join("test_single");
700
701        let doc = array![[0.6, 0.8]];
702        let mut builder = MultiVectorBuilder::new(2);
703        builder.insert(0, doc, HashMap::new());
704        builder.build(&index_path).unwrap();
705
706        let searcher = MultiVectorSearcher::open(&index_path).unwrap();
707        assert_eq!(searcher.num_docs(), 1);
708        assert_eq!(searcher.num_tokens(), 1);
709
710        let query = array![[0.6, 0.8]];
711        let results = searcher.search(&query, 1).unwrap();
712        assert_eq!(results.len(), 1);
713        // dot(q, d) = 0.36 + 0.64 = 1.0
714        assert!((results[0].score - 1.0).abs() < 1e-5);
715    }
716
717    #[test]
718    fn test_top_k_limits_results() {
719        let dir = tempfile::tempdir().unwrap();
720        let index_path = dir.path().join("test_topk");
721
722        let mut builder = MultiVectorBuilder::new(4);
723        for i in 0..5u32 {
724            let doc = array![[1.0, 0.0, 0.0, 0.0]];
725            builder.insert(i, doc, HashMap::new());
726        }
727        builder.build(&index_path).unwrap();
728
729        let searcher = MultiVectorSearcher::open(&index_path).unwrap();
730        let query = array![[1.0, 0.0, 0.0, 0.0]];
731
732        let results = searcher.search(&query, 3).unwrap();
733        assert_eq!(results.len(), 3);
734
735        let results_all = searcher.search(&query, 10).unwrap();
736        assert_eq!(results_all.len(), 5);
737    }
738
739    #[test]
740    fn test_variable_token_counts() {
741        // Documents with different numbers of tokens.
742        let dir = tempfile::tempdir().unwrap();
743        let index_path = dir.path().join("test_vartok");
744
745        let doc0 = array![[1.0, 0.0]]; // 1 token
746        let doc1 = array![[0.0, 1.0], [0.5, 0.5], [0.3, 0.7]]; // 3 tokens
747        let doc2 = array![[0.7, 0.7], [0.8, 0.6]]; // 2 tokens
748
749        let mut builder = MultiVectorBuilder::new(2);
750        builder.insert(0, doc0, HashMap::new());
751        builder.insert(1, doc1, HashMap::new());
752        builder.insert(2, doc2, HashMap::new());
753        builder.build(&index_path).unwrap();
754
755        let searcher = MultiVectorSearcher::open(&index_path).unwrap();
756        assert_eq!(searcher.num_docs(), 3);
757        assert_eq!(searcher.num_tokens(), 6); // 1 + 3 + 2
758
759        let query = array![[0.0, 1.0]];
760        let results = searcher.search_exact(&query, 3, 10).unwrap();
761        assert_eq!(results.len(), 3);
762        // Doc 1 has a token [0, 1] — perfect match.
763        assert_eq!(results[0].doc_id, 1);
764    }
765
766    #[test]
767    fn test_labels_sidecar_format() {
768        let dir = tempfile::tempdir().unwrap();
769        let index_path = dir.path().join("test_labels");
770
771        let doc0 = array![[1.0, 0.0], [0.0, 1.0]];
772        let doc1 = array![[0.5, 0.5]];
773
774        let mut meta0 = HashMap::new();
775        meta0.insert("page".to_string(), serde_json::json!(1));
776
777        let mut builder = MultiVectorBuilder::new(2);
778        builder.insert(10, doc0, meta0);
779        builder.insert(20, doc1, HashMap::new());
780        builder.build(&index_path).unwrap();
781
782        // Read and verify labels.json directly.
783        let labels_path = with_ext(&index_path, "labels.json");
784        let data = fs::read_to_string(&labels_path).unwrap();
785        let labels: Vec<TokenLabel> = serde_json::from_str(&data).unwrap();
786
787        assert_eq!(labels.len(), 3);
788        assert_eq!(labels[0].doc_id, 10);
789        assert_eq!(labels[0].seq_id, 0);
790        assert_eq!(labels[0].metadata["page"], 1);
791        assert_eq!(labels[1].doc_id, 10);
792        assert_eq!(labels[1].seq_id, 1);
793        assert_eq!(labels[2].doc_id, 20);
794        assert_eq!(labels[2].seq_id, 0);
795        assert!(labels[2].metadata.is_empty());
796    }
797
798    #[test]
799    fn test_exact_vs_approximate_consistency() {
800        // Exact and approximate search should agree on the top-1 result
801        // for a well-separated dataset.
802        let dir = tempfile::tempdir().unwrap();
803        let index_path = dir.path().join("test_consistency");
804
805        // 8 docs in distinct directions (need enough nodes for HNSW to work well).
806        let dim = 8;
807        let mut builder = MultiVectorBuilder::new(dim);
808        for i in 0..8u32 {
809            let mut emb = Array2::<f32>::zeros((1, dim));
810            emb[[0, i as usize]] = 1.0;
811            builder.insert(i, emb, HashMap::new());
812        }
813        builder.build(&index_path).unwrap();
814
815        let searcher = MultiVectorSearcher::open(&index_path).unwrap();
816        let mut query = Array2::<f32>::zeros((1, dim));
817        query[[0, 2]] = 1.0;
818
819        let exact = searcher.search_exact(&query, 1, 10).unwrap();
820        assert_eq!(exact[0].doc_id, 2);
821        assert!((exact[0].score - 1.0).abs() < 1e-5);
822
823        // Approximate should also find doc 2 (well-separated).
824        let approx = searcher.search(&query, 1).unwrap();
825        assert_eq!(approx[0].doc_id, 2);
826    }
827
828    #[test]
829    #[should_panic(expected = "no documents inserted")]
830    fn test_build_empty_panics() {
831        let dir = tempfile::tempdir().unwrap();
832        let index_path = dir.path().join("test_empty");
833        let builder = MultiVectorBuilder::new(4);
834        builder.build(&index_path).unwrap();
835    }
836
837    #[test]
838    #[should_panic(expected = "embedding dim 3 != expected 4")]
839    fn test_dimension_mismatch_panics() {
840        let mut builder = MultiVectorBuilder::new(4);
841        builder.insert(0, array![[1.0, 2.0, 3.0]], HashMap::new());
842    }
843}