Skip to main content

graphrag_core/reranking/
cross_encoder.rs

1//! Cross-Encoder reranking for improved retrieval accuracy
2//!
3//! Cross-encoders jointly encode query and document, providing more accurate
4//! relevance scores than bi-encoder approaches. This implementation provides
5//! a trait-based interface that can be backed by ONNX models, API calls, or
6//! other implementations.
7//!
8//! Reference: "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks"
9//! Reimers & Gurevych (2019)
10
11use async_trait::async_trait;
12
13use crate::retrieval::SearchResult;
14use crate::Result;
15
16/// Configuration for cross-encoder reranking
17#[derive(Debug, Clone)]
18pub struct CrossEncoderConfig {
19    /// Model name/path for cross-encoder
20    pub model_name: String,
21
22    /// Maximum sequence length
23    pub max_length: usize,
24
25    /// Batch size for inference
26    pub batch_size: usize,
27
28    /// Top-k results to return after reranking
29    pub top_k: usize,
30
31    /// Minimum confidence threshold (0.0-1.0)
32    pub min_confidence: f32,
33
34    /// Enable score normalization
35    pub normalize_scores: bool,
36}
37
38impl Default for CrossEncoderConfig {
39    fn default() -> Self {
40        Self {
41            model_name: "cross-encoder/ms-marco-MiniLM-L-6-v2".to_string(),
42            max_length: 512,
43            batch_size: 32,
44            top_k: 10,
45            min_confidence: 0.0,
46            normalize_scores: true,
47        }
48    }
49}
50
51/// Result of cross-encoder reranking with confidence score
52#[derive(Debug, Clone)]
53pub struct RankedResult {
54    /// Original search result
55    pub result: SearchResult,
56
57    /// Cross-encoder relevance score (typically 0.0-1.0 after normalization)
58    pub relevance_score: f32,
59
60    /// Original retrieval score (for comparison)
61    pub original_score: f32,
62
63    /// Score improvement over original (relevance_score - original_score)
64    pub score_delta: f32,
65}
66
67/// Cross-encoder trait for reranking retrieved results
68#[async_trait]
69pub trait CrossEncoder: Send + Sync {
70    /// Rerank a list of search results based on relevance to query
71    async fn rerank(&self, query: &str, candidates: Vec<SearchResult>)
72        -> Result<Vec<RankedResult>>;
73
74    /// Score a single query-document pair
75    async fn score_pair(&self, query: &str, document: &str) -> Result<f32>;
76
77    /// Batch score multiple query-document pairs
78    async fn score_batch(&self, pairs: Vec<(String, String)>) -> Result<Vec<f32>>;
79}
80
81#[cfg(feature = "neural-embeddings")]
82use candle_core::{Device, Tensor};
83#[cfg(feature = "neural-embeddings")]
84use candle_nn::VarBuilder;
85#[cfg(feature = "neural-embeddings")]
86use candle_transformers::models::bert::{BertModel, Config, Dtype};
87#[cfg(feature = "huggingface-hub")]
88use hf_hub::api::sync::Api;
89#[cfg(feature = "neural-embeddings")]
90use tokenizers::Tokenizer;
91
92/// Cross-encoder implementation using Candle (BERT)
93#[cfg(feature = "neural-embeddings")]
94pub struct CandleCrossEncoder {
95    config: CrossEncoderConfig,
96    model: BertModel,
97    tokenizer: Tokenizer,
98    device: Device,
99}
100
101#[cfg(feature = "neural-embeddings")]
102impl CandleCrossEncoder {
103    pub fn new(config: CrossEncoderConfig) -> Result<Self> {
104        let api = Api::new().map_err(|e| GraphRAGError::Embedding {
105            message: format!("Failed to create HF Hub API: {}", e),
106        })?;
107        let repo = api.model(config.model_name.clone());
108
109        let model_file = repo
110            .get("model.safetensors")
111            .or_else(|_| repo.get("pytorch_model.bin"))
112            .map_err(|e| GraphRAGError::Embedding {
113                message: format!("Failed to download model '{}': {}", config.model_name, e),
114            })?;
115
116        let tokenizer_file = repo
117            .get("tokenizer.json")
118            .map_err(|e| GraphRAGError::Embedding {
119                message: format!("Failed to download tokenizer: {}", e),
120            })?;
121
122        let config_file = repo
123            .get("config.json")
124            .map_err(|e| GraphRAGError::Embedding {
125                message: format!("Failed to download config: {}", e),
126            })?;
127
128        let device = Device::Cpu;
129        let model_config: Config =
130            serde_json::from_str(&std::fs::read_to_string(config_file).map_err(|e| {
131                GraphRAGError::Embedding {
132                    message: format!("Failed to read config: {}", e),
133                }
134            })?)
135            .map_err(|e| GraphRAGError::Embedding {
136                message: format!("Failed to parse config: {}", e),
137            })?;
138
139        let tokenizer =
140            Tokenizer::from_file(tokenizer_file).map_err(|e| GraphRAGError::Embedding {
141                message: format!("Failed to load tokenizer: {}", e),
142            })?;
143
144        let vb = unsafe {
145            VarBuilder::from_mmaped_safetensors(&[model_file], Dtype::F32, &device).map_err(
146                |e| GraphRAGError::Embedding {
147                    message: format!("Failed to load weights: {}", e),
148                },
149            )?
150        };
151
152        let model = BertModel::load(vb, &model_config).map_err(|e| GraphRAGError::Embedding {
153            message: format!("Failed to load BERT model: {}", e),
154        })?;
155
156        Ok(Self {
157            config,
158            model,
159            tokenizer,
160            device,
161        })
162    }
163}
164
165#[cfg(feature = "neural-embeddings")]
166#[async_trait]
167impl CrossEncoder for CandleCrossEncoder {
168    async fn rerank(
169        &self,
170        query: &str,
171        candidates: Vec<SearchResult>,
172    ) -> Result<Vec<RankedResult>> {
173        let mut ranked = Vec::new();
174
175        for candidate in candidates {
176            let score = self.score_pair(query, &candidate.content).await?;
177            let score_delta = score - candidate.score;
178
179            if score >= self.config.min_confidence {
180                ranked.push(RankedResult {
181                    result: candidate,
182                    relevance_score: score,
183                    original_score: candidate.score,
184                    score_delta,
185                });
186            }
187        }
188
189        ranked.sort_by(|a, b| {
190            b.relevance_score
191                .partial_cmp(&a.relevance_score)
192                .unwrap_or(std::cmp::Ordering::Equal)
193        });
194        ranked.truncate(self.config.top_k);
195        Ok(ranked)
196    }
197
198    async fn score_pair(&self, query: &str, document: &str) -> Result<f32> {
199        let tokens = self
200            .tokenizer
201            .encode((query, document), true)
202            .map_err(|e| GraphRAGError::Embedding {
203                message: format!("Tokenization failed: {}", e),
204            })?;
205
206        let token_ids = Tensor::new(tokens.get_ids(), &self.device)
207            .map_err(|e| GraphRAGError::Embedding {
208                message: format!("Tensor creation failed: {}", e),
209            })?
210            .unsqueeze(0)
211            .map_err(|_| GraphRAGError::Embedding {
212                message: "Unsqueeze failed".to_string(),
213            })?;
214
215        let token_type_ids = Tensor::new(tokens.get_type_ids(), &self.device)
216            .map_err(|e| GraphRAGError::Embedding {
217                message: format!("Type tensor creation failed: {}", e),
218            })?
219            .unsqueeze(0)
220            .map_err(|_| GraphRAGError::Embedding {
221                message: "Unsqueeze failed".to_string(),
222            })?;
223
224        let logits = self
225            .model
226            .forward(&token_ids, &token_type_ids)
227            .map_err(|e| GraphRAGError::Embedding {
228                message: format!("Forward pass failed: {}", e),
229            })?;
230
231        // Cross-encoders typically output a single logit for the positive class (index 0 or 1 depending on model)
232        // Or if it's a regression model, 1 output.
233        // Assuming MiniLM cross-encoder, it outputs 1 value usually.
234        let score = logits
235            .squeeze(0)
236            .map_err(|_| GraphRAGError::Embedding {
237                message: "Squeeze failed".to_string(),
238            })?
239            .to_vec1::<f32>()
240            .map_err(|e| GraphRAGError::Embedding {
241                message: format!("To vec failed: {}", e),
242            })?;
243
244        // Sigmoid if needed, but often logits are enough for ranking. Config has normalize_scores.
245        let raw_score = score[0];
246
247        if self.config.normalize_scores {
248            Ok(1.0 / (1.0 + (-raw_score).exp()))
249        } else {
250            Ok(raw_score)
251        }
252    }
253
254    async fn score_batch(&self, pairs: Vec<(String, String)>) -> Result<Vec<f32>> {
255        let mut scores = Vec::new();
256        for (q, d) in pairs {
257            scores.push(self.score_pair(&q, &d).await?);
258        }
259        Ok(scores)
260    }
261}
262
263/// Statistics about reranking performance
264#[derive(Debug, Clone)]
265pub struct RerankingStats {
266    /// Number of candidates reranked
267    pub candidates_count: usize,
268
269    /// Number of results returned
270    pub results_count: usize,
271
272    /// Average score improvement (mean delta)
273    pub avg_score_improvement: f32,
274
275    /// Maximum score improvement
276    pub max_score_improvement: f32,
277
278    /// Percentage of candidates filtered out
279    pub filter_rate: f32,
280}
281
282impl RerankingStats {
283    /// Calculate statistics from ranked results
284    pub fn from_results(original_count: usize, ranked: &[RankedResult]) -> Self {
285        let results_count = ranked.len();
286
287        let avg_score_improvement = if !ranked.is_empty() {
288            ranked.iter().map(|r| r.score_delta).sum::<f32>() / ranked.len() as f32
289        } else {
290            0.0
291        };
292
293        let max_score_improvement = ranked
294            .iter()
295            .map(|r| r.score_delta)
296            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
297            .unwrap_or(0.0);
298
299        let filter_rate = if original_count > 0 {
300            ((original_count - results_count) as f32 / original_count as f32) * 100.0
301        } else {
302            0.0
303        };
304
305        Self {
306            candidates_count: original_count,
307            results_count,
308            avg_score_improvement,
309            max_score_improvement,
310            filter_rate,
311        }
312    }
313}
314
315/// Confidence-based cross-encoder implementation (Restored Fallback)
316pub struct ConfidenceCrossEncoder {
317    _config: CrossEncoderConfig,
318}
319
320impl ConfidenceCrossEncoder {
321    /// Create a new confidence-based cross-encoder with the given configuration
322    pub fn new(config: CrossEncoderConfig) -> Self {
323        Self { _config: config }
324    }
325}
326
327#[async_trait]
328impl CrossEncoder for ConfidenceCrossEncoder {
329    async fn rerank(
330        &self,
331        _query: &str,
332        candidates: Vec<SearchResult>,
333    ) -> Result<Vec<RankedResult>> {
334        // Simple passthrough/mock implementation to satisfy imports
335        let mut ranked = Vec::new();
336        for candidate in candidates {
337            ranked.push(RankedResult {
338                result: candidate.clone(),
339                relevance_score: candidate.score,
340                original_score: candidate.score,
341                score_delta: 0.0,
342            });
343        }
344        Ok(ranked)
345    }
346
347    async fn score_pair(&self, _query: &str, _document: &str) -> Result<f32> {
348        Ok(0.0)
349    }
350
351    async fn score_batch(&self, pairs: Vec<(String, String)>) -> Result<Vec<f32>> {
352        Ok(vec![0.0; pairs.len()])
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359    use crate::retrieval::ResultType;
360
361    fn create_test_result(id: &str, content: &str, score: f32) -> SearchResult {
362        SearchResult {
363            id: id.to_string(),
364            content: content.to_string(),
365            score,
366            result_type: ResultType::Chunk,
367            entities: Vec::new(),
368            source_chunks: Vec::new(),
369        }
370    }
371
372    #[tokio::test]
373    async fn test_rerank_basic() {
374        let config = CrossEncoderConfig {
375            top_k: 3,
376            min_confidence: 0.0,
377            ..Default::default()
378        };
379
380        let encoder = ConfidenceCrossEncoder::new(config);
381
382        let query = "machine learning algorithms";
383        let candidates = vec![
384            create_test_result(
385                "1",
386                "Machine learning is a subset of artificial intelligence",
387                0.5,
388            ),
389            create_test_result("2", "The weather today is sunny", 0.6),
390            create_test_result(
391                "3",
392                "Neural networks are machine learning algorithms used for pattern recognition",
393                0.4,
394            ),
395        ];
396
397        let ranked = encoder.rerank(query, candidates).await.unwrap();
398
399        // Should rerank based on relevance
400        assert_eq!(ranked.len(), 3);
401
402        // Most relevant should be first (result 3 has best overlap)
403        assert!(ranked[0].relevance_score >= ranked[1].relevance_score);
404        assert!(ranked[1].relevance_score >= ranked[2].relevance_score);
405    }
406
407    #[tokio::test]
408    async fn test_confidence_filtering() {
409        let config = CrossEncoderConfig {
410            top_k: 10,
411            min_confidence: 0.5, // High threshold
412            ..Default::default()
413        };
414
415        let encoder = ConfidenceCrossEncoder::new(config);
416
417        let query = "specific technical query";
418        let candidates = vec![
419            create_test_result("1", "highly relevant technical content", 0.3),
420            create_test_result("2", "somewhat relevant", 0.4),
421            create_test_result("3", "not relevant at all", 0.5),
422        ];
423
424        let ranked = encoder.rerank(query, candidates).await.unwrap();
425
426        // Should filter low-confidence results
427        for result in &ranked {
428            assert!(result.relevance_score >= 0.5);
429        }
430    }
431
432    #[tokio::test]
433    async fn test_score_pair() {
434        let config = CrossEncoderConfig::default();
435        let encoder = ConfidenceCrossEncoder::new(config);
436
437        let score = encoder
438            .score_pair(
439                "artificial intelligence",
440                "AI and machine learning are related fields",
441            )
442            .await
443            .unwrap();
444
445        assert!(score >= 0.0 && score <= 1.0);
446    }
447
448    #[test]
449    fn test_reranking_stats() {
450        let ranked = vec![
451            RankedResult {
452                result: create_test_result("1", "test", 0.5),
453                relevance_score: 0.8,
454                original_score: 0.5,
455                score_delta: 0.3,
456            },
457            RankedResult {
458                result: create_test_result("2", "test", 0.6),
459                relevance_score: 0.7,
460                original_score: 0.6,
461                score_delta: 0.1,
462            },
463        ];
464
465        let stats = RerankingStats::from_results(5, &ranked);
466
467        assert_eq!(stats.candidates_count, 5);
468        assert_eq!(stats.results_count, 2);
469        // Use approximate equality for floating point comparison
470        assert!((stats.filter_rate - 60.0).abs() < 0.001); // 3/5 filtered = 60%
471        assert!(stats.avg_score_improvement > 0.0);
472    }
473}