use anyhow::Result;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::simd;
use crate::types::{DistanceMetric, SearchResult};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColbertConfig {
pub metric: DistanceMetric,
pub max_doc_tokens: usize,
pub max_query_tokens: usize,
pub compress_tokens: bool,
pub parallel_search: bool,
}
impl Default for ColbertConfig {
fn default() -> Self {
Self {
metric: DistanceMetric::Cosine,
max_doc_tokens: 300,
max_query_tokens: 32,
compress_tokens: false,
parallel_search: true,
}
}
}
impl ColbertConfig {
pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
self.metric = metric;
self
}
pub fn with_max_doc_tokens(mut self, max_doc_tokens: usize) -> Self {
self.max_doc_tokens = max_doc_tokens;
self
}
pub fn with_max_query_tokens(mut self, max_query_tokens: usize) -> Self {
self.max_query_tokens = max_query_tokens;
self
}
pub fn with_compression(mut self, compress: bool) -> Self {
self.compress_tokens = compress;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiVectorDoc {
pub entity_id: String,
pub token_embeddings: Vec<Vec<f32>>,
}
#[derive(Debug, Clone)]
pub struct ColbertSearchResult {
pub entity_id: String,
pub score: f32,
pub token_matches: Vec<(usize, f32)>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColbertIndex {
config: ColbertConfig,
documents: Vec<MultiVectorDoc>,
dim: Option<usize>,
}
impl ColbertIndex {
pub fn new(config: ColbertConfig) -> Self {
Self {
config,
documents: Vec::new(),
dim: None,
}
}
pub fn build(&mut self, doc_tokens: &HashMap<String, Vec<Vec<f32>>>) -> Result<()> {
if doc_tokens.is_empty() {
anyhow::bail!("Cannot build ColBERT index with empty documents");
}
let first_doc_tokens = doc_tokens.values().next().unwrap();
if first_doc_tokens.is_empty() {
anyhow::bail!("Document has no token embeddings");
}
self.dim = Some(first_doc_tokens[0].len());
self.documents.clear();
for (entity_id, tokens) in doc_tokens {
let truncated_tokens = if tokens.len() > self.config.max_doc_tokens {
tokens[..self.config.max_doc_tokens].to_vec()
} else {
tokens.clone()
};
self.documents.push(MultiVectorDoc {
entity_id: entity_id.clone(),
token_embeddings: truncated_tokens,
});
}
Ok(())
}
pub fn add(&mut self, entity_id: String, token_embeddings: Vec<Vec<f32>>) -> Result<()> {
if token_embeddings.is_empty() {
anyhow::bail!("Cannot add document with no token embeddings");
}
if self.dim.is_none() {
self.dim = Some(token_embeddings[0].len());
}
let dim = self.dim.unwrap();
for token in &token_embeddings {
if token.len() != dim {
anyhow::bail!(
"Token dimension {} does not match index dimension {}",
token.len(),
dim
);
}
}
let truncated_tokens = if token_embeddings.len() > self.config.max_doc_tokens {
token_embeddings[..self.config.max_doc_tokens].to_vec()
} else {
token_embeddings
};
self.documents.push(MultiVectorDoc {
entity_id,
token_embeddings: truncated_tokens,
});
Ok(())
}
pub fn search(&self, query_tokens: &[Vec<f32>], k: usize) -> Result<Vec<ColbertSearchResult>> {
if self.documents.is_empty() {
return Ok(Vec::new());
}
let query = if query_tokens.len() > self.config.max_query_tokens {
&query_tokens[..self.config.max_query_tokens]
} else {
query_tokens
};
let results: Vec<ColbertSearchResult> = if self.config.parallel_search {
self.documents
.par_iter()
.map(|doc| self.compute_maxsim_score(query, doc))
.collect()
} else {
self.documents
.iter()
.map(|doc| self.compute_maxsim_score(query, doc))
.collect()
};
let mut sorted_results = results;
sorted_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
Ok(sorted_results.into_iter().take(k).collect())
}
#[inline]
fn compute_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
simd::compute_distance_simd(self.config.metric, a, b)
}
fn compute_maxsim_score(
&self,
query_tokens: &[Vec<f32>],
doc: &MultiVectorDoc,
) -> ColbertSearchResult {
let mut total_score = 0.0;
let mut token_matches = Vec::with_capacity(query_tokens.len());
for query_token in query_tokens {
let (best_doc_idx, best_score) = doc
.token_embeddings
.iter()
.enumerate()
.map(|(idx, doc_token)| {
let score = self.compute_similarity(query_token, doc_token);
(idx, score)
})
.max_by(|(_, a), (_, b)| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or((0, 0.0));
total_score += best_score;
token_matches.push((best_doc_idx, best_score));
}
ColbertSearchResult {
entity_id: doc.entity_id.clone(),
score: total_score,
token_matches,
}
}
pub fn to_search_results(&self, results: Vec<ColbertSearchResult>) -> Vec<SearchResult> {
results
.into_iter()
.enumerate()
.map(|(rank, r)| SearchResult {
entity_id: r.entity_id,
score: r.score,
distance: r.score,
rank: rank + 1,
})
.collect()
}
pub fn stats(&self) -> ColbertStats {
let total_tokens: usize = self
.documents
.iter()
.map(|d| d.token_embeddings.len())
.sum();
let avg_tokens = if self.documents.is_empty() {
0.0
} else {
total_tokens as f32 / self.documents.len() as f32
};
let memory_bytes = self.estimate_memory();
ColbertStats {
num_documents: self.documents.len(),
total_tokens,
avg_tokens_per_doc: avg_tokens,
dimension: self.dim.unwrap_or(0),
memory_bytes,
}
}
fn estimate_memory(&self) -> usize {
let total_tokens: usize = self
.documents
.iter()
.map(|d| d.token_embeddings.len())
.sum();
let dim = self.dim.unwrap_or(0);
total_tokens * dim * 4
}
pub fn remove(&mut self, entity_id: &str) -> bool {
if let Some(pos) = self.documents.iter().position(|d| d.entity_id == entity_id) {
self.documents.remove(pos);
true
} else {
false
}
}
pub fn len(&self) -> usize {
self.documents.len()
}
pub fn is_empty(&self) -> bool {
self.documents.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct ColbertStats {
pub num_documents: usize,
pub total_tokens: usize,
pub avg_tokens_per_doc: f32,
pub dimension: usize,
pub memory_bytes: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_colbert_creation() {
let config = ColbertConfig::default();
let index = ColbertIndex::new(config);
assert_eq!(index.len(), 0);
assert!(index.is_empty());
}
#[test]
fn test_colbert_add_document() {
let config = ColbertConfig::default();
let mut index = ColbertIndex::new(config);
let tokens = vec![vec![0.1, 0.2, 0.3], vec![0.2, 0.3, 0.4]];
assert!(index.add("doc1".to_string(), tokens).is_ok());
assert_eq!(index.len(), 1);
}
#[test]
fn test_colbert_search() {
let config = ColbertConfig::default();
let mut index = ColbertIndex::new(config);
let doc1_tokens = vec![
vec![1.0, 0.0, 0.0],
vec![0.9, 0.1, 0.0],
vec![0.8, 0.2, 0.0],
];
let doc2_tokens = vec![
vec![0.0, 1.0, 0.0],
vec![0.1, 0.9, 0.0],
vec![0.2, 0.8, 0.0],
];
assert!(index.add("doc1".to_string(), doc1_tokens).is_ok());
assert!(index.add("doc2".to_string(), doc2_tokens).is_ok());
let query_tokens = vec![vec![0.95, 0.05, 0.0], vec![0.85, 0.15, 0.0]];
let results = index.search(&query_tokens, 2);
assert!(results.is_ok());
let results = results.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].entity_id, "doc1");
assert!(results[0].score > results[1].score);
}
#[test]
fn test_colbert_maxsim_scoring() {
let config = ColbertConfig::default();
let mut index = ColbertIndex::new(config);
let doc_tokens = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
assert!(index.add("doc1".to_string(), doc_tokens).is_ok());
let query_tokens = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
let results = index.search(&query_tokens, 1);
assert!(results.is_ok());
let results = results.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].token_matches.len(), 2);
}
#[test]
fn test_colbert_remove() {
let config = ColbertConfig::default();
let mut index = ColbertIndex::new(config);
let tokens = vec![vec![0.1, 0.2, 0.3]];
assert!(index.add("doc1".to_string(), tokens.clone()).is_ok());
assert!(index.add("doc2".to_string(), tokens).is_ok());
assert_eq!(index.len(), 2);
assert!(index.remove("doc1"));
assert_eq!(index.len(), 1);
assert!(!index.remove("doc1")); }
#[test]
fn test_colbert_build_from_hashmap() {
let config = ColbertConfig::default();
let mut index = ColbertIndex::new(config);
let mut doc_tokens = HashMap::new();
doc_tokens.insert(
"doc1".to_string(),
vec![vec![1.0, 0.0, 0.0], vec![0.9, 0.1, 0.0]],
);
doc_tokens.insert(
"doc2".to_string(),
vec![vec![0.0, 1.0, 0.0], vec![0.1, 0.9, 0.0]],
);
doc_tokens.insert(
"doc3".to_string(),
vec![vec![0.0, 0.0, 1.0], vec![0.0, 0.1, 0.9]],
);
let build_result = index.build(&doc_tokens);
assert!(build_result.is_ok());
assert_eq!(index.len(), 3);
let query_tokens = vec![vec![1.0, 0.0, 0.0]];
let results = index.search(&query_tokens, 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].entity_id, "doc1");
}
#[test]
fn test_colbert_token_truncation() {
let config = ColbertConfig::default().with_max_doc_tokens(5);
let mut index = ColbertIndex::new(config);
let long_doc_tokens: Vec<Vec<f32>> =
(0..10).map(|i| vec![i as f32 / 10.0, 0.0, 0.0]).collect();
assert!(index.add("doc1".to_string(), long_doc_tokens).is_ok());
assert_eq!(index.documents[0].token_embeddings.len(), 5);
}
#[test]
fn test_colbert_query_truncation() {
let config = ColbertConfig::default().with_max_query_tokens(3);
let mut index = ColbertIndex::new(config);
let doc_tokens = vec![vec![1.0, 0.0, 0.0], vec![0.9, 0.1, 0.0]];
assert!(index.add("doc1".to_string(), doc_tokens).is_ok());
let long_query: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32 / 10.0, 0.0, 0.0]).collect();
let results = index.search(&long_query, 1);
assert!(results.is_ok());
let results = results.unwrap();
assert_eq!(results[0].token_matches.len(), 3);
}
#[test]
fn test_colbert_parallel_vs_sequential() {
let config_parallel = ColbertConfig::default().with_compression(false);
let mut index_parallel = ColbertIndex::new(config_parallel);
let config_sequential = ColbertConfig {
parallel_search: false,
..Default::default()
};
let mut index_sequential = ColbertIndex::new(config_sequential);
let mut doc_tokens = HashMap::new();
for i in 0..20 {
let tokens: Vec<Vec<f32>> = (0..10)
.map(|j| vec![(i + j) as f32 / 20.0, 0.0, 0.0])
.collect();
doc_tokens.insert(format!("doc{}", i), tokens);
}
assert!(index_parallel.build(&doc_tokens).is_ok());
assert!(index_sequential.build(&doc_tokens).is_ok());
let query_tokens = vec![vec![0.5, 0.0, 0.0]];
let results_parallel = index_parallel.search(&query_tokens, 5).unwrap();
let results_sequential = index_sequential.search(&query_tokens, 5).unwrap();
assert_eq!(results_parallel.len(), results_sequential.len());
assert_eq!(
results_parallel[0].entity_id,
results_sequential[0].entity_id
);
}
#[test]
fn test_colbert_different_metrics() {
let metrics = vec![
DistanceMetric::Cosine,
DistanceMetric::Euclidean,
DistanceMetric::DotProduct,
DistanceMetric::Manhattan,
];
for metric in metrics {
let config = ColbertConfig::default().with_metric(metric);
let mut index = ColbertIndex::new(config);
let doc_tokens = vec![vec![1.0, 0.0, 0.0], vec![0.9, 0.1, 0.0]];
assert!(index.add("doc1".to_string(), doc_tokens).is_ok());
let query_tokens = vec![vec![1.0, 0.0, 0.0]];
let results = index.search(&query_tokens, 1);
assert!(results.is_ok());
}
}
#[test]
fn test_colbert_empty_index_search() {
let config = ColbertConfig::default();
let index = ColbertIndex::new(config);
let query_tokens = vec![vec![1.0, 0.0, 0.0]];
let results = index.search(&query_tokens, 5);
assert!(results.is_ok());
assert_eq!(results.unwrap().len(), 0);
}
#[test]
fn test_colbert_empty_tokens_error() {
let config = ColbertConfig::default();
let mut index = ColbertIndex::new(config);
let empty_tokens: Vec<Vec<f32>> = vec![];
let result = index.add("doc1".to_string(), empty_tokens);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Cannot add document with no token embeddings"));
}
#[test]
fn test_colbert_dimension_mismatch_error() {
let config = ColbertConfig::default();
let mut index = ColbertIndex::new(config);
let doc1_tokens = vec![vec![1.0, 0.0, 0.0]];
assert!(index.add("doc1".to_string(), doc1_tokens).is_ok());
let doc2_tokens = vec![vec![1.0, 0.0, 0.0, 0.0]];
let result = index.add("doc2".to_string(), doc2_tokens);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("does not match index dimension"));
}
#[test]
fn test_colbert_build_empty_error() {
let config = ColbertConfig::default();
let mut index = ColbertIndex::new(config);
let empty_docs = HashMap::new();
let result = index.build(&empty_docs);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Cannot build ColBERT index with empty documents"));
}
#[test]
fn test_colbert_build_empty_tokens_error() {
let config = ColbertConfig::default();
let mut index = ColbertIndex::new(config);
let mut doc_tokens = HashMap::new();
doc_tokens.insert("doc1".to_string(), vec![]);
let result = index.build(&doc_tokens);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Document has no token embeddings"));
}
#[test]
fn test_colbert_stats() {
let config = ColbertConfig::default();
let mut index = ColbertIndex::new(config);
index
.add(
"doc1".to_string(),
vec![vec![1.0, 0.0], vec![0.9, 0.1], vec![0.8, 0.2]],
)
.unwrap();
index
.add("doc2".to_string(), vec![vec![0.0, 1.0], vec![0.1, 0.9]])
.unwrap();
index.add("doc3".to_string(), vec![vec![0.5, 0.5]]).unwrap();
let stats = index.stats();
assert_eq!(stats.num_documents, 3);
assert_eq!(stats.total_tokens, 6); assert!((stats.avg_tokens_per_doc - 2.0).abs() < 0.01); assert_eq!(stats.dimension, 2);
assert!(stats.memory_bytes > 0);
}
#[test]
fn test_colbert_to_search_results() {
let config = ColbertConfig::default();
let mut index = ColbertIndex::new(config);
index
.add(
"doc1".to_string(),
vec![vec![1.0, 0.0, 0.0], vec![0.9, 0.1, 0.0]],
)
.unwrap();
index
.add(
"doc2".to_string(),
vec![vec![0.0, 1.0, 0.0], vec![0.1, 0.9, 0.0]],
)
.unwrap();
let query_tokens = vec![vec![1.0, 0.0, 0.0]];
let colbert_results = index.search(&query_tokens, 2).unwrap();
let search_results = index.to_search_results(colbert_results);
assert_eq!(search_results.len(), 2);
assert_eq!(search_results[0].rank, 1);
assert_eq!(search_results[1].rank, 2);
assert_eq!(search_results[0].entity_id, "doc1");
}
#[test]
fn test_colbert_large_scale() {
let config = ColbertConfig::default();
let mut index = ColbertIndex::new(config);
for i in 0..100 {
let tokens: Vec<Vec<f32>> = (0..10)
.map(|j| vec![(i + j) as f32 / 100.0, 0.0, 0.0])
.collect();
index.add(format!("doc{}", i), tokens).unwrap();
}
assert_eq!(index.len(), 100);
let query_tokens = vec![vec![0.5, 0.0, 0.0], vec![0.6, 0.0, 0.0]];
let results = index.search(&query_tokens, 10).unwrap();
assert_eq!(results.len(), 10);
assert!(results[0].score >= results[9].score); }
#[test]
fn test_colbert_token_match_information() {
let config = ColbertConfig::default();
let mut index = ColbertIndex::new(config);
let doc_tokens = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
index.add("doc1".to_string(), doc_tokens).unwrap();
let query_tokens = vec![vec![1.0, 0.0, 0.0], vec![0.0, 0.0, 1.0]];
let results = index.search(&query_tokens, 1).unwrap();
assert_eq!(results.len(), 1);
let token_matches = &results[0].token_matches;
assert_eq!(token_matches.len(), 2);
assert_eq!(token_matches[0].0, 0);
assert_eq!(token_matches[1].0, 2);
}
}