use crate::{GraphRAGError, GraphRAGResult, ScoredEntity};
use std::collections::HashMap;
pub type TokenEmbedding = Vec<f32>;
pub type TokenSequence = Vec<TokenEmbedding>;
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Embedding dimensions must match");
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a < 1e-9 || norm_b < 1e-9 {
return 0.0;
}
(dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
}
fn max_sim(query_token: &[f32], doc_tokens: &[TokenEmbedding]) -> f32 {
doc_tokens
.iter()
.map(|dt| cosine_similarity(query_token, dt))
.fold(f32::NEG_INFINITY, f32::max)
}
fn colbert_score(query_tokens: &TokenSequence, doc_tokens: &TokenSequence) -> f32 {
if query_tokens.is_empty() || doc_tokens.is_empty() {
return 0.0;
}
query_tokens
.iter()
.map(|qt| max_sim(qt, doc_tokens))
.sum::<f32>()
/ query_tokens.len() as f32 }
#[derive(Debug, Clone)]
pub struct ColbertRerankerConfig {
pub colbert_weight: f64,
pub min_colbert_score: f32,
pub max_candidates: usize,
pub normalise_scores: bool,
}
impl Default for ColbertRerankerConfig {
fn default() -> Self {
Self {
colbert_weight: 0.7,
min_colbert_score: 0.0,
max_candidates: 100,
normalise_scores: true,
}
}
}
pub trait TokenEncoder: Send + Sync {
fn encode(&self, text: &str) -> GraphRAGResult<TokenSequence>;
}
pub struct MockTokenEncoder {
dim: usize,
vocab: HashMap<String, TokenEmbedding>,
}
impl MockTokenEncoder {
pub fn new(dim: usize) -> Self {
Self {
dim,
vocab: HashMap::new(),
}
}
pub fn register_token(&mut self, token: impl Into<String>, embedding: Vec<f32>) {
self.vocab.insert(token.into(), embedding);
}
fn hash_embed(&self, token: &str) -> TokenEmbedding {
let mut v: Vec<f32> = (0..self.dim)
.map(|i| {
let hash: u64 = token.bytes().fold(i as u64, |acc, b| {
acc.wrapping_mul(6364136223846793005).wrapping_add(b as u64)
});
((hash as i64) as f32) / (i64::MAX as f32)
})
.collect();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-9 {
v.iter_mut().for_each(|x| *x /= norm);
}
v
}
}
impl TokenEncoder for MockTokenEncoder {
fn encode(&self, text: &str) -> GraphRAGResult<TokenSequence> {
let tokens: TokenSequence = text
.split_whitespace()
.map(|tok| {
let lower = tok.to_lowercase();
self.vocab
.get(&lower)
.cloned()
.unwrap_or_else(|| self.hash_embed(&lower))
})
.collect();
Ok(tokens)
}
}
pub struct ColbertReranker<E: TokenEncoder> {
encoder: E,
config: ColbertRerankerConfig,
doc_store: HashMap<String, String>,
}
impl<E: TokenEncoder> ColbertReranker<E> {
pub fn new(encoder: E, config: ColbertRerankerConfig) -> Self {
Self {
encoder,
config,
doc_store: HashMap::new(),
}
}
pub fn register_documents(&mut self, docs: impl IntoIterator<Item = (String, String)>) {
for (uri, text) in docs {
self.doc_store.insert(uri, text);
}
}
pub fn rerank(
&self,
query: &str,
mut candidates: Vec<ScoredEntity>,
) -> GraphRAGResult<Vec<ScoredEntity>> {
if candidates.is_empty() || query.is_empty() {
return Ok(candidates);
}
let query_tokens = self.encoder.encode(query)?;
candidates.truncate(self.config.max_candidates);
let mut scored: Vec<(ScoredEntity, f32)> = candidates
.into_iter()
.map(|entity| {
let colbert = self.score_entity(query, &query_tokens, &entity);
(entity, colbert)
})
.collect();
if self.config.normalise_scores {
let max_c = scored
.iter()
.map(|(_, c)| *c)
.fold(f32::NEG_INFINITY, f32::max);
if max_c > 1e-9 {
scored.iter_mut().for_each(|(_, c)| *c /= max_c);
}
}
let w = self.config.colbert_weight;
let min_c = self.config.min_colbert_score;
let mut result: Vec<ScoredEntity> = scored
.into_iter()
.filter(|(_, c)| *c >= min_c)
.map(|(mut entity, c)| {
entity.score = (1.0 - w) * entity.score + w * c as f64;
entity
})
.collect();
result.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(result)
}
fn score_entity(
&self,
_query: &str,
query_tokens: &TokenSequence,
entity: &ScoredEntity,
) -> f32 {
let doc_text = match self.doc_store.get(&entity.uri) {
Some(text) => text.clone(),
None => {
entity.uri.clone()
}
};
match self.encoder.encode(&doc_text) {
Ok(doc_tokens) => colbert_score(query_tokens, &doc_tokens),
Err(_) => 0.0,
}
}
}
pub fn colbert_score_batch<E: TokenEncoder>(
encoder: &E,
query: &str,
docs: &[(&str, &str)],
) -> GraphRAGResult<Vec<f32>> {
let query_tokens = encoder.encode(query)?;
docs.iter()
.map(|(_, doc_text)| {
encoder
.encode(doc_text)
.map(|dt| colbert_score(&query_tokens, &dt))
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ScoreSource;
fn make_encoder(dim: usize) -> MockTokenEncoder {
MockTokenEncoder::new(dim)
}
fn make_entity(uri: &str, score: f64) -> ScoredEntity {
ScoredEntity {
uri: uri.to_string(),
score,
source: ScoreSource::Fused,
metadata: HashMap::new(),
}
}
#[test]
fn test_cosine_similarity_identical_vectors() {
let v = vec![0.6, 0.8];
assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!((cosine_similarity(&a, &b)).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_zero_vector() {
let a = vec![0.0, 0.0];
let b = vec![1.0, 0.0];
assert_eq!(cosine_similarity(&a, &b), 0.0);
}
#[test]
fn test_colbert_score_same_query_doc() {
let enc = make_encoder(8);
let q = enc.encode("battery safety").expect("should succeed");
let d = enc.encode("battery safety").expect("should succeed");
let score = colbert_score(&q, &d);
assert!(
score > 0.8,
"Identical query/doc should score >0.8, got {score}"
);
}
#[test]
fn test_colbert_score_empty_query() {
let q: TokenSequence = vec![];
let d = vec![vec![1.0f32, 0.0]];
assert_eq!(colbert_score(&q, &d), 0.0);
}
#[test]
fn test_colbert_score_empty_doc() {
let q = vec![vec![1.0f32, 0.0]];
let d: TokenSequence = vec![];
assert_eq!(colbert_score(&q, &d), 0.0);
}
#[test]
fn test_mock_encoder_deterministic() {
let enc = make_encoder(16);
let e1 = enc.encode("hello world").expect("should succeed");
let e2 = enc.encode("hello world").expect("should succeed");
assert_eq!(e1.len(), e2.len());
for (a, b) in e1.iter().zip(e2.iter()) {
for (x, y) in a.iter().zip(b.iter()) {
assert!((x - y).abs() < 1e-9);
}
}
}
#[test]
fn test_mock_encoder_registered_token() {
let mut enc = make_encoder(4);
enc.register_token("special", vec![1.0, 0.0, 0.0, 0.0]);
let tokens = enc.encode("special term").expect("should succeed");
assert_eq!(tokens.len(), 2);
assert!((tokens[0][0] - 1.0).abs() < 1e-9);
}
#[test]
fn test_mock_encoder_unit_length() {
let enc = make_encoder(32);
let tokens = enc
.encode("test token normalization")
.expect("should succeed");
for tok in &tokens {
let norm: f32 = tok.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5, "Token not unit length: {norm}");
}
}
#[test]
fn test_reranker_basic() {
let enc = make_encoder(16);
let mut reranker = ColbertReranker::new(enc, ColbertRerankerConfig::default());
reranker.register_documents([
(
"http://a".to_string(),
"battery safety cell thermal".to_string(),
),
(
"http://b".to_string(),
"charging protocol electric".to_string(),
),
]);
let candidates = vec![make_entity("http://a", 0.7), make_entity("http://b", 0.6)];
let reranked = reranker
.rerank("battery safety", candidates)
.expect("should succeed");
assert_eq!(reranked.len(), 2);
assert_eq!(reranked[0].uri, "http://a");
}
#[test]
fn test_reranker_empty_candidates() {
let enc = make_encoder(8);
let reranker = ColbertReranker::new(enc, ColbertRerankerConfig::default());
let result = reranker.rerank("query", vec![]).expect("should succeed");
assert!(result.is_empty());
}
#[test]
fn test_reranker_empty_query() {
let enc = make_encoder(8);
let reranker = ColbertReranker::new(enc, ColbertRerankerConfig::default());
let candidates = vec![make_entity("http://a", 0.5)];
let result = reranker.rerank("", candidates).expect("should succeed");
assert_eq!(result.len(), 1);
}
#[test]
fn test_reranker_max_candidates_limiting() {
let enc = make_encoder(8);
let config = ColbertRerankerConfig {
max_candidates: 2,
..Default::default()
};
let reranker = ColbertReranker::new(enc, config);
let candidates: Vec<ScoredEntity> = (0..10)
.map(|i| make_entity(&format!("http://e{i}"), 0.5))
.collect();
let result = reranker.rerank("test", candidates).expect("should succeed");
assert!(result.len() <= 2);
}
#[test]
fn test_reranker_min_score_filter() {
let enc = make_encoder(8);
let config = ColbertRerankerConfig {
min_colbert_score: 999.0, normalise_scores: false,
..Default::default()
};
let reranker = ColbertReranker::new(enc, config);
let candidates = vec![make_entity("http://a", 0.8)];
let result = reranker.rerank("test", candidates).expect("should succeed");
assert!(result.is_empty());
}
#[test]
fn test_reranker_fallback_without_doc_store() {
let enc = make_encoder(8);
let reranker = ColbertReranker::new(enc, ColbertRerankerConfig::default());
let candidates = vec![make_entity("http://a", 0.7), make_entity("http://b", 0.6)];
let result = reranker
.rerank("some query", candidates)
.expect("should succeed");
assert_eq!(result.len(), 2);
}
#[test]
fn test_reranker_normalises_scores() {
let enc = make_encoder(16);
let config = ColbertRerankerConfig {
normalise_scores: true,
colbert_weight: 1.0, ..Default::default()
};
let mut reranker = ColbertReranker::new(enc, config);
reranker.register_documents([
("http://x".to_string(), "alpha beta gamma".to_string()),
("http://y".to_string(), "delta epsilon zeta".to_string()),
]);
let candidates = vec![make_entity("http://x", 0.5), make_entity("http://y", 0.5)];
let result = reranker
.rerank("alpha gamma", candidates)
.expect("should succeed");
assert!(
result[0].score <= 1.01,
"Score should be ≤ 1.0, got {}",
result[0].score
);
}
#[test]
fn test_batch_scoring() {
let enc = make_encoder(16);
let docs = vec![
("id1", "battery safety cell"),
("id2", "charging electric vehicle"),
("id3", "battery cell chemistry"),
];
let scores = colbert_score_batch(&enc, "battery safety", &docs).expect("should succeed");
assert_eq!(scores.len(), 3);
for s in &scores {
assert!(*s >= 0.0, "Score should be non-negative");
}
assert!(
scores[0] > scores[1],
"Doc 0 should beat doc 1 for 'battery safety'"
);
}
#[test]
fn test_batch_scoring_empty_docs() {
let enc = make_encoder(8);
let scores = colbert_score_batch(&enc, "query", &[]).expect("should succeed");
assert!(scores.is_empty());
}
#[test]
fn test_colbert_score_partial_overlap() {
let enc = make_encoder(16);
let q = enc.encode("battery cell safety").expect("should succeed");
let d_rel = enc
.encode("battery cell thermal runaway")
.expect("should succeed");
let d_irrel = enc
.encode("aircraft propulsion jet")
.expect("should succeed");
let s_rel = colbert_score(&q, &d_rel);
let s_irrel = colbert_score(&q, &d_irrel);
assert!(
s_rel > s_irrel,
"Relevant doc should score higher: {s_rel} vs {s_irrel}"
);
}
}