oxirs_graphrag/retrieval/
reranker.rs

1//! Cross-encoder reranking for GraphRAG
2
3use crate::{GraphRAGResult, ScoredEntity};
4use async_trait::async_trait;
5
6/// Reranker trait for cross-encoder reranking
7#[async_trait]
8pub trait RerankerTrait: Send + Sync {
9    /// Rerank candidates given the original query
10    async fn rerank(
11        &self,
12        query: &str,
13        candidates: Vec<ScoredEntity>,
14    ) -> GraphRAGResult<Vec<ScoredEntity>>;
15}
16
17/// Simple score-based reranker (no cross-encoder)
18pub struct Reranker {
19    /// Boost factor for entities appearing in multiple sources
20    multi_source_boost: f64,
21    /// Minimum score threshold
22    min_score: f64,
23}
24
25impl Default for Reranker {
26    fn default() -> Self {
27        Self::new()
28    }
29}
30
31impl Reranker {
32    pub fn new() -> Self {
33        Self {
34            multi_source_boost: 1.2,
35            min_score: 0.1,
36        }
37    }
38
39    /// Set multi-source boost factor
40    pub fn with_multi_source_boost(mut self, boost: f64) -> Self {
41        self.multi_source_boost = boost;
42        self
43    }
44
45    /// Set minimum score threshold
46    pub fn with_min_score(mut self, min_score: f64) -> Self {
47        self.min_score = min_score;
48        self
49    }
50
51    /// Rerank candidates based on heuristics
52    pub fn rerank(&self, candidates: Vec<ScoredEntity>) -> Vec<ScoredEntity> {
53        let mut reranked: Vec<ScoredEntity> = candidates
54            .into_iter()
55            .filter(|e| e.score >= self.min_score)
56            .map(|mut e| {
57                // Boost fused results
58                if e.source == crate::ScoreSource::Fused {
59                    e.score *= self.multi_source_boost;
60                }
61                e
62            })
63            .collect();
64
65        reranked.sort_by(|a, b| {
66            b.score
67                .partial_cmp(&a.score)
68                .unwrap_or(std::cmp::Ordering::Equal)
69        });
70        reranked
71    }
72}
73
74/// Cross-encoder reranker using an embedding model
75pub struct CrossEncoderReranker<E>
76where
77    E: CrossEncoderModel,
78{
79    model: E,
80    batch_size: usize,
81}
82
83/// Trait for cross-encoder models
84#[async_trait]
85pub trait CrossEncoderModel: Send + Sync {
86    /// Score a query-document pair
87    async fn score(&self, query: &str, document: &str) -> GraphRAGResult<f64>;
88
89    /// Score multiple pairs in batch
90    async fn score_batch(&self, query: &str, documents: &[&str]) -> GraphRAGResult<Vec<f64>>;
91}
92
93impl<E: CrossEncoderModel> CrossEncoderReranker<E> {
94    pub fn new(model: E) -> Self {
95        Self {
96            model,
97            batch_size: 32,
98        }
99    }
100
101    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
102        self.batch_size = batch_size;
103        self
104    }
105
106    /// Rerank using cross-encoder
107    pub async fn rerank(
108        &self,
109        query: &str,
110        candidates: Vec<ScoredEntity>,
111    ) -> GraphRAGResult<Vec<ScoredEntity>> {
112        if candidates.is_empty() {
113            return Ok(vec![]);
114        }
115
116        // Get document representations (URIs for now)
117        let docs: Vec<&str> = candidates.iter().map(|e| e.uri.as_str()).collect();
118
119        // Score in batches
120        let mut all_scores = Vec::with_capacity(candidates.len());
121        for chunk in docs.chunks(self.batch_size) {
122            let scores = self.model.score_batch(query, chunk).await?;
123            all_scores.extend(scores);
124        }
125
126        // Combine with original scores
127        let mut reranked: Vec<ScoredEntity> = candidates
128            .into_iter()
129            .zip(all_scores)
130            .map(|(mut e, cross_score)| {
131                // Weighted combination of original and cross-encoder scores
132                e.score = e.score * 0.3 + cross_score * 0.7;
133                e
134            })
135            .collect();
136
137        reranked.sort_by(|a, b| {
138            b.score
139                .partial_cmp(&a.score)
140                .unwrap_or(std::cmp::Ordering::Equal)
141        });
142        Ok(reranked)
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use crate::ScoreSource;
150    use std::collections::HashMap;
151
152    #[test]
153    fn test_simple_reranker() {
154        let reranker = Reranker::new();
155
156        let candidates = vec![
157            ScoredEntity {
158                uri: "http://a".to_string(),
159                score: 0.5,
160                source: ScoreSource::Vector,
161                metadata: HashMap::new(),
162            },
163            ScoredEntity {
164                uri: "http://b".to_string(),
165                score: 0.6,
166                source: ScoreSource::Fused,
167                metadata: HashMap::new(),
168            },
169            ScoredEntity {
170                uri: "http://c".to_string(),
171                score: 0.05,
172                source: ScoreSource::Keyword,
173                metadata: HashMap::new(),
174            },
175        ];
176
177        let reranked = reranker.rerank(candidates);
178
179        // 'b' should be first (boosted), 'a' second, 'c' filtered out
180        assert_eq!(reranked.len(), 2);
181        assert_eq!(reranked[0].uri, "http://b");
182    }
183}