use anyhow::Result;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct ColBERTConfig {
pub max_query_tokens: usize,
pub max_doc_tokens: usize,
pub embedding_dim: usize,
pub similarity_metric: SimilarityMetric,
pub normalize: bool,
}
impl Default for ColBERTConfig {
fn default() -> Self {
Self {
max_query_tokens: 32,
max_doc_tokens: 128,
embedding_dim: 128, similarity_metric: SimilarityMetric::Cosine,
normalize: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum SimilarityMetric {
Cosine,
DotProduct,
L2,
}
#[derive(Debug, Clone)]
pub struct TokenEmbeddings {
pub embeddings: Vec<Vec<f32>>,
pub tokens: Option<Vec<String>>,
}
impl TokenEmbeddings {
pub fn new(embeddings: Vec<Vec<f32>>) -> Self {
Self {
embeddings,
tokens: None,
}
}
pub fn with_tokens(embeddings: Vec<Vec<f32>>, tokens: Vec<String>) -> Self {
Self {
embeddings,
tokens: Some(tokens),
}
}
pub fn num_tokens(&self) -> usize {
self.embeddings.len()
}
pub fn embedding_dim(&self) -> usize {
self.embeddings.first().map(|v| v.len()).unwrap_or(0)
}
}
pub struct ColBERTReranker {
config: ColBERTConfig,
doc_cache: HashMap<String, TokenEmbeddings>,
}
impl ColBERTReranker {
pub fn new(config: ColBERTConfig) -> Result<Self> {
Ok(Self {
config,
doc_cache: HashMap::new(),
})
}
pub async fn encode_query(&self, query: &str) -> Result<TokenEmbeddings> {
let tokens: Vec<String> = query.split_whitespace().map(|s| s.to_string()).collect();
let num_tokens = tokens.len().min(self.config.max_query_tokens);
let embeddings: Vec<Vec<f32>> = (0..num_tokens)
.map(|_| {
let vec: Vec<f32> = (0..self.config.embedding_dim)
.map(|_| rand::random::<f32>() - 0.5)
.collect();
if self.config.normalize {
Self::normalize_vector(vec)
} else {
vec
}
})
.collect();
Ok(TokenEmbeddings::with_tokens(
embeddings,
tokens[..num_tokens].to_vec(),
))
}
pub async fn encode_document(&self, document: &str) -> Result<TokenEmbeddings> {
let tokens: Vec<String> = document.split_whitespace().map(|s| s.to_string()).collect();
let num_tokens = tokens.len().min(self.config.max_doc_tokens);
let embeddings: Vec<Vec<f32>> = (0..num_tokens)
.map(|_| {
let vec: Vec<f32> = (0..self.config.embedding_dim)
.map(|_| rand::random::<f32>() - 0.5)
.collect();
if self.config.normalize {
Self::normalize_vector(vec)
} else {
vec
}
})
.collect();
Ok(TokenEmbeddings::with_tokens(
embeddings,
tokens[..num_tokens].to_vec(),
))
}
pub fn compute_score(
&self,
query_tokens: &TokenEmbeddings,
doc_tokens: &TokenEmbeddings,
) -> Result<f32> {
if query_tokens.embedding_dim() != doc_tokens.embedding_dim() {
anyhow::bail!(
"Dimension mismatch: query={}, doc={}",
query_tokens.embedding_dim(),
doc_tokens.embedding_dim()
);
}
let mut total_score = 0.0;
for query_emb in &query_tokens.embeddings {
let mut max_sim = f32::NEG_INFINITY;
for doc_emb in &doc_tokens.embeddings {
let sim = self.compute_token_similarity(query_emb, doc_emb);
max_sim = max_sim.max(sim);
}
total_score += max_sim;
}
Ok(total_score)
}
fn compute_token_similarity(&self, vec1: &[f32], vec2: &[f32]) -> f32 {
match self.config.similarity_metric {
SimilarityMetric::Cosine => Self::cosine_similarity(vec1, vec2),
SimilarityMetric::DotProduct => Self::dot_product(vec1, vec2),
SimilarityMetric::L2 => -Self::l2_distance(vec1, vec2), }
}
fn cosine_similarity(vec1: &[f32], vec2: &[f32]) -> f32 {
Self::dot_product(vec1, vec2) }
fn dot_product(vec1: &[f32], vec2: &[f32]) -> f32 {
vec1.iter().zip(vec2.iter()).map(|(a, b)| a * b).sum()
}
fn l2_distance(vec1: &[f32], vec2: &[f32]) -> f32 {
vec1.iter()
.zip(vec2.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
.sqrt()
}
fn normalize_vector(vec: Vec<f32>) -> Vec<f32> {
let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
vec.into_iter().map(|x| x / norm).collect()
} else {
vec
}
}
pub fn cache_document(&mut self, doc_id: String, embeddings: TokenEmbeddings) {
self.doc_cache.insert(doc_id, embeddings);
}
pub fn get_cached_document(&self, doc_id: &str) -> Option<&TokenEmbeddings> {
self.doc_cache.get(doc_id)
}
pub fn clear_cache(&mut self) {
self.doc_cache.clear();
}
}
pub struct ColBERTBatchReranker {
reranker: ColBERTReranker,
}
impl ColBERTBatchReranker {
pub fn new(config: ColBERTConfig) -> Result<Self> {
Ok(Self {
reranker: ColBERTReranker::new(config)?,
})
}
pub async fn rerank(
&mut self,
query: &str,
documents: &[String],
top_k: usize,
) -> Result<Vec<(usize, f32)>> {
let query_tokens = self.reranker.encode_query(query).await?;
let mut scores: Vec<(usize, f32)> = Vec::with_capacity(documents.len());
for (idx, doc) in documents.iter().enumerate() {
let doc_tokens = self.reranker.encode_document(doc).await?;
let score = self.reranker.compute_score(&query_tokens, &doc_tokens)?;
scores.push((idx, score));
}
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(top_k);
Ok(scores)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_embeddings_creation() {
let embeddings = vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]];
let token_embs = TokenEmbeddings::new(embeddings.clone());
assert_eq!(token_embs.num_tokens(), 2);
assert_eq!(token_embs.embedding_dim(), 3);
}
#[test]
fn test_normalize_vector() {
let vec = vec![3.0, 4.0]; let normalized = ColBERTReranker::normalize_vector(vec);
assert!((normalized[0] - 0.6).abs() < 1e-6);
assert!((normalized[1] - 0.8).abs() < 1e-6);
let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6);
}
#[test]
fn test_dot_product() {
let vec1 = vec![1.0, 2.0, 3.0];
let vec2 = vec![4.0, 5.0, 6.0];
let dot = ColBERTReranker::dot_product(&vec1, &vec2);
assert_eq!(dot, 32.0); }
#[test]
fn test_l2_distance() {
let vec1 = vec![0.0, 0.0];
let vec2 = vec![3.0, 4.0];
let dist = ColBERTReranker::l2_distance(&vec1, &vec2);
assert_eq!(dist, 5.0); }
#[test]
fn test_compute_score() {
let config = ColBERTConfig {
embedding_dim: 3,
normalize: false, ..Default::default()
};
let reranker = ColBERTReranker::new(config).unwrap();
let query_tokens = TokenEmbeddings::new(vec![
vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0], ]);
let doc_tokens = TokenEmbeddings::new(vec![
vec![1.0, 0.0, 0.0], vec![0.0, 0.0, 1.0], vec![0.0, 1.0, 0.0], ]);
let score = reranker.compute_score(&query_tokens, &doc_tokens).unwrap();
assert_eq!(score, 2.0);
}
#[tokio::test]
async fn test_colbert_reranker_basic() {
let config = ColBERTConfig::default();
let reranker = ColBERTReranker::new(config).unwrap();
let query_tokens = reranker.encode_query("test query").await.unwrap();
let doc_tokens = reranker.encode_document("test document").await.unwrap();
assert!(query_tokens.num_tokens() > 0);
assert!(doc_tokens.num_tokens() > 0);
let score = reranker.compute_score(&query_tokens, &doc_tokens).unwrap();
assert!(score.is_finite());
}
}