oxify_vector/
colbert.rs

1//! ColBERT-style Multi-Vector Search
2//!
3//! Late interaction model for dense retrieval with token-level matching.
4//!
5//! ## Algorithm Overview
6//!
7//! ColBERT (Contextualized Late Interaction over BERT) represents documents
8//! as collections of token embeddings and uses MaxSim for scoring:
9//!
10//! 1. Each document/query → sequence of token embeddings
11//! 2. Score = Σ max(sim(q_token, d_token)) for all query tokens
12//! 3. "Late interaction": token-level matching instead of single vector
13//!
14//! ## Benefits
15//!
16//! - **Fine-grained matching**: Matches specific parts of documents
17//! - **Better accuracy**: Captures more semantic nuance than single vectors
18//! - **Interpretability**: Can identify which tokens matched
19//!
20//! ## Example
21//!
22//! ```rust
23//! use oxify_vector::colbert::{ColbertIndex, ColbertConfig};
24//! use std::collections::HashMap;
25//!
26//! # fn example() -> anyhow::Result<()> {
27//! let config = ColbertConfig::default();
28//! let mut index = ColbertIndex::new(config);
29//!
30//! // Each document has multiple token embeddings
31//! let mut doc_tokens = HashMap::new();
32//! doc_tokens.insert("doc1".to_string(), vec![
33//!     vec![0.1, 0.2, 0.3],
34//!     vec![0.2, 0.3, 0.4],
35//!     vec![0.3, 0.4, 0.5],
36//! ]);
37//!
38//! index.build(&doc_tokens)?;
39//!
40//! let query_tokens = vec![
41//!     vec![0.15, 0.25, 0.35],
42//!     vec![0.25, 0.35, 0.45],
43//! ];
44//!
45//! let results = index.search(&query_tokens, 10)?;
46//! # Ok(())
47//! # }
48//! ```
49
50use anyhow::Result;
51use rayon::prelude::*;
52use serde::{Deserialize, Serialize};
53use std::collections::HashMap;
54
55use crate::simd;
56use crate::types::{DistanceMetric, SearchResult};
57
58/// ColBERT configuration
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct ColbertConfig {
61    /// Distance metric for token similarity
62    /// Cosine similarity is standard for ColBERT
63    pub metric: DistanceMetric,
64
65    /// Maximum number of tokens per document
66    /// Longer documents are truncated
67    pub max_doc_tokens: usize,
68
69    /// Maximum number of tokens per query
70    pub max_query_tokens: usize,
71
72    /// Enable compression for token storage
73    pub compress_tokens: bool,
74
75    /// Use parallel search
76    pub parallel_search: bool,
77}
78
79impl Default for ColbertConfig {
80    fn default() -> Self {
81        Self {
82            metric: DistanceMetric::Cosine,
83            max_doc_tokens: 300,
84            max_query_tokens: 32,
85            compress_tokens: false,
86            parallel_search: true,
87        }
88    }
89}
90
91impl ColbertConfig {
92    pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
93        self.metric = metric;
94        self
95    }
96
97    pub fn with_max_doc_tokens(mut self, max_doc_tokens: usize) -> Self {
98        self.max_doc_tokens = max_doc_tokens;
99        self
100    }
101
102    pub fn with_max_query_tokens(mut self, max_query_tokens: usize) -> Self {
103        self.max_query_tokens = max_query_tokens;
104        self
105    }
106
107    pub fn with_compression(mut self, compress: bool) -> Self {
108        self.compress_tokens = compress;
109        self
110    }
111}
112
113/// Multi-vector representation of a document
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct MultiVectorDoc {
116    pub entity_id: String,
117    pub token_embeddings: Vec<Vec<f32>>,
118}
119
120/// ColBERT search result with token-level match information
121#[derive(Debug, Clone)]
122pub struct ColbertSearchResult {
123    pub entity_id: String,
124    pub score: f32,
125    /// Token-level scores (query_token_idx -> (best_doc_token_idx, score))
126    pub token_matches: Vec<(usize, f32)>,
127}
128
129/// ColBERT index for multi-vector search
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct ColbertIndex {
132    config: ColbertConfig,
133    documents: Vec<MultiVectorDoc>,
134    dim: Option<usize>,
135}
136
137impl ColbertIndex {
138    pub fn new(config: ColbertConfig) -> Self {
139        Self {
140            config,
141            documents: Vec::new(),
142            dim: None,
143        }
144    }
145
146    /// Build index from multi-vector documents
147    pub fn build(&mut self, doc_tokens: &HashMap<String, Vec<Vec<f32>>>) -> Result<()> {
148        if doc_tokens.is_empty() {
149            anyhow::bail!("Cannot build ColBERT index with empty documents");
150        }
151
152        // Determine dimension from first token
153        let first_doc_tokens = doc_tokens.values().next().unwrap();
154        if first_doc_tokens.is_empty() {
155            anyhow::bail!("Document has no token embeddings");
156        }
157        self.dim = Some(first_doc_tokens[0].len());
158
159        // Store all documents
160        self.documents.clear();
161        for (entity_id, tokens) in doc_tokens {
162            // Truncate to max_doc_tokens
163            let truncated_tokens = if tokens.len() > self.config.max_doc_tokens {
164                tokens[..self.config.max_doc_tokens].to_vec()
165            } else {
166                tokens.clone()
167            };
168
169            self.documents.push(MultiVectorDoc {
170                entity_id: entity_id.clone(),
171                token_embeddings: truncated_tokens,
172            });
173        }
174
175        Ok(())
176    }
177
178    /// Add a single document to the index
179    pub fn add(&mut self, entity_id: String, token_embeddings: Vec<Vec<f32>>) -> Result<()> {
180        if token_embeddings.is_empty() {
181            anyhow::bail!("Cannot add document with no token embeddings");
182        }
183
184        // Set dimension if first document
185        if self.dim.is_none() {
186            self.dim = Some(token_embeddings[0].len());
187        }
188
189        // Verify all tokens have correct dimension
190        let dim = self.dim.unwrap();
191        for token in &token_embeddings {
192            if token.len() != dim {
193                anyhow::bail!(
194                    "Token dimension {} does not match index dimension {}",
195                    token.len(),
196                    dim
197                );
198            }
199        }
200
201        // Truncate to max_doc_tokens
202        let truncated_tokens = if token_embeddings.len() > self.config.max_doc_tokens {
203            token_embeddings[..self.config.max_doc_tokens].to_vec()
204        } else {
205            token_embeddings
206        };
207
208        self.documents.push(MultiVectorDoc {
209            entity_id,
210            token_embeddings: truncated_tokens,
211        });
212
213        Ok(())
214    }
215
216    /// Search using MaxSim scoring
217    pub fn search(&self, query_tokens: &[Vec<f32>], k: usize) -> Result<Vec<ColbertSearchResult>> {
218        if self.documents.is_empty() {
219            return Ok(Vec::new());
220        }
221
222        // Truncate query to max_query_tokens
223        let query = if query_tokens.len() > self.config.max_query_tokens {
224            &query_tokens[..self.config.max_query_tokens]
225        } else {
226            query_tokens
227        };
228
229        // Compute MaxSim score for each document
230        let results: Vec<ColbertSearchResult> = if self.config.parallel_search {
231            self.documents
232                .par_iter()
233                .map(|doc| self.compute_maxsim_score(query, doc))
234                .collect()
235        } else {
236            self.documents
237                .iter()
238                .map(|doc| self.compute_maxsim_score(query, doc))
239                .collect()
240        };
241
242        // Sort by score (descending) and return top-k
243        let mut sorted_results = results;
244        sorted_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
245
246        Ok(sorted_results.into_iter().take(k).collect())
247    }
248
249    /// Compute similarity score between two vectors
250    ///
251    /// Uses SIMD-optimized calculations for better performance.
252    #[inline]
253    fn compute_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
254        // Use SIMD-optimized implementations for hot path performance
255        simd::compute_distance_simd(self.config.metric, a, b)
256    }
257
258    /// Compute MaxSim score: sum of max similarities for each query token
259    fn compute_maxsim_score(
260        &self,
261        query_tokens: &[Vec<f32>],
262        doc: &MultiVectorDoc,
263    ) -> ColbertSearchResult {
264        let mut total_score = 0.0;
265        let mut token_matches = Vec::with_capacity(query_tokens.len());
266
267        for query_token in query_tokens {
268            // Find the best matching document token
269            let (best_doc_idx, best_score) = doc
270                .token_embeddings
271                .iter()
272                .enumerate()
273                .map(|(idx, doc_token)| {
274                    let score = self.compute_similarity(query_token, doc_token);
275                    (idx, score)
276                })
277                .max_by(|(_, a), (_, b)| {
278                    // Handle NaN values by treating them as negative infinity
279                    a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
280                })
281                .unwrap_or((0, 0.0));
282
283            total_score += best_score;
284            token_matches.push((best_doc_idx, best_score));
285        }
286
287        ColbertSearchResult {
288            entity_id: doc.entity_id.clone(),
289            score: total_score,
290            token_matches,
291        }
292    }
293
294    /// Convert ColBERT results to standard SearchResult format
295    pub fn to_search_results(&self, results: Vec<ColbertSearchResult>) -> Vec<SearchResult> {
296        results
297            .into_iter()
298            .enumerate()
299            .map(|(rank, r)| SearchResult {
300                entity_id: r.entity_id,
301                score: r.score,
302                distance: r.score,
303                rank: rank + 1,
304            })
305            .collect()
306    }
307
308    /// Get index statistics
309    pub fn stats(&self) -> ColbertStats {
310        let total_tokens: usize = self
311            .documents
312            .iter()
313            .map(|d| d.token_embeddings.len())
314            .sum();
315
316        let avg_tokens = if self.documents.is_empty() {
317            0.0
318        } else {
319            total_tokens as f32 / self.documents.len() as f32
320        };
321
322        let memory_bytes = self.estimate_memory();
323
324        ColbertStats {
325            num_documents: self.documents.len(),
326            total_tokens,
327            avg_tokens_per_doc: avg_tokens,
328            dimension: self.dim.unwrap_or(0),
329            memory_bytes,
330        }
331    }
332
333    fn estimate_memory(&self) -> usize {
334        let total_tokens: usize = self
335            .documents
336            .iter()
337            .map(|d| d.token_embeddings.len())
338            .sum();
339        let dim = self.dim.unwrap_or(0);
340
341        // Tokens: total_tokens * dim * 4 bytes (f32)
342        total_tokens * dim * 4
343    }
344
345    /// Remove a document by entity_id
346    pub fn remove(&mut self, entity_id: &str) -> bool {
347        if let Some(pos) = self.documents.iter().position(|d| d.entity_id == entity_id) {
348            self.documents.remove(pos);
349            true
350        } else {
351            false
352        }
353    }
354
355    /// Get number of documents
356    pub fn len(&self) -> usize {
357        self.documents.len()
358    }
359
360    /// Check if index is empty
361    pub fn is_empty(&self) -> bool {
362        self.documents.is_empty()
363    }
364}
365
366/// ColBERT index statistics
367#[derive(Debug, Clone)]
368pub struct ColbertStats {
369    pub num_documents: usize,
370    pub total_tokens: usize,
371    pub avg_tokens_per_doc: f32,
372    pub dimension: usize,
373    pub memory_bytes: usize,
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    #[test]
381    fn test_colbert_creation() {
382        let config = ColbertConfig::default();
383        let index = ColbertIndex::new(config);
384
385        assert_eq!(index.len(), 0);
386        assert!(index.is_empty());
387    }
388
389    #[test]
390    fn test_colbert_add_document() {
391        let config = ColbertConfig::default();
392        let mut index = ColbertIndex::new(config);
393
394        let tokens = vec![vec![0.1, 0.2, 0.3], vec![0.2, 0.3, 0.4]];
395
396        assert!(index.add("doc1".to_string(), tokens).is_ok());
397        assert_eq!(index.len(), 1);
398    }
399
400    #[test]
401    fn test_colbert_search() {
402        let config = ColbertConfig::default();
403        let mut index = ColbertIndex::new(config);
404
405        // Add documents
406        let doc1_tokens = vec![
407            vec![1.0, 0.0, 0.0],
408            vec![0.9, 0.1, 0.0],
409            vec![0.8, 0.2, 0.0],
410        ];
411
412        let doc2_tokens = vec![
413            vec![0.0, 1.0, 0.0],
414            vec![0.1, 0.9, 0.0],
415            vec![0.2, 0.8, 0.0],
416        ];
417
418        assert!(index.add("doc1".to_string(), doc1_tokens).is_ok());
419        assert!(index.add("doc2".to_string(), doc2_tokens).is_ok());
420
421        // Search with query closer to doc1
422        let query_tokens = vec![vec![0.95, 0.05, 0.0], vec![0.85, 0.15, 0.0]];
423
424        let results = index.search(&query_tokens, 2);
425        assert!(results.is_ok());
426
427        let results = results.unwrap();
428        assert_eq!(results.len(), 2);
429
430        // First result should be doc1 (higher score)
431        assert_eq!(results[0].entity_id, "doc1");
432        assert!(results[0].score > results[1].score);
433    }
434
435    #[test]
436    fn test_colbert_maxsim_scoring() {
437        let config = ColbertConfig::default();
438        let mut index = ColbertIndex::new(config);
439
440        // Document with 3 token embeddings
441        let doc_tokens = vec![
442            vec![1.0, 0.0, 0.0],
443            vec![0.0, 1.0, 0.0],
444            vec![0.0, 0.0, 1.0],
445        ];
446
447        assert!(index.add("doc1".to_string(), doc_tokens).is_ok());
448
449        // Query with 2 tokens that match first two doc tokens
450        let query_tokens = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
451
452        let results = index.search(&query_tokens, 1);
453        assert!(results.is_ok());
454
455        let results = results.unwrap();
456        assert_eq!(results.len(), 1);
457
458        // Should have match info for both query tokens
459        assert_eq!(results[0].token_matches.len(), 2);
460    }
461
462    #[test]
463    fn test_colbert_remove() {
464        let config = ColbertConfig::default();
465        let mut index = ColbertIndex::new(config);
466
467        let tokens = vec![vec![0.1, 0.2, 0.3]];
468
469        assert!(index.add("doc1".to_string(), tokens.clone()).is_ok());
470        assert!(index.add("doc2".to_string(), tokens).is_ok());
471
472        assert_eq!(index.len(), 2);
473
474        assert!(index.remove("doc1"));
475        assert_eq!(index.len(), 1);
476
477        assert!(!index.remove("doc1")); // Already removed
478    }
479
480    #[test]
481    fn test_colbert_build_from_hashmap() {
482        let config = ColbertConfig::default();
483        let mut index = ColbertIndex::new(config);
484
485        let mut doc_tokens = HashMap::new();
486        doc_tokens.insert(
487            "doc1".to_string(),
488            vec![vec![1.0, 0.0, 0.0], vec![0.9, 0.1, 0.0]],
489        );
490        doc_tokens.insert(
491            "doc2".to_string(),
492            vec![vec![0.0, 1.0, 0.0], vec![0.1, 0.9, 0.0]],
493        );
494        doc_tokens.insert(
495            "doc3".to_string(),
496            vec![vec![0.0, 0.0, 1.0], vec![0.0, 0.1, 0.9]],
497        );
498
499        let build_result = index.build(&doc_tokens);
500        assert!(build_result.is_ok());
501        assert_eq!(index.len(), 3);
502
503        // Search for doc1
504        let query_tokens = vec![vec![1.0, 0.0, 0.0]];
505        let results = index.search(&query_tokens, 2).unwrap();
506        assert_eq!(results.len(), 2);
507        assert_eq!(results[0].entity_id, "doc1");
508    }
509
510    #[test]
511    fn test_colbert_token_truncation() {
512        let config = ColbertConfig::default().with_max_doc_tokens(5);
513        let mut index = ColbertIndex::new(config);
514
515        // Create a document with 10 tokens (should be truncated to 5)
516        let long_doc_tokens: Vec<Vec<f32>> =
517            (0..10).map(|i| vec![i as f32 / 10.0, 0.0, 0.0]).collect();
518
519        assert!(index.add("doc1".to_string(), long_doc_tokens).is_ok());
520
521        // Verify truncation
522        assert_eq!(index.documents[0].token_embeddings.len(), 5);
523    }
524
525    #[test]
526    fn test_colbert_query_truncation() {
527        let config = ColbertConfig::default().with_max_query_tokens(3);
528        let mut index = ColbertIndex::new(config);
529
530        let doc_tokens = vec![vec![1.0, 0.0, 0.0], vec![0.9, 0.1, 0.0]];
531        assert!(index.add("doc1".to_string(), doc_tokens).is_ok());
532
533        // Create a long query (10 tokens, should be truncated to 3)
534        let long_query: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32 / 10.0, 0.0, 0.0]).collect();
535
536        let results = index.search(&long_query, 1);
537        assert!(results.is_ok());
538
539        let results = results.unwrap();
540        // Should have match info for only 3 query tokens (truncated)
541        assert_eq!(results[0].token_matches.len(), 3);
542    }
543
544    #[test]
545    fn test_colbert_parallel_vs_sequential() {
546        // Test with parallel search
547        let config_parallel = ColbertConfig::default().with_compression(false);
548        let mut index_parallel = ColbertIndex::new(config_parallel);
549
550        // Test with sequential search
551        let config_sequential = ColbertConfig {
552            parallel_search: false,
553            ..Default::default()
554        };
555        let mut index_sequential = ColbertIndex::new(config_sequential);
556
557        // Add same documents to both
558        let mut doc_tokens = HashMap::new();
559        for i in 0..20 {
560            let tokens: Vec<Vec<f32>> = (0..10)
561                .map(|j| vec![(i + j) as f32 / 20.0, 0.0, 0.0])
562                .collect();
563            doc_tokens.insert(format!("doc{}", i), tokens);
564        }
565
566        assert!(index_parallel.build(&doc_tokens).is_ok());
567        assert!(index_sequential.build(&doc_tokens).is_ok());
568
569        // Search with both
570        let query_tokens = vec![vec![0.5, 0.0, 0.0]];
571        let results_parallel = index_parallel.search(&query_tokens, 5).unwrap();
572        let results_sequential = index_sequential.search(&query_tokens, 5).unwrap();
573
574        // Results should be the same
575        assert_eq!(results_parallel.len(), results_sequential.len());
576        assert_eq!(
577            results_parallel[0].entity_id,
578            results_sequential[0].entity_id
579        );
580    }
581
582    #[test]
583    fn test_colbert_different_metrics() {
584        let metrics = vec![
585            DistanceMetric::Cosine,
586            DistanceMetric::Euclidean,
587            DistanceMetric::DotProduct,
588            DistanceMetric::Manhattan,
589        ];
590
591        for metric in metrics {
592            let config = ColbertConfig::default().with_metric(metric);
593            let mut index = ColbertIndex::new(config);
594
595            let doc_tokens = vec![vec![1.0, 0.0, 0.0], vec![0.9, 0.1, 0.0]];
596            assert!(index.add("doc1".to_string(), doc_tokens).is_ok());
597
598            let query_tokens = vec![vec![1.0, 0.0, 0.0]];
599            let results = index.search(&query_tokens, 1);
600            assert!(results.is_ok());
601        }
602    }
603
604    #[test]
605    fn test_colbert_empty_index_search() {
606        let config = ColbertConfig::default();
607        let index = ColbertIndex::new(config);
608
609        let query_tokens = vec![vec![1.0, 0.0, 0.0]];
610        let results = index.search(&query_tokens, 5);
611
612        assert!(results.is_ok());
613        assert_eq!(results.unwrap().len(), 0);
614    }
615
616    #[test]
617    fn test_colbert_empty_tokens_error() {
618        let config = ColbertConfig::default();
619        let mut index = ColbertIndex::new(config);
620
621        let empty_tokens: Vec<Vec<f32>> = vec![];
622        let result = index.add("doc1".to_string(), empty_tokens);
623
624        assert!(result.is_err());
625        assert!(result
626            .unwrap_err()
627            .to_string()
628            .contains("Cannot add document with no token embeddings"));
629    }
630
631    #[test]
632    fn test_colbert_dimension_mismatch_error() {
633        let config = ColbertConfig::default();
634        let mut index = ColbertIndex::new(config);
635
636        // Add first document with 3 dimensions
637        let doc1_tokens = vec![vec![1.0, 0.0, 0.0]];
638        assert!(index.add("doc1".to_string(), doc1_tokens).is_ok());
639
640        // Try to add second document with 4 dimensions (should fail)
641        let doc2_tokens = vec![vec![1.0, 0.0, 0.0, 0.0]];
642        let result = index.add("doc2".to_string(), doc2_tokens);
643
644        assert!(result.is_err());
645        assert!(result
646            .unwrap_err()
647            .to_string()
648            .contains("does not match index dimension"));
649    }
650
651    #[test]
652    fn test_colbert_build_empty_error() {
653        let config = ColbertConfig::default();
654        let mut index = ColbertIndex::new(config);
655
656        let empty_docs = HashMap::new();
657        let result = index.build(&empty_docs);
658
659        assert!(result.is_err());
660        assert!(result
661            .unwrap_err()
662            .to_string()
663            .contains("Cannot build ColBERT index with empty documents"));
664    }
665
666    #[test]
667    fn test_colbert_build_empty_tokens_error() {
668        let config = ColbertConfig::default();
669        let mut index = ColbertIndex::new(config);
670
671        let mut doc_tokens = HashMap::new();
672        doc_tokens.insert("doc1".to_string(), vec![]); // Empty tokens
673
674        let result = index.build(&doc_tokens);
675        assert!(result.is_err());
676        assert!(result
677            .unwrap_err()
678            .to_string()
679            .contains("Document has no token embeddings"));
680    }
681
682    #[test]
683    fn test_colbert_stats() {
684        let config = ColbertConfig::default();
685        let mut index = ColbertIndex::new(config);
686
687        // Add documents with varying token counts
688        index
689            .add(
690                "doc1".to_string(),
691                vec![vec![1.0, 0.0], vec![0.9, 0.1], vec![0.8, 0.2]],
692            )
693            .unwrap();
694        index
695            .add("doc2".to_string(), vec![vec![0.0, 1.0], vec![0.1, 0.9]])
696            .unwrap();
697        index.add("doc3".to_string(), vec![vec![0.5, 0.5]]).unwrap();
698
699        let stats = index.stats();
700        assert_eq!(stats.num_documents, 3);
701        assert_eq!(stats.total_tokens, 6); // 3 + 2 + 1
702        assert!((stats.avg_tokens_per_doc - 2.0).abs() < 0.01); // 6/3 = 2.0
703        assert_eq!(stats.dimension, 2);
704        assert!(stats.memory_bytes > 0);
705    }
706
707    #[test]
708    fn test_colbert_to_search_results() {
709        let config = ColbertConfig::default();
710        let mut index = ColbertIndex::new(config);
711
712        index
713            .add(
714                "doc1".to_string(),
715                vec![vec![1.0, 0.0, 0.0], vec![0.9, 0.1, 0.0]],
716            )
717            .unwrap();
718        index
719            .add(
720                "doc2".to_string(),
721                vec![vec![0.0, 1.0, 0.0], vec![0.1, 0.9, 0.0]],
722            )
723            .unwrap();
724
725        let query_tokens = vec![vec![1.0, 0.0, 0.0]];
726        let colbert_results = index.search(&query_tokens, 2).unwrap();
727
728        // Convert to standard search results
729        let search_results = index.to_search_results(colbert_results);
730
731        assert_eq!(search_results.len(), 2);
732        assert_eq!(search_results[0].rank, 1);
733        assert_eq!(search_results[1].rank, 2);
734        assert_eq!(search_results[0].entity_id, "doc1");
735    }
736
737    #[test]
738    fn test_colbert_large_scale() {
739        let config = ColbertConfig::default();
740        let mut index = ColbertIndex::new(config);
741
742        // Add 100 documents with multiple tokens each
743        for i in 0..100 {
744            let tokens: Vec<Vec<f32>> = (0..10)
745                .map(|j| vec![(i + j) as f32 / 100.0, 0.0, 0.0])
746                .collect();
747            index.add(format!("doc{}", i), tokens).unwrap();
748        }
749
750        assert_eq!(index.len(), 100);
751
752        // Search
753        let query_tokens = vec![vec![0.5, 0.0, 0.0], vec![0.6, 0.0, 0.0]];
754        let results = index.search(&query_tokens, 10).unwrap();
755
756        assert_eq!(results.len(), 10);
757        assert!(results[0].score >= results[9].score); // Sorted by score
758    }
759
760    #[test]
761    fn test_colbert_token_match_information() {
762        let config = ColbertConfig::default();
763        let mut index = ColbertIndex::new(config);
764
765        // Document with 3 distinct tokens
766        let doc_tokens = vec![
767            vec![1.0, 0.0, 0.0],
768            vec![0.0, 1.0, 0.0],
769            vec![0.0, 0.0, 1.0],
770        ];
771        index.add("doc1".to_string(), doc_tokens).unwrap();
772
773        // Query with 2 tokens
774        let query_tokens = vec![vec![1.0, 0.0, 0.0], vec![0.0, 0.0, 1.0]];
775
776        let results = index.search(&query_tokens, 1).unwrap();
777        assert_eq!(results.len(), 1);
778
779        // Check token matches
780        let token_matches = &results[0].token_matches;
781        assert_eq!(token_matches.len(), 2);
782
783        // First query token should match first doc token (index 0)
784        assert_eq!(token_matches[0].0, 0);
785        // Second query token should match third doc token (index 2)
786        assert_eq!(token_matches[1].0, 2);
787    }
788}