1use crate::{GraphRAGError, GraphRAGResult, ScoredEntity};
10use std::collections::HashMap;
11
12pub type TokenEmbedding = Vec<f32>;
14
15pub type TokenSequence = Vec<TokenEmbedding>;
17
18fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
20 debug_assert_eq!(a.len(), b.len(), "Embedding dimensions must match");
21
22 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
23 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
24 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
25
26 if norm_a < 1e-9 || norm_b < 1e-9 {
27 return 0.0;
28 }
29 (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
30}
31
32fn max_sim(query_token: &[f32], doc_tokens: &[TokenEmbedding]) -> f32 {
34 doc_tokens
35 .iter()
36 .map(|dt| cosine_similarity(query_token, dt))
37 .fold(f32::NEG_INFINITY, f32::max)
38}
39
40fn colbert_score(query_tokens: &TokenSequence, doc_tokens: &TokenSequence) -> f32 {
42 if query_tokens.is_empty() || doc_tokens.is_empty() {
43 return 0.0;
44 }
45 query_tokens
46 .iter()
47 .map(|qt| max_sim(qt, doc_tokens))
48 .sum::<f32>()
49 / query_tokens.len() as f32 }
51
52#[derive(Debug, Clone)]
56pub struct ColbertRerankerConfig {
57 pub colbert_weight: f64,
60 pub min_colbert_score: f32,
62 pub max_candidates: usize,
64 pub normalise_scores: bool,
66}
67
68impl Default for ColbertRerankerConfig {
69 fn default() -> Self {
70 Self {
71 colbert_weight: 0.7,
72 min_colbert_score: 0.0,
73 max_candidates: 100,
74 normalise_scores: true,
75 }
76 }
77}
78
79pub trait TokenEncoder: Send + Sync {
83 fn encode(&self, text: &str) -> GraphRAGResult<TokenSequence>;
85}
86
87pub struct MockTokenEncoder {
90 dim: usize,
91 vocab: HashMap<String, TokenEmbedding>,
92}
93
94impl MockTokenEncoder {
95 pub fn new(dim: usize) -> Self {
97 Self {
98 dim,
99 vocab: HashMap::new(),
100 }
101 }
102
103 pub fn register_token(&mut self, token: impl Into<String>, embedding: Vec<f32>) {
105 self.vocab.insert(token.into(), embedding);
106 }
107
108 fn hash_embed(&self, token: &str) -> TokenEmbedding {
110 let mut v: Vec<f32> = (0..self.dim)
111 .map(|i| {
112 let hash: u64 = token.bytes().fold(i as u64, |acc, b| {
114 acc.wrapping_mul(6364136223846793005).wrapping_add(b as u64)
115 });
116 ((hash as i64) as f32) / (i64::MAX as f32)
117 })
118 .collect();
119
120 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
122 if norm > 1e-9 {
123 v.iter_mut().for_each(|x| *x /= norm);
124 }
125 v
126 }
127}
128
129impl TokenEncoder for MockTokenEncoder {
130 fn encode(&self, text: &str) -> GraphRAGResult<TokenSequence> {
131 let tokens: TokenSequence = text
132 .split_whitespace()
133 .map(|tok| {
134 let lower = tok.to_lowercase();
135 self.vocab
136 .get(&lower)
137 .cloned()
138 .unwrap_or_else(|| self.hash_embed(&lower))
139 })
140 .collect();
141 Ok(tokens)
142 }
143}
144
145pub struct ColbertReranker<E: TokenEncoder> {
149 encoder: E,
150 config: ColbertRerankerConfig,
151 doc_store: HashMap<String, String>,
153}
154
155impl<E: TokenEncoder> ColbertReranker<E> {
156 pub fn new(encoder: E, config: ColbertRerankerConfig) -> Self {
158 Self {
159 encoder,
160 config,
161 doc_store: HashMap::new(),
162 }
163 }
164
165 pub fn register_documents(&mut self, docs: impl IntoIterator<Item = (String, String)>) {
167 for (uri, text) in docs {
168 self.doc_store.insert(uri, text);
169 }
170 }
171
172 pub fn rerank(
177 &self,
178 query: &str,
179 mut candidates: Vec<ScoredEntity>,
180 ) -> GraphRAGResult<Vec<ScoredEntity>> {
181 if candidates.is_empty() || query.is_empty() {
182 return Ok(candidates);
183 }
184
185 let query_tokens = self.encoder.encode(query)?;
187
188 candidates.truncate(self.config.max_candidates);
190
191 let mut scored: Vec<(ScoredEntity, f32)> = candidates
193 .into_iter()
194 .map(|entity| {
195 let colbert = self.score_entity(query, &query_tokens, &entity);
196 (entity, colbert)
197 })
198 .collect();
199
200 if self.config.normalise_scores {
202 let max_c = scored
203 .iter()
204 .map(|(_, c)| *c)
205 .fold(f32::NEG_INFINITY, f32::max);
206 if max_c > 1e-9 {
207 scored.iter_mut().for_each(|(_, c)| *c /= max_c);
208 }
209 }
210
211 let w = self.config.colbert_weight;
213 let min_c = self.config.min_colbert_score;
214
215 let mut result: Vec<ScoredEntity> = scored
216 .into_iter()
217 .filter(|(_, c)| *c >= min_c)
218 .map(|(mut entity, c)| {
219 entity.score = (1.0 - w) * entity.score + w * c as f64;
220 entity
221 })
222 .collect();
223
224 result.sort_by(|a, b| {
225 b.score
226 .partial_cmp(&a.score)
227 .unwrap_or(std::cmp::Ordering::Equal)
228 });
229
230 Ok(result)
231 }
232
233 fn score_entity(
235 &self,
236 _query: &str,
237 query_tokens: &TokenSequence,
238 entity: &ScoredEntity,
239 ) -> f32 {
240 let doc_text = match self.doc_store.get(&entity.uri) {
241 Some(text) => text.clone(),
242 None => {
243 entity.uri.clone()
245 }
246 };
247
248 match self.encoder.encode(&doc_text) {
249 Ok(doc_tokens) => colbert_score(query_tokens, &doc_tokens),
250 Err(_) => 0.0,
251 }
252 }
253}
254
255pub fn colbert_score_batch<E: TokenEncoder>(
259 encoder: &E,
260 query: &str,
261 docs: &[(&str, &str)],
262) -> GraphRAGResult<Vec<f32>> {
263 let query_tokens = encoder.encode(query)?;
264 docs.iter()
265 .map(|(_, doc_text)| {
266 encoder
267 .encode(doc_text)
268 .map(|dt| colbert_score(&query_tokens, &dt))
269 })
270 .collect()
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use crate::ScoreSource;
277
278 fn make_encoder(dim: usize) -> MockTokenEncoder {
279 MockTokenEncoder::new(dim)
280 }
281
282 fn make_entity(uri: &str, score: f64) -> ScoredEntity {
283 ScoredEntity {
284 uri: uri.to_string(),
285 score,
286 source: ScoreSource::Fused,
287 metadata: HashMap::new(),
288 }
289 }
290
291 #[test]
294 fn test_cosine_similarity_identical_vectors() {
295 let v = vec![0.6, 0.8];
296 assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-6);
297 }
298
299 #[test]
300 fn test_cosine_similarity_orthogonal() {
301 let a = vec![1.0, 0.0];
302 let b = vec![0.0, 1.0];
303 assert!((cosine_similarity(&a, &b)).abs() < 1e-6);
304 }
305
306 #[test]
307 fn test_cosine_similarity_zero_vector() {
308 let a = vec![0.0, 0.0];
309 let b = vec![1.0, 0.0];
310 assert_eq!(cosine_similarity(&a, &b), 0.0);
311 }
312
313 #[test]
316 fn test_colbert_score_same_query_doc() {
317 let enc = make_encoder(8);
319 let q = enc.encode("battery safety").expect("should succeed");
320 let d = enc.encode("battery safety").expect("should succeed");
321 let score = colbert_score(&q, &d);
322 assert!(
323 score > 0.8,
324 "Identical query/doc should score >0.8, got {score}"
325 );
326 }
327
328 #[test]
329 fn test_colbert_score_empty_query() {
330 let q: TokenSequence = vec![];
331 let d = vec![vec![1.0f32, 0.0]];
332 assert_eq!(colbert_score(&q, &d), 0.0);
333 }
334
335 #[test]
336 fn test_colbert_score_empty_doc() {
337 let q = vec![vec![1.0f32, 0.0]];
338 let d: TokenSequence = vec![];
339 assert_eq!(colbert_score(&q, &d), 0.0);
340 }
341
342 #[test]
345 fn test_mock_encoder_deterministic() {
346 let enc = make_encoder(16);
347 let e1 = enc.encode("hello world").expect("should succeed");
348 let e2 = enc.encode("hello world").expect("should succeed");
349 assert_eq!(e1.len(), e2.len());
350 for (a, b) in e1.iter().zip(e2.iter()) {
351 for (x, y) in a.iter().zip(b.iter()) {
352 assert!((x - y).abs() < 1e-9);
353 }
354 }
355 }
356
357 #[test]
358 fn test_mock_encoder_registered_token() {
359 let mut enc = make_encoder(4);
360 enc.register_token("special", vec![1.0, 0.0, 0.0, 0.0]);
361 let tokens = enc.encode("special term").expect("should succeed");
362 assert_eq!(tokens.len(), 2);
363 assert!((tokens[0][0] - 1.0).abs() < 1e-9);
365 }
366
367 #[test]
368 fn test_mock_encoder_unit_length() {
369 let enc = make_encoder(32);
370 let tokens = enc
371 .encode("test token normalization")
372 .expect("should succeed");
373 for tok in &tokens {
374 let norm: f32 = tok.iter().map(|x| x * x).sum::<f32>().sqrt();
375 assert!((norm - 1.0).abs() < 1e-5, "Token not unit length: {norm}");
376 }
377 }
378
379 #[test]
382 fn test_reranker_basic() {
383 let enc = make_encoder(16);
384 let mut reranker = ColbertReranker::new(enc, ColbertRerankerConfig::default());
385 reranker.register_documents([
386 (
387 "http://a".to_string(),
388 "battery safety cell thermal".to_string(),
389 ),
390 (
391 "http://b".to_string(),
392 "charging protocol electric".to_string(),
393 ),
394 ]);
395
396 let candidates = vec![make_entity("http://a", 0.7), make_entity("http://b", 0.6)];
397
398 let reranked = reranker
399 .rerank("battery safety", candidates)
400 .expect("should succeed");
401 assert_eq!(reranked.len(), 2);
402 assert_eq!(reranked[0].uri, "http://a");
404 }
405
406 #[test]
407 fn test_reranker_empty_candidates() {
408 let enc = make_encoder(8);
409 let reranker = ColbertReranker::new(enc, ColbertRerankerConfig::default());
410 let result = reranker.rerank("query", vec![]).expect("should succeed");
411 assert!(result.is_empty());
412 }
413
414 #[test]
415 fn test_reranker_empty_query() {
416 let enc = make_encoder(8);
417 let reranker = ColbertReranker::new(enc, ColbertRerankerConfig::default());
418 let candidates = vec![make_entity("http://a", 0.5)];
419 let result = reranker.rerank("", candidates).expect("should succeed");
420 assert_eq!(result.len(), 1);
421 }
422
423 #[test]
424 fn test_reranker_max_candidates_limiting() {
425 let enc = make_encoder(8);
426 let config = ColbertRerankerConfig {
427 max_candidates: 2,
428 ..Default::default()
429 };
430 let reranker = ColbertReranker::new(enc, config);
431 let candidates: Vec<ScoredEntity> = (0..10)
432 .map(|i| make_entity(&format!("http://e{i}"), 0.5))
433 .collect();
434 let result = reranker.rerank("test", candidates).expect("should succeed");
435 assert!(result.len() <= 2);
436 }
437
438 #[test]
439 fn test_reranker_min_score_filter() {
440 let enc = make_encoder(8);
441 let config = ColbertRerankerConfig {
442 min_colbert_score: 999.0, normalise_scores: false,
444 ..Default::default()
445 };
446 let reranker = ColbertReranker::new(enc, config);
447 let candidates = vec![make_entity("http://a", 0.8)];
448 let result = reranker.rerank("test", candidates).expect("should succeed");
449 assert!(result.is_empty());
450 }
451
452 #[test]
453 fn test_reranker_fallback_without_doc_store() {
454 let enc = make_encoder(8);
456 let reranker = ColbertReranker::new(enc, ColbertRerankerConfig::default());
457 let candidates = vec![make_entity("http://a", 0.7), make_entity("http://b", 0.6)];
458 let result = reranker
460 .rerank("some query", candidates)
461 .expect("should succeed");
462 assert_eq!(result.len(), 2);
463 }
464
465 #[test]
466 fn test_reranker_normalises_scores() {
467 let enc = make_encoder(16);
468 let config = ColbertRerankerConfig {
469 normalise_scores: true,
470 colbert_weight: 1.0, ..Default::default()
472 };
473 let mut reranker = ColbertReranker::new(enc, config);
474 reranker.register_documents([
475 ("http://x".to_string(), "alpha beta gamma".to_string()),
476 ("http://y".to_string(), "delta epsilon zeta".to_string()),
477 ]);
478 let candidates = vec![make_entity("http://x", 0.5), make_entity("http://y", 0.5)];
479 let result = reranker
480 .rerank("alpha gamma", candidates)
481 .expect("should succeed");
482 assert!(
484 result[0].score <= 1.01,
485 "Score should be ≤ 1.0, got {}",
486 result[0].score
487 );
488 }
489
490 #[test]
493 fn test_batch_scoring() {
494 let enc = make_encoder(16);
495 let docs = vec![
496 ("id1", "battery safety cell"),
497 ("id2", "charging electric vehicle"),
498 ("id3", "battery cell chemistry"),
499 ];
500 let scores = colbert_score_batch(&enc, "battery safety", &docs).expect("should succeed");
501 assert_eq!(scores.len(), 3);
502 for s in &scores {
503 assert!(*s >= 0.0, "Score should be non-negative");
504 }
505 assert!(
507 scores[0] > scores[1],
508 "Doc 0 should beat doc 1 for 'battery safety'"
509 );
510 }
511
512 #[test]
513 fn test_batch_scoring_empty_docs() {
514 let enc = make_encoder(8);
515 let scores = colbert_score_batch(&enc, "query", &[]).expect("should succeed");
516 assert!(scores.is_empty());
517 }
518
519 #[test]
520 fn test_colbert_score_partial_overlap() {
521 let enc = make_encoder(16);
522 let q = enc.encode("battery cell safety").expect("should succeed");
523 let d_rel = enc
524 .encode("battery cell thermal runaway")
525 .expect("should succeed");
526 let d_irrel = enc
527 .encode("aircraft propulsion jet")
528 .expect("should succeed");
529
530 let s_rel = colbert_score(&q, &d_rel);
531 let s_irrel = colbert_score(&q, &d_irrel);
532
533 assert!(
534 s_rel > s_irrel,
535 "Relevant doc should score higher: {s_rel} vs {s_irrel}"
536 );
537 }
538}