oxirs_graphrag/retrieval/
reranker.rs1use crate::{GraphRAGResult, ScoredEntity};
4use async_trait::async_trait;
5
6#[async_trait]
8pub trait RerankerTrait: Send + Sync {
9 async fn rerank(
11 &self,
12 query: &str,
13 candidates: Vec<ScoredEntity>,
14 ) -> GraphRAGResult<Vec<ScoredEntity>>;
15}
16
17pub struct Reranker {
19 multi_source_boost: f64,
21 min_score: f64,
23}
24
25impl Default for Reranker {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl Reranker {
32 pub fn new() -> Self {
33 Self {
34 multi_source_boost: 1.2,
35 min_score: 0.1,
36 }
37 }
38
39 pub fn with_multi_source_boost(mut self, boost: f64) -> Self {
41 self.multi_source_boost = boost;
42 self
43 }
44
45 pub fn with_min_score(mut self, min_score: f64) -> Self {
47 self.min_score = min_score;
48 self
49 }
50
51 pub fn rerank(&self, candidates: Vec<ScoredEntity>) -> Vec<ScoredEntity> {
53 let mut reranked: Vec<ScoredEntity> = candidates
54 .into_iter()
55 .filter(|e| e.score >= self.min_score)
56 .map(|mut e| {
57 if e.source == crate::ScoreSource::Fused {
59 e.score *= self.multi_source_boost;
60 }
61 e
62 })
63 .collect();
64
65 reranked.sort_by(|a, b| {
66 b.score
67 .partial_cmp(&a.score)
68 .unwrap_or(std::cmp::Ordering::Equal)
69 });
70 reranked
71 }
72}
73
74pub struct CrossEncoderReranker<E>
76where
77 E: CrossEncoderModel,
78{
79 model: E,
80 batch_size: usize,
81}
82
83#[async_trait]
85pub trait CrossEncoderModel: Send + Sync {
86 async fn score(&self, query: &str, document: &str) -> GraphRAGResult<f64>;
88
89 async fn score_batch(&self, query: &str, documents: &[&str]) -> GraphRAGResult<Vec<f64>>;
91}
92
93impl<E: CrossEncoderModel> CrossEncoderReranker<E> {
94 pub fn new(model: E) -> Self {
95 Self {
96 model,
97 batch_size: 32,
98 }
99 }
100
101 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
102 self.batch_size = batch_size;
103 self
104 }
105
106 pub async fn rerank(
108 &self,
109 query: &str,
110 candidates: Vec<ScoredEntity>,
111 ) -> GraphRAGResult<Vec<ScoredEntity>> {
112 if candidates.is_empty() {
113 return Ok(vec![]);
114 }
115
116 let docs: Vec<&str> = candidates.iter().map(|e| e.uri.as_str()).collect();
118
119 let mut all_scores = Vec::with_capacity(candidates.len());
121 for chunk in docs.chunks(self.batch_size) {
122 let scores = self.model.score_batch(query, chunk).await?;
123 all_scores.extend(scores);
124 }
125
126 let mut reranked: Vec<ScoredEntity> = candidates
128 .into_iter()
129 .zip(all_scores)
130 .map(|(mut e, cross_score)| {
131 e.score = e.score * 0.3 + cross_score * 0.7;
133 e
134 })
135 .collect();
136
137 reranked.sort_by(|a, b| {
138 b.score
139 .partial_cmp(&a.score)
140 .unwrap_or(std::cmp::Ordering::Equal)
141 });
142 Ok(reranked)
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use crate::ScoreSource;
150 use std::collections::HashMap;
151
152 #[test]
153 fn test_simple_reranker() {
154 let reranker = Reranker::new();
155
156 let candidates = vec![
157 ScoredEntity {
158 uri: "http://a".to_string(),
159 score: 0.5,
160 source: ScoreSource::Vector,
161 metadata: HashMap::new(),
162 },
163 ScoredEntity {
164 uri: "http://b".to_string(),
165 score: 0.6,
166 source: ScoreSource::Fused,
167 metadata: HashMap::new(),
168 },
169 ScoredEntity {
170 uri: "http://c".to_string(),
171 score: 0.05,
172 source: ScoreSource::Keyword,
173 metadata: HashMap::new(),
174 },
175 ];
176
177 let reranked = reranker.rerank(candidates);
178
179 assert_eq!(reranked.len(), 2);
181 assert_eq!(reranked[0].uri, "http://b");
182 }
183}