Skip to main content

graphrag_core/reranking/
cross_encoder.rs

1//! Cross-Encoder reranking for improved retrieval accuracy
2//!
3//! Cross-encoders jointly encode query and document, providing more accurate
4//! relevance scores than bi-encoder approaches. This implementation provides
5//! a trait-based interface that can be backed by ONNX models, API calls, or
6//! other implementations.
7//!
8//! Reference: "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks"
9//! Reimers & Gurevych (2019)
10
11use std::collections::HashMap;
12use async_trait::async_trait;
13
14use crate::retrieval::SearchResult;
15use crate::Result;
16
17/// Configuration for cross-encoder reranking
18#[derive(Debug, Clone)]
19pub struct CrossEncoderConfig {
20    /// Model name/path for cross-encoder
21    pub model_name: String,
22
23    /// Maximum sequence length
24    pub max_length: usize,
25
26    /// Batch size for inference
27    pub batch_size: usize,
28
29    /// Top-k results to return after reranking
30    pub top_k: usize,
31
32    /// Minimum confidence threshold (0.0-1.0)
33    pub min_confidence: f32,
34
35    /// Enable score normalization
36    pub normalize_scores: bool,
37}
38
39impl Default for CrossEncoderConfig {
40    fn default() -> Self {
41        Self {
42            model_name: "cross-encoder/ms-marco-MiniLM-L-6-v2".to_string(),
43            max_length: 512,
44            batch_size: 32,
45            top_k: 10,
46            min_confidence: 0.0,
47            normalize_scores: true,
48        }
49    }
50}
51
52/// Result of cross-encoder reranking with confidence score
53#[derive(Debug, Clone)]
54pub struct RankedResult {
55    /// Original search result
56    pub result: SearchResult,
57
58    /// Cross-encoder relevance score (typically 0.0-1.0 after normalization)
59    pub relevance_score: f32,
60
61    /// Original retrieval score (for comparison)
62    pub original_score: f32,
63
64    /// Score improvement over original (relevance_score - original_score)
65    pub score_delta: f32,
66}
67
68/// Cross-encoder trait for reranking retrieved results
69#[async_trait]
70pub trait CrossEncoder: Send + Sync {
71    /// Rerank a list of search results based on relevance to query
72    async fn rerank(
73        &self,
74        query: &str,
75        candidates: Vec<SearchResult>,
76    ) -> Result<Vec<RankedResult>>;
77
78    /// Score a single query-document pair
79    async fn score_pair(&self, query: &str, document: &str) -> Result<f32>;
80
81    /// Batch score multiple query-document pairs
82    async fn score_batch(
83        &self,
84        pairs: Vec<(String, String)>,
85    ) -> Result<Vec<f32>>;
86}
87
88/// Confidence-based cross-encoder implementation
89///
90/// This implementation uses semantic similarity and confidence metrics
91/// to rerank results. For production use with actual transformer models,
92/// consider using ONNXCrossEncoder or APICrossEncoder implementations.
93pub struct ConfidenceCrossEncoder {
94    config: CrossEncoderConfig,
95}
96
97impl ConfidenceCrossEncoder {
98    /// Create a new confidence-based cross-encoder
99    pub fn new(config: CrossEncoderConfig) -> Self {
100        Self { config }
101    }
102
103    /// Calculate relevance score based on text similarity and length
104    fn calculate_relevance(&self, query: &str, document: &str) -> f32 {
105        // Tokenize
106        let query_tokens: Vec<&str> = query.split_whitespace().collect();
107        let doc_tokens: Vec<&str> = document.split_whitespace().collect();
108
109        if query_tokens.is_empty() || doc_tokens.is_empty() {
110            return 0.0;
111        }
112
113        // Calculate token overlap (Jaccard similarity as baseline)
114        let query_set: HashMap<&str, ()> = query_tokens.iter()
115            .map(|t| (*t, ()))
116            .collect();
117        let doc_set: HashMap<&str, ()> = doc_tokens.iter()
118            .map(|t| (*t, ()))
119            .collect();
120
121        let intersection: usize = query_set.keys()
122            .filter(|k| doc_set.contains_key(*k))
123            .count();
124
125        let union_size = query_set.len() + doc_set.len() - intersection;
126
127        let jaccard = if union_size > 0 {
128            intersection as f32 / union_size as f32
129        } else {
130            0.0
131        };
132
133        // Boost score based on document length (prefer longer, more informative docs)
134        let length_factor = (doc_tokens.len() as f32 / 100.0).min(1.0);
135
136        // Combined score
137        let raw_score = jaccard * 0.7 + length_factor * 0.3;
138
139        if self.config.normalize_scores {
140            // Normalize to 0-1 range using sigmoid-like function
141            1.0 / (1.0 + (-5.0 * (raw_score - 0.5)).exp())
142        } else {
143            raw_score
144        }
145    }
146}
147
148#[async_trait]
149impl CrossEncoder for ConfidenceCrossEncoder {
150    async fn rerank(
151        &self,
152        query: &str,
153        candidates: Vec<SearchResult>,
154    ) -> Result<Vec<RankedResult>> {
155        if candidates.is_empty() {
156            return Ok(Vec::new());
157        }
158
159        // Score all candidates
160        let mut ranked: Vec<RankedResult> = candidates
161            .into_iter()
162            .map(|result| {
163                let relevance_score = self.calculate_relevance(query, &result.content);
164                let original_score = result.score;
165                let score_delta = relevance_score - original_score;
166
167                RankedResult {
168                    result,
169                    relevance_score,
170                    original_score,
171                    score_delta,
172                }
173            })
174            .collect();
175
176        // Sort by relevance score (descending)
177        ranked.sort_by(|a, b| {
178            b.relevance_score
179                .partial_cmp(&a.relevance_score)
180                .unwrap_or(std::cmp::Ordering::Equal)
181        });
182
183        // Filter by confidence threshold
184        ranked.retain(|r| r.relevance_score >= self.config.min_confidence);
185
186        // Truncate to top-k
187        ranked.truncate(self.config.top_k);
188
189        log::info!(
190            "Reranked {} candidates, returning top-{}",
191            ranked.len(),
192            self.config.top_k
193        );
194
195        Ok(ranked)
196    }
197
198    async fn score_pair(&self, query: &str, document: &str) -> Result<f32> {
199        Ok(self.calculate_relevance(query, document))
200    }
201
202    async fn score_batch(&self, pairs: Vec<(String, String)>) -> Result<Vec<f32>> {
203        let scores = pairs
204            .iter()
205            .map(|(query, doc)| self.calculate_relevance(query, doc))
206            .collect();
207
208        Ok(scores)
209    }
210}
211
212/// Statistics about reranking performance
213#[derive(Debug, Clone)]
214pub struct RerankingStats {
215    /// Number of candidates reranked
216    pub candidates_count: usize,
217
218    /// Number of results returned
219    pub results_count: usize,
220
221    /// Average score improvement (mean delta)
222    pub avg_score_improvement: f32,
223
224    /// Maximum score improvement
225    pub max_score_improvement: f32,
226
227    /// Percentage of candidates filtered out
228    pub filter_rate: f32,
229}
230
231impl RerankingStats {
232    /// Calculate statistics from ranked results
233    pub fn from_results(
234        original_count: usize,
235        ranked: &[RankedResult],
236    ) -> Self {
237        let results_count = ranked.len();
238
239        let avg_score_improvement = if !ranked.is_empty() {
240            ranked.iter().map(|r| r.score_delta).sum::<f32>() / ranked.len() as f32
241        } else {
242            0.0
243        };
244
245        let max_score_improvement = ranked
246            .iter()
247            .map(|r| r.score_delta)
248            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
249            .unwrap_or(0.0);
250
251        let filter_rate = if original_count > 0 {
252            ((original_count - results_count) as f32 / original_count as f32) * 100.0
253        } else {
254            0.0
255        };
256
257        Self {
258            candidates_count: original_count,
259            results_count,
260            avg_score_improvement,
261            max_score_improvement,
262            filter_rate,
263        }
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270    use crate::retrieval::ResultType;
271
272    fn create_test_result(id: &str, content: &str, score: f32) -> SearchResult {
273        SearchResult {
274            id: id.to_string(),
275            content: content.to_string(),
276            score,
277            result_type: ResultType::Chunk,
278            entities: Vec::new(),
279            source_chunks: Vec::new(),
280        }
281    }
282
283    #[tokio::test]
284    async fn test_rerank_basic() {
285        let config = CrossEncoderConfig {
286            top_k: 3,
287            min_confidence: 0.0,
288            ..Default::default()
289        };
290
291        let encoder = ConfidenceCrossEncoder::new(config);
292
293        let query = "machine learning algorithms";
294        let candidates = vec![
295            create_test_result(
296                "1",
297                "Machine learning is a subset of artificial intelligence",
298                0.5,
299            ),
300            create_test_result(
301                "2",
302                "The weather today is sunny",
303                0.6,
304            ),
305            create_test_result(
306                "3",
307                "Neural networks are machine learning algorithms used for pattern recognition",
308                0.4,
309            ),
310        ];
311
312        let ranked = encoder.rerank(query, candidates).await.unwrap();
313
314        // Should rerank based on relevance
315        assert_eq!(ranked.len(), 3);
316
317        // Most relevant should be first (result 3 has best overlap)
318        assert!(ranked[0].relevance_score >= ranked[1].relevance_score);
319        assert!(ranked[1].relevance_score >= ranked[2].relevance_score);
320    }
321
322    #[tokio::test]
323    async fn test_confidence_filtering() {
324        let config = CrossEncoderConfig {
325            top_k: 10,
326            min_confidence: 0.5, // High threshold
327            ..Default::default()
328        };
329
330        let encoder = ConfidenceCrossEncoder::new(config);
331
332        let query = "specific technical query";
333        let candidates = vec![
334            create_test_result("1", "highly relevant technical content", 0.3),
335            create_test_result("2", "somewhat relevant", 0.4),
336            create_test_result("3", "not relevant at all", 0.5),
337        ];
338
339        let ranked = encoder.rerank(query, candidates).await.unwrap();
340
341        // Should filter low-confidence results
342        for result in &ranked {
343            assert!(result.relevance_score >= 0.5);
344        }
345    }
346
347    #[tokio::test]
348    async fn test_score_pair() {
349        let config = CrossEncoderConfig::default();
350        let encoder = ConfidenceCrossEncoder::new(config);
351
352        let score = encoder
353            .score_pair(
354                "artificial intelligence",
355                "AI and machine learning are related fields",
356            )
357            .await
358            .unwrap();
359
360        assert!(score >= 0.0 && score <= 1.0);
361    }
362
363    #[test]
364    fn test_reranking_stats() {
365        let ranked = vec![
366            RankedResult {
367                result: create_test_result("1", "test", 0.5),
368                relevance_score: 0.8,
369                original_score: 0.5,
370                score_delta: 0.3,
371            },
372            RankedResult {
373                result: create_test_result("2", "test", 0.6),
374                relevance_score: 0.7,
375                original_score: 0.6,
376                score_delta: 0.1,
377            },
378        ];
379
380        let stats = RerankingStats::from_results(5, &ranked);
381
382        assert_eq!(stats.candidates_count, 5);
383        assert_eq!(stats.results_count, 2);
384        // Use approximate equality for floating point comparison
385        assert!((stats.filter_rate - 60.0).abs() < 0.001); // 3/5 filtered = 60%
386        assert!(stats.avg_score_improvement > 0.0);
387    }
388}