1use anyhow::Result;
51use rayon::prelude::*;
52use serde::{Deserialize, Serialize};
53use std::collections::HashMap;
54
55use crate::simd;
56use crate::types::{DistanceMetric, SearchResult};
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct ColbertConfig {
61 pub metric: DistanceMetric,
64
65 pub max_doc_tokens: usize,
68
69 pub max_query_tokens: usize,
71
72 pub compress_tokens: bool,
74
75 pub parallel_search: bool,
77}
78
79impl Default for ColbertConfig {
80 fn default() -> Self {
81 Self {
82 metric: DistanceMetric::Cosine,
83 max_doc_tokens: 300,
84 max_query_tokens: 32,
85 compress_tokens: false,
86 parallel_search: true,
87 }
88 }
89}
90
91impl ColbertConfig {
92 pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
93 self.metric = metric;
94 self
95 }
96
97 pub fn with_max_doc_tokens(mut self, max_doc_tokens: usize) -> Self {
98 self.max_doc_tokens = max_doc_tokens;
99 self
100 }
101
102 pub fn with_max_query_tokens(mut self, max_query_tokens: usize) -> Self {
103 self.max_query_tokens = max_query_tokens;
104 self
105 }
106
107 pub fn with_compression(mut self, compress: bool) -> Self {
108 self.compress_tokens = compress;
109 self
110 }
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct MultiVectorDoc {
116 pub entity_id: String,
117 pub token_embeddings: Vec<Vec<f32>>,
118}
119
120#[derive(Debug, Clone)]
122pub struct ColbertSearchResult {
123 pub entity_id: String,
124 pub score: f32,
125 pub token_matches: Vec<(usize, f32)>,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct ColbertIndex {
132 config: ColbertConfig,
133 documents: Vec<MultiVectorDoc>,
134 dim: Option<usize>,
135}
136
137impl ColbertIndex {
138 pub fn new(config: ColbertConfig) -> Self {
139 Self {
140 config,
141 documents: Vec::new(),
142 dim: None,
143 }
144 }
145
146 pub fn build(&mut self, doc_tokens: &HashMap<String, Vec<Vec<f32>>>) -> Result<()> {
148 if doc_tokens.is_empty() {
149 anyhow::bail!("Cannot build ColBERT index with empty documents");
150 }
151
152 let first_doc_tokens = doc_tokens.values().next().unwrap();
154 if first_doc_tokens.is_empty() {
155 anyhow::bail!("Document has no token embeddings");
156 }
157 self.dim = Some(first_doc_tokens[0].len());
158
159 self.documents.clear();
161 for (entity_id, tokens) in doc_tokens {
162 let truncated_tokens = if tokens.len() > self.config.max_doc_tokens {
164 tokens[..self.config.max_doc_tokens].to_vec()
165 } else {
166 tokens.clone()
167 };
168
169 self.documents.push(MultiVectorDoc {
170 entity_id: entity_id.clone(),
171 token_embeddings: truncated_tokens,
172 });
173 }
174
175 Ok(())
176 }
177
178 pub fn add(&mut self, entity_id: String, token_embeddings: Vec<Vec<f32>>) -> Result<()> {
180 if token_embeddings.is_empty() {
181 anyhow::bail!("Cannot add document with no token embeddings");
182 }
183
184 if self.dim.is_none() {
186 self.dim = Some(token_embeddings[0].len());
187 }
188
189 let dim = self.dim.unwrap();
191 for token in &token_embeddings {
192 if token.len() != dim {
193 anyhow::bail!(
194 "Token dimension {} does not match index dimension {}",
195 token.len(),
196 dim
197 );
198 }
199 }
200
201 let truncated_tokens = if token_embeddings.len() > self.config.max_doc_tokens {
203 token_embeddings[..self.config.max_doc_tokens].to_vec()
204 } else {
205 token_embeddings
206 };
207
208 self.documents.push(MultiVectorDoc {
209 entity_id,
210 token_embeddings: truncated_tokens,
211 });
212
213 Ok(())
214 }
215
216 pub fn search(&self, query_tokens: &[Vec<f32>], k: usize) -> Result<Vec<ColbertSearchResult>> {
218 if self.documents.is_empty() {
219 return Ok(Vec::new());
220 }
221
222 let query = if query_tokens.len() > self.config.max_query_tokens {
224 &query_tokens[..self.config.max_query_tokens]
225 } else {
226 query_tokens
227 };
228
229 let results: Vec<ColbertSearchResult> = if self.config.parallel_search {
231 self.documents
232 .par_iter()
233 .map(|doc| self.compute_maxsim_score(query, doc))
234 .collect()
235 } else {
236 self.documents
237 .iter()
238 .map(|doc| self.compute_maxsim_score(query, doc))
239 .collect()
240 };
241
242 let mut sorted_results = results;
244 sorted_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
245
246 Ok(sorted_results.into_iter().take(k).collect())
247 }
248
249 #[inline]
253 fn compute_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
254 simd::compute_distance_simd(self.config.metric, a, b)
256 }
257
258 fn compute_maxsim_score(
260 &self,
261 query_tokens: &[Vec<f32>],
262 doc: &MultiVectorDoc,
263 ) -> ColbertSearchResult {
264 let mut total_score = 0.0;
265 let mut token_matches = Vec::with_capacity(query_tokens.len());
266
267 for query_token in query_tokens {
268 let (best_doc_idx, best_score) = doc
270 .token_embeddings
271 .iter()
272 .enumerate()
273 .map(|(idx, doc_token)| {
274 let score = self.compute_similarity(query_token, doc_token);
275 (idx, score)
276 })
277 .max_by(|(_, a), (_, b)| {
278 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
280 })
281 .unwrap_or((0, 0.0));
282
283 total_score += best_score;
284 token_matches.push((best_doc_idx, best_score));
285 }
286
287 ColbertSearchResult {
288 entity_id: doc.entity_id.clone(),
289 score: total_score,
290 token_matches,
291 }
292 }
293
294 pub fn to_search_results(&self, results: Vec<ColbertSearchResult>) -> Vec<SearchResult> {
296 results
297 .into_iter()
298 .enumerate()
299 .map(|(rank, r)| SearchResult {
300 entity_id: r.entity_id,
301 score: r.score,
302 distance: r.score,
303 rank: rank + 1,
304 })
305 .collect()
306 }
307
308 pub fn stats(&self) -> ColbertStats {
310 let total_tokens: usize = self
311 .documents
312 .iter()
313 .map(|d| d.token_embeddings.len())
314 .sum();
315
316 let avg_tokens = if self.documents.is_empty() {
317 0.0
318 } else {
319 total_tokens as f32 / self.documents.len() as f32
320 };
321
322 let memory_bytes = self.estimate_memory();
323
324 ColbertStats {
325 num_documents: self.documents.len(),
326 total_tokens,
327 avg_tokens_per_doc: avg_tokens,
328 dimension: self.dim.unwrap_or(0),
329 memory_bytes,
330 }
331 }
332
333 fn estimate_memory(&self) -> usize {
334 let total_tokens: usize = self
335 .documents
336 .iter()
337 .map(|d| d.token_embeddings.len())
338 .sum();
339 let dim = self.dim.unwrap_or(0);
340
341 total_tokens * dim * 4
343 }
344
345 pub fn remove(&mut self, entity_id: &str) -> bool {
347 if let Some(pos) = self.documents.iter().position(|d| d.entity_id == entity_id) {
348 self.documents.remove(pos);
349 true
350 } else {
351 false
352 }
353 }
354
355 pub fn len(&self) -> usize {
357 self.documents.len()
358 }
359
360 pub fn is_empty(&self) -> bool {
362 self.documents.is_empty()
363 }
364}
365
366#[derive(Debug, Clone)]
368pub struct ColbertStats {
369 pub num_documents: usize,
370 pub total_tokens: usize,
371 pub avg_tokens_per_doc: f32,
372 pub dimension: usize,
373 pub memory_bytes: usize,
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[test]
381 fn test_colbert_creation() {
382 let config = ColbertConfig::default();
383 let index = ColbertIndex::new(config);
384
385 assert_eq!(index.len(), 0);
386 assert!(index.is_empty());
387 }
388
389 #[test]
390 fn test_colbert_add_document() {
391 let config = ColbertConfig::default();
392 let mut index = ColbertIndex::new(config);
393
394 let tokens = vec![vec![0.1, 0.2, 0.3], vec![0.2, 0.3, 0.4]];
395
396 assert!(index.add("doc1".to_string(), tokens).is_ok());
397 assert_eq!(index.len(), 1);
398 }
399
400 #[test]
401 fn test_colbert_search() {
402 let config = ColbertConfig::default();
403 let mut index = ColbertIndex::new(config);
404
405 let doc1_tokens = vec![
407 vec![1.0, 0.0, 0.0],
408 vec![0.9, 0.1, 0.0],
409 vec![0.8, 0.2, 0.0],
410 ];
411
412 let doc2_tokens = vec![
413 vec![0.0, 1.0, 0.0],
414 vec![0.1, 0.9, 0.0],
415 vec![0.2, 0.8, 0.0],
416 ];
417
418 assert!(index.add("doc1".to_string(), doc1_tokens).is_ok());
419 assert!(index.add("doc2".to_string(), doc2_tokens).is_ok());
420
421 let query_tokens = vec![vec![0.95, 0.05, 0.0], vec![0.85, 0.15, 0.0]];
423
424 let results = index.search(&query_tokens, 2);
425 assert!(results.is_ok());
426
427 let results = results.unwrap();
428 assert_eq!(results.len(), 2);
429
430 assert_eq!(results[0].entity_id, "doc1");
432 assert!(results[0].score > results[1].score);
433 }
434
435 #[test]
436 fn test_colbert_maxsim_scoring() {
437 let config = ColbertConfig::default();
438 let mut index = ColbertIndex::new(config);
439
440 let doc_tokens = vec![
442 vec![1.0, 0.0, 0.0],
443 vec![0.0, 1.0, 0.0],
444 vec![0.0, 0.0, 1.0],
445 ];
446
447 assert!(index.add("doc1".to_string(), doc_tokens).is_ok());
448
449 let query_tokens = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
451
452 let results = index.search(&query_tokens, 1);
453 assert!(results.is_ok());
454
455 let results = results.unwrap();
456 assert_eq!(results.len(), 1);
457
458 assert_eq!(results[0].token_matches.len(), 2);
460 }
461
462 #[test]
463 fn test_colbert_remove() {
464 let config = ColbertConfig::default();
465 let mut index = ColbertIndex::new(config);
466
467 let tokens = vec![vec![0.1, 0.2, 0.3]];
468
469 assert!(index.add("doc1".to_string(), tokens.clone()).is_ok());
470 assert!(index.add("doc2".to_string(), tokens).is_ok());
471
472 assert_eq!(index.len(), 2);
473
474 assert!(index.remove("doc1"));
475 assert_eq!(index.len(), 1);
476
477 assert!(!index.remove("doc1")); }
479
480 #[test]
481 fn test_colbert_build_from_hashmap() {
482 let config = ColbertConfig::default();
483 let mut index = ColbertIndex::new(config);
484
485 let mut doc_tokens = HashMap::new();
486 doc_tokens.insert(
487 "doc1".to_string(),
488 vec![vec![1.0, 0.0, 0.0], vec![0.9, 0.1, 0.0]],
489 );
490 doc_tokens.insert(
491 "doc2".to_string(),
492 vec![vec![0.0, 1.0, 0.0], vec![0.1, 0.9, 0.0]],
493 );
494 doc_tokens.insert(
495 "doc3".to_string(),
496 vec![vec![0.0, 0.0, 1.0], vec![0.0, 0.1, 0.9]],
497 );
498
499 let build_result = index.build(&doc_tokens);
500 assert!(build_result.is_ok());
501 assert_eq!(index.len(), 3);
502
503 let query_tokens = vec![vec![1.0, 0.0, 0.0]];
505 let results = index.search(&query_tokens, 2).unwrap();
506 assert_eq!(results.len(), 2);
507 assert_eq!(results[0].entity_id, "doc1");
508 }
509
510 #[test]
511 fn test_colbert_token_truncation() {
512 let config = ColbertConfig::default().with_max_doc_tokens(5);
513 let mut index = ColbertIndex::new(config);
514
515 let long_doc_tokens: Vec<Vec<f32>> =
517 (0..10).map(|i| vec![i as f32 / 10.0, 0.0, 0.0]).collect();
518
519 assert!(index.add("doc1".to_string(), long_doc_tokens).is_ok());
520
521 assert_eq!(index.documents[0].token_embeddings.len(), 5);
523 }
524
525 #[test]
526 fn test_colbert_query_truncation() {
527 let config = ColbertConfig::default().with_max_query_tokens(3);
528 let mut index = ColbertIndex::new(config);
529
530 let doc_tokens = vec![vec![1.0, 0.0, 0.0], vec![0.9, 0.1, 0.0]];
531 assert!(index.add("doc1".to_string(), doc_tokens).is_ok());
532
533 let long_query: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32 / 10.0, 0.0, 0.0]).collect();
535
536 let results = index.search(&long_query, 1);
537 assert!(results.is_ok());
538
539 let results = results.unwrap();
540 assert_eq!(results[0].token_matches.len(), 3);
542 }
543
544 #[test]
545 fn test_colbert_parallel_vs_sequential() {
546 let config_parallel = ColbertConfig::default().with_compression(false);
548 let mut index_parallel = ColbertIndex::new(config_parallel);
549
550 let config_sequential = ColbertConfig {
552 parallel_search: false,
553 ..Default::default()
554 };
555 let mut index_sequential = ColbertIndex::new(config_sequential);
556
557 let mut doc_tokens = HashMap::new();
559 for i in 0..20 {
560 let tokens: Vec<Vec<f32>> = (0..10)
561 .map(|j| vec![(i + j) as f32 / 20.0, 0.0, 0.0])
562 .collect();
563 doc_tokens.insert(format!("doc{}", i), tokens);
564 }
565
566 assert!(index_parallel.build(&doc_tokens).is_ok());
567 assert!(index_sequential.build(&doc_tokens).is_ok());
568
569 let query_tokens = vec![vec![0.5, 0.0, 0.0]];
571 let results_parallel = index_parallel.search(&query_tokens, 5).unwrap();
572 let results_sequential = index_sequential.search(&query_tokens, 5).unwrap();
573
574 assert_eq!(results_parallel.len(), results_sequential.len());
576 assert_eq!(
577 results_parallel[0].entity_id,
578 results_sequential[0].entity_id
579 );
580 }
581
582 #[test]
583 fn test_colbert_different_metrics() {
584 let metrics = vec![
585 DistanceMetric::Cosine,
586 DistanceMetric::Euclidean,
587 DistanceMetric::DotProduct,
588 DistanceMetric::Manhattan,
589 ];
590
591 for metric in metrics {
592 let config = ColbertConfig::default().with_metric(metric);
593 let mut index = ColbertIndex::new(config);
594
595 let doc_tokens = vec![vec![1.0, 0.0, 0.0], vec![0.9, 0.1, 0.0]];
596 assert!(index.add("doc1".to_string(), doc_tokens).is_ok());
597
598 let query_tokens = vec![vec![1.0, 0.0, 0.0]];
599 let results = index.search(&query_tokens, 1);
600 assert!(results.is_ok());
601 }
602 }
603
604 #[test]
605 fn test_colbert_empty_index_search() {
606 let config = ColbertConfig::default();
607 let index = ColbertIndex::new(config);
608
609 let query_tokens = vec![vec![1.0, 0.0, 0.0]];
610 let results = index.search(&query_tokens, 5);
611
612 assert!(results.is_ok());
613 assert_eq!(results.unwrap().len(), 0);
614 }
615
616 #[test]
617 fn test_colbert_empty_tokens_error() {
618 let config = ColbertConfig::default();
619 let mut index = ColbertIndex::new(config);
620
621 let empty_tokens: Vec<Vec<f32>> = vec![];
622 let result = index.add("doc1".to_string(), empty_tokens);
623
624 assert!(result.is_err());
625 assert!(result
626 .unwrap_err()
627 .to_string()
628 .contains("Cannot add document with no token embeddings"));
629 }
630
631 #[test]
632 fn test_colbert_dimension_mismatch_error() {
633 let config = ColbertConfig::default();
634 let mut index = ColbertIndex::new(config);
635
636 let doc1_tokens = vec![vec![1.0, 0.0, 0.0]];
638 assert!(index.add("doc1".to_string(), doc1_tokens).is_ok());
639
640 let doc2_tokens = vec![vec![1.0, 0.0, 0.0, 0.0]];
642 let result = index.add("doc2".to_string(), doc2_tokens);
643
644 assert!(result.is_err());
645 assert!(result
646 .unwrap_err()
647 .to_string()
648 .contains("does not match index dimension"));
649 }
650
651 #[test]
652 fn test_colbert_build_empty_error() {
653 let config = ColbertConfig::default();
654 let mut index = ColbertIndex::new(config);
655
656 let empty_docs = HashMap::new();
657 let result = index.build(&empty_docs);
658
659 assert!(result.is_err());
660 assert!(result
661 .unwrap_err()
662 .to_string()
663 .contains("Cannot build ColBERT index with empty documents"));
664 }
665
666 #[test]
667 fn test_colbert_build_empty_tokens_error() {
668 let config = ColbertConfig::default();
669 let mut index = ColbertIndex::new(config);
670
671 let mut doc_tokens = HashMap::new();
672 doc_tokens.insert("doc1".to_string(), vec![]); let result = index.build(&doc_tokens);
675 assert!(result.is_err());
676 assert!(result
677 .unwrap_err()
678 .to_string()
679 .contains("Document has no token embeddings"));
680 }
681
682 #[test]
683 fn test_colbert_stats() {
684 let config = ColbertConfig::default();
685 let mut index = ColbertIndex::new(config);
686
687 index
689 .add(
690 "doc1".to_string(),
691 vec![vec![1.0, 0.0], vec![0.9, 0.1], vec![0.8, 0.2]],
692 )
693 .unwrap();
694 index
695 .add("doc2".to_string(), vec![vec![0.0, 1.0], vec![0.1, 0.9]])
696 .unwrap();
697 index.add("doc3".to_string(), vec![vec![0.5, 0.5]]).unwrap();
698
699 let stats = index.stats();
700 assert_eq!(stats.num_documents, 3);
701 assert_eq!(stats.total_tokens, 6); assert!((stats.avg_tokens_per_doc - 2.0).abs() < 0.01); assert_eq!(stats.dimension, 2);
704 assert!(stats.memory_bytes > 0);
705 }
706
707 #[test]
708 fn test_colbert_to_search_results() {
709 let config = ColbertConfig::default();
710 let mut index = ColbertIndex::new(config);
711
712 index
713 .add(
714 "doc1".to_string(),
715 vec![vec![1.0, 0.0, 0.0], vec![0.9, 0.1, 0.0]],
716 )
717 .unwrap();
718 index
719 .add(
720 "doc2".to_string(),
721 vec![vec![0.0, 1.0, 0.0], vec![0.1, 0.9, 0.0]],
722 )
723 .unwrap();
724
725 let query_tokens = vec![vec![1.0, 0.0, 0.0]];
726 let colbert_results = index.search(&query_tokens, 2).unwrap();
727
728 let search_results = index.to_search_results(colbert_results);
730
731 assert_eq!(search_results.len(), 2);
732 assert_eq!(search_results[0].rank, 1);
733 assert_eq!(search_results[1].rank, 2);
734 assert_eq!(search_results[0].entity_id, "doc1");
735 }
736
737 #[test]
738 fn test_colbert_large_scale() {
739 let config = ColbertConfig::default();
740 let mut index = ColbertIndex::new(config);
741
742 for i in 0..100 {
744 let tokens: Vec<Vec<f32>> = (0..10)
745 .map(|j| vec![(i + j) as f32 / 100.0, 0.0, 0.0])
746 .collect();
747 index.add(format!("doc{}", i), tokens).unwrap();
748 }
749
750 assert_eq!(index.len(), 100);
751
752 let query_tokens = vec![vec![0.5, 0.0, 0.0], vec![0.6, 0.0, 0.0]];
754 let results = index.search(&query_tokens, 10).unwrap();
755
756 assert_eq!(results.len(), 10);
757 assert!(results[0].score >= results[9].score); }
759
760 #[test]
761 fn test_colbert_token_match_information() {
762 let config = ColbertConfig::default();
763 let mut index = ColbertIndex::new(config);
764
765 let doc_tokens = vec![
767 vec![1.0, 0.0, 0.0],
768 vec![0.0, 1.0, 0.0],
769 vec![0.0, 0.0, 1.0],
770 ];
771 index.add("doc1".to_string(), doc_tokens).unwrap();
772
773 let query_tokens = vec![vec![1.0, 0.0, 0.0], vec![0.0, 0.0, 1.0]];
775
776 let results = index.search(&query_tokens, 1).unwrap();
777 assert_eq!(results.len(), 1);
778
779 let token_matches = &results[0].token_matches;
781 assert_eq!(token_matches.len(), 2);
782
783 assert_eq!(token_matches[0].0, 0);
785 assert_eq!(token_matches[1].0, 2);
787 }
788}