Skip to main content

oxirs_graphrag/fusion/
colbert_reranker.rs

1//! ColBERT-style late interaction reranking
2//!
3//! ColBERT computes MaxSim: for each query token embedding, find the maximum
4//! cosine similarity with any document token embedding, then sum across all
5//! query tokens.  This module implements a pure-Rust approximation that uses
6//! pre-computed token-level embeddings (f32 arrays) to perform late interaction
7//! scoring without requiring an external ML runtime.
8
9use crate::{GraphRAGError, GraphRAGResult, ScoredEntity};
10use std::collections::HashMap;
11
12/// A token-level embedding (one embedding per sub-word token)
13pub type TokenEmbedding = Vec<f32>;
14
15/// Sequence of token embeddings for a passage (document or query)
16pub type TokenSequence = Vec<TokenEmbedding>;
17
18/// Cosine similarity between two equal-length embedding vectors
19fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
20    debug_assert_eq!(a.len(), b.len(), "Embedding dimensions must match");
21
22    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
23    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
24    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
25
26    if norm_a < 1e-9 || norm_b < 1e-9 {
27        return 0.0;
28    }
29    (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
30}
31
32/// MaxSim score for one query token against all document tokens
33fn max_sim(query_token: &[f32], doc_tokens: &[TokenEmbedding]) -> f32 {
34    doc_tokens
35        .iter()
36        .map(|dt| cosine_similarity(query_token, dt))
37        .fold(f32::NEG_INFINITY, f32::max)
38}
39
40/// ColBERT score: sum of MaxSim across all query tokens
41fn colbert_score(query_tokens: &TokenSequence, doc_tokens: &TokenSequence) -> f32 {
42    if query_tokens.is_empty() || doc_tokens.is_empty() {
43        return 0.0;
44    }
45    query_tokens
46        .iter()
47        .map(|qt| max_sim(qt, doc_tokens))
48        .sum::<f32>()
49        / query_tokens.len() as f32 // normalise by number of query tokens
50}
51
52// ─── Configuration ─────────────────────────────────────────────────────────
53
54/// ColBERT reranker configuration
55#[derive(Debug, Clone)]
56pub struct ColbertRerankerConfig {
57    /// Weight given to the ColBERT score vs the original retrieval score
58    /// (0.0 = pure original, 1.0 = pure ColBERT)
59    pub colbert_weight: f64,
60    /// Minimum ColBERT score required to keep a candidate (0.0 = keep all)
61    pub min_colbert_score: f32,
62    /// Maximum candidates to rerank (truncated before scoring for speed)
63    pub max_candidates: usize,
64    /// Whether to normalise ColBERT scores across the candidate set
65    pub normalise_scores: bool,
66}
67
68impl Default for ColbertRerankerConfig {
69    fn default() -> Self {
70        Self {
71            colbert_weight: 0.7,
72            min_colbert_score: 0.0,
73            max_candidates: 100,
74            normalise_scores: true,
75        }
76    }
77}
78
79// ─── Token encoder trait ────────────────────────────────────────────────────
80
81/// Trait for encoding text into token-level embeddings
82pub trait TokenEncoder: Send + Sync {
83    /// Encode text into a sequence of token embeddings
84    fn encode(&self, text: &str) -> GraphRAGResult<TokenSequence>;
85}
86
87/// Simple whitespace-based encoder using random unit vectors per token type
88/// (placeholder for a real transformer encoder in production)
89pub struct MockTokenEncoder {
90    dim: usize,
91    vocab: HashMap<String, TokenEmbedding>,
92}
93
94impl MockTokenEncoder {
95    /// Create with the given embedding dimension
96    pub fn new(dim: usize) -> Self {
97        Self {
98            dim,
99            vocab: HashMap::new(),
100        }
101    }
102
103    /// Register a specific token embedding (for deterministic tests)
104    pub fn register_token(&mut self, token: impl Into<String>, embedding: Vec<f32>) {
105        self.vocab.insert(token.into(), embedding);
106    }
107
108    /// Generate a deterministic unit vector for an unknown token
109    fn hash_embed(&self, token: &str) -> TokenEmbedding {
110        let mut v: Vec<f32> = (0..self.dim)
111            .map(|i| {
112                // Deterministic pseudo-random from token bytes + index
113                let hash: u64 = token.bytes().fold(i as u64, |acc, b| {
114                    acc.wrapping_mul(6364136223846793005).wrapping_add(b as u64)
115                });
116                ((hash as i64) as f32) / (i64::MAX as f32)
117            })
118            .collect();
119
120        // Normalise to unit length
121        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
122        if norm > 1e-9 {
123            v.iter_mut().for_each(|x| *x /= norm);
124        }
125        v
126    }
127}
128
129impl TokenEncoder for MockTokenEncoder {
130    fn encode(&self, text: &str) -> GraphRAGResult<TokenSequence> {
131        let tokens: TokenSequence = text
132            .split_whitespace()
133            .map(|tok| {
134                let lower = tok.to_lowercase();
135                self.vocab
136                    .get(&lower)
137                    .cloned()
138                    .unwrap_or_else(|| self.hash_embed(&lower))
139            })
140            .collect();
141        Ok(tokens)
142    }
143}
144
145// ─── Reranker ───────────────────────────────────────────────────────────────
146
147/// ColBERT-style late interaction reranker
148pub struct ColbertReranker<E: TokenEncoder> {
149    encoder: E,
150    config: ColbertRerankerConfig,
151    /// Optional document text lookup: entity_uri → text representation
152    doc_store: HashMap<String, String>,
153}
154
155impl<E: TokenEncoder> ColbertReranker<E> {
156    /// Create a new reranker
157    pub fn new(encoder: E, config: ColbertRerankerConfig) -> Self {
158        Self {
159            encoder,
160            config,
161            doc_store: HashMap::new(),
162        }
163    }
164
165    /// Register entity text representations for late interaction scoring
166    pub fn register_documents(&mut self, docs: impl IntoIterator<Item = (String, String)>) {
167        for (uri, text) in docs {
168            self.doc_store.insert(uri, text);
169        }
170    }
171
172    /// Rerank candidates using ColBERT late interaction.
173    ///
174    /// Candidates without a registered document text fall back to their
175    /// original scores (not discarded, so no precision loss for unknown docs).
176    pub fn rerank(
177        &self,
178        query: &str,
179        mut candidates: Vec<ScoredEntity>,
180    ) -> GraphRAGResult<Vec<ScoredEntity>> {
181        if candidates.is_empty() || query.is_empty() {
182            return Ok(candidates);
183        }
184
185        // Encode query
186        let query_tokens = self.encoder.encode(query)?;
187
188        // Truncate to max_candidates
189        candidates.truncate(self.config.max_candidates);
190
191        // Score each candidate
192        let mut scored: Vec<(ScoredEntity, f32)> = candidates
193            .into_iter()
194            .map(|entity| {
195                let colbert = self.score_entity(query, &query_tokens, &entity);
196                (entity, colbert)
197            })
198            .collect();
199
200        // Normalise ColBERT scores if requested
201        if self.config.normalise_scores {
202            let max_c = scored
203                .iter()
204                .map(|(_, c)| *c)
205                .fold(f32::NEG_INFINITY, f32::max);
206            if max_c > 1e-9 {
207                scored.iter_mut().for_each(|(_, c)| *c /= max_c);
208            }
209        }
210
211        // Blend and filter
212        let w = self.config.colbert_weight;
213        let min_c = self.config.min_colbert_score;
214
215        let mut result: Vec<ScoredEntity> = scored
216            .into_iter()
217            .filter(|(_, c)| *c >= min_c)
218            .map(|(mut entity, c)| {
219                entity.score = (1.0 - w) * entity.score + w * c as f64;
220                entity
221            })
222            .collect();
223
224        result.sort_by(|a, b| {
225            b.score
226                .partial_cmp(&a.score)
227                .unwrap_or(std::cmp::Ordering::Equal)
228        });
229
230        Ok(result)
231    }
232
233    /// Score a single entity (returns raw ColBERT score)
234    fn score_entity(
235        &self,
236        _query: &str,
237        query_tokens: &TokenSequence,
238        entity: &ScoredEntity,
239    ) -> f32 {
240        let doc_text = match self.doc_store.get(&entity.uri) {
241            Some(text) => text.clone(),
242            None => {
243                // Fall back to the URI itself as a pseudo-document
244                entity.uri.clone()
245            }
246        };
247
248        match self.encoder.encode(&doc_text) {
249            Ok(doc_tokens) => colbert_score(query_tokens, &doc_tokens),
250            Err(_) => 0.0,
251        }
252    }
253}
254
255// ─── Batch scoring helper ────────────────────────────────────────────────────
256
257/// Score multiple (query, document) pairs and return ColBERT scores
258pub fn colbert_score_batch<E: TokenEncoder>(
259    encoder: &E,
260    query: &str,
261    docs: &[(&str, &str)],
262) -> GraphRAGResult<Vec<f32>> {
263    let query_tokens = encoder.encode(query)?;
264    docs.iter()
265        .map(|(_, doc_text)| {
266            encoder
267                .encode(doc_text)
268                .map(|dt| colbert_score(&query_tokens, &dt))
269        })
270        .collect()
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use crate::ScoreSource;
277
278    fn make_encoder(dim: usize) -> MockTokenEncoder {
279        MockTokenEncoder::new(dim)
280    }
281
282    fn make_entity(uri: &str, score: f64) -> ScoredEntity {
283        ScoredEntity {
284            uri: uri.to_string(),
285            score,
286            source: ScoreSource::Fused,
287            metadata: HashMap::new(),
288        }
289    }
290
291    // ── cosine_similarity ─────────────────────────────────────────────────
292
293    #[test]
294    fn test_cosine_similarity_identical_vectors() {
295        let v = vec![0.6, 0.8];
296        assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-6);
297    }
298
299    #[test]
300    fn test_cosine_similarity_orthogonal() {
301        let a = vec![1.0, 0.0];
302        let b = vec![0.0, 1.0];
303        assert!((cosine_similarity(&a, &b)).abs() < 1e-6);
304    }
305
306    #[test]
307    fn test_cosine_similarity_zero_vector() {
308        let a = vec![0.0, 0.0];
309        let b = vec![1.0, 0.0];
310        assert_eq!(cosine_similarity(&a, &b), 0.0);
311    }
312
313    // ── colbert_score ─────────────────────────────────────────────────────
314
315    #[test]
316    fn test_colbert_score_same_query_doc() {
317        // A query identical to the document should score high
318        let enc = make_encoder(8);
319        let q = enc.encode("battery safety").expect("should succeed");
320        let d = enc.encode("battery safety").expect("should succeed");
321        let score = colbert_score(&q, &d);
322        assert!(
323            score > 0.8,
324            "Identical query/doc should score >0.8, got {score}"
325        );
326    }
327
328    #[test]
329    fn test_colbert_score_empty_query() {
330        let q: TokenSequence = vec![];
331        let d = vec![vec![1.0f32, 0.0]];
332        assert_eq!(colbert_score(&q, &d), 0.0);
333    }
334
335    #[test]
336    fn test_colbert_score_empty_doc() {
337        let q = vec![vec![1.0f32, 0.0]];
338        let d: TokenSequence = vec![];
339        assert_eq!(colbert_score(&q, &d), 0.0);
340    }
341
342    // ── MockTokenEncoder ─────────────────────────────────────────────────
343
344    #[test]
345    fn test_mock_encoder_deterministic() {
346        let enc = make_encoder(16);
347        let e1 = enc.encode("hello world").expect("should succeed");
348        let e2 = enc.encode("hello world").expect("should succeed");
349        assert_eq!(e1.len(), e2.len());
350        for (a, b) in e1.iter().zip(e2.iter()) {
351            for (x, y) in a.iter().zip(b.iter()) {
352                assert!((x - y).abs() < 1e-9);
353            }
354        }
355    }
356
357    #[test]
358    fn test_mock_encoder_registered_token() {
359        let mut enc = make_encoder(4);
360        enc.register_token("special", vec![1.0, 0.0, 0.0, 0.0]);
361        let tokens = enc.encode("special term").expect("should succeed");
362        assert_eq!(tokens.len(), 2);
363        // First token should be exactly our registered vector
364        assert!((tokens[0][0] - 1.0).abs() < 1e-9);
365    }
366
367    #[test]
368    fn test_mock_encoder_unit_length() {
369        let enc = make_encoder(32);
370        let tokens = enc
371            .encode("test token normalization")
372            .expect("should succeed");
373        for tok in &tokens {
374            let norm: f32 = tok.iter().map(|x| x * x).sum::<f32>().sqrt();
375            assert!((norm - 1.0).abs() < 1e-5, "Token not unit length: {norm}");
376        }
377    }
378
379    // ── ColbertReranker ──────────────────────────────────────────────────
380
381    #[test]
382    fn test_reranker_basic() {
383        let enc = make_encoder(16);
384        let mut reranker = ColbertReranker::new(enc, ColbertRerankerConfig::default());
385        reranker.register_documents([
386            (
387                "http://a".to_string(),
388                "battery safety cell thermal".to_string(),
389            ),
390            (
391                "http://b".to_string(),
392                "charging protocol electric".to_string(),
393            ),
394        ]);
395
396        let candidates = vec![make_entity("http://a", 0.7), make_entity("http://b", 0.6)];
397
398        let reranked = reranker
399            .rerank("battery safety", candidates)
400            .expect("should succeed");
401        assert_eq!(reranked.len(), 2);
402        // http://a should score higher (relevant doc)
403        assert_eq!(reranked[0].uri, "http://a");
404    }
405
406    #[test]
407    fn test_reranker_empty_candidates() {
408        let enc = make_encoder(8);
409        let reranker = ColbertReranker::new(enc, ColbertRerankerConfig::default());
410        let result = reranker.rerank("query", vec![]).expect("should succeed");
411        assert!(result.is_empty());
412    }
413
414    #[test]
415    fn test_reranker_empty_query() {
416        let enc = make_encoder(8);
417        let reranker = ColbertReranker::new(enc, ColbertRerankerConfig::default());
418        let candidates = vec![make_entity("http://a", 0.5)];
419        let result = reranker.rerank("", candidates).expect("should succeed");
420        assert_eq!(result.len(), 1);
421    }
422
423    #[test]
424    fn test_reranker_max_candidates_limiting() {
425        let enc = make_encoder(8);
426        let config = ColbertRerankerConfig {
427            max_candidates: 2,
428            ..Default::default()
429        };
430        let reranker = ColbertReranker::new(enc, config);
431        let candidates: Vec<ScoredEntity> = (0..10)
432            .map(|i| make_entity(&format!("http://e{i}"), 0.5))
433            .collect();
434        let result = reranker.rerank("test", candidates).expect("should succeed");
435        assert!(result.len() <= 2);
436    }
437
438    #[test]
439    fn test_reranker_min_score_filter() {
440        let enc = make_encoder(8);
441        let config = ColbertRerankerConfig {
442            min_colbert_score: 999.0, // impossible threshold
443            normalise_scores: false,
444            ..Default::default()
445        };
446        let reranker = ColbertReranker::new(enc, config);
447        let candidates = vec![make_entity("http://a", 0.8)];
448        let result = reranker.rerank("test", candidates).expect("should succeed");
449        assert!(result.is_empty());
450    }
451
452    #[test]
453    fn test_reranker_fallback_without_doc_store() {
454        // No documents registered – should fall back to URI-based scoring
455        let enc = make_encoder(8);
456        let reranker = ColbertReranker::new(enc, ColbertRerankerConfig::default());
457        let candidates = vec![make_entity("http://a", 0.7), make_entity("http://b", 0.6)];
458        // Should not panic and should return some ordering
459        let result = reranker
460            .rerank("some query", candidates)
461            .expect("should succeed");
462        assert_eq!(result.len(), 2);
463    }
464
465    #[test]
466    fn test_reranker_normalises_scores() {
467        let enc = make_encoder(16);
468        let config = ColbertRerankerConfig {
469            normalise_scores: true,
470            colbert_weight: 1.0, // fully ColBERT
471            ..Default::default()
472        };
473        let mut reranker = ColbertReranker::new(enc, config);
474        reranker.register_documents([
475            ("http://x".to_string(), "alpha beta gamma".to_string()),
476            ("http://y".to_string(), "delta epsilon zeta".to_string()),
477        ]);
478        let candidates = vec![make_entity("http://x", 0.5), make_entity("http://y", 0.5)];
479        let result = reranker
480            .rerank("alpha gamma", candidates)
481            .expect("should succeed");
482        // After normalisation + full ColBERT weight, top doc should score ~1.0
483        assert!(
484            result[0].score <= 1.01,
485            "Score should be ≤ 1.0, got {}",
486            result[0].score
487        );
488    }
489
490    // ── colbert_score_batch ───────────────────────────────────────────────
491
492    #[test]
493    fn test_batch_scoring() {
494        let enc = make_encoder(16);
495        let docs = vec![
496            ("id1", "battery safety cell"),
497            ("id2", "charging electric vehicle"),
498            ("id3", "battery cell chemistry"),
499        ];
500        let scores = colbert_score_batch(&enc, "battery safety", &docs).expect("should succeed");
501        assert_eq!(scores.len(), 3);
502        for s in &scores {
503            assert!(*s >= 0.0, "Score should be non-negative");
504        }
505        // First doc is most relevant to "battery safety"
506        assert!(
507            scores[0] > scores[1],
508            "Doc 0 should beat doc 1 for 'battery safety'"
509        );
510    }
511
512    #[test]
513    fn test_batch_scoring_empty_docs() {
514        let enc = make_encoder(8);
515        let scores = colbert_score_batch(&enc, "query", &[]).expect("should succeed");
516        assert!(scores.is_empty());
517    }
518
519    #[test]
520    fn test_colbert_score_partial_overlap() {
521        let enc = make_encoder(16);
522        let q = enc.encode("battery cell safety").expect("should succeed");
523        let d_rel = enc
524            .encode("battery cell thermal runaway")
525            .expect("should succeed");
526        let d_irrel = enc
527            .encode("aircraft propulsion jet")
528            .expect("should succeed");
529
530        let s_rel = colbert_score(&q, &d_rel);
531        let s_irrel = colbert_score(&q, &d_irrel);
532
533        assert!(
534            s_rel > s_irrel,
535            "Relevant doc should score higher: {s_rel} vs {s_irrel}"
536        );
537    }
538}