1use std::collections::HashMap;
12use async_trait::async_trait;
13
14use crate::retrieval::SearchResult;
15use crate::Result;
16
17#[derive(Debug, Clone)]
19pub struct CrossEncoderConfig {
20 pub model_name: String,
22
23 pub max_length: usize,
25
26 pub batch_size: usize,
28
29 pub top_k: usize,
31
32 pub min_confidence: f32,
34
35 pub normalize_scores: bool,
37}
38
39impl Default for CrossEncoderConfig {
40 fn default() -> Self {
41 Self {
42 model_name: "cross-encoder/ms-marco-MiniLM-L-6-v2".to_string(),
43 max_length: 512,
44 batch_size: 32,
45 top_k: 10,
46 min_confidence: 0.0,
47 normalize_scores: true,
48 }
49 }
50}
51
52#[derive(Debug, Clone)]
54pub struct RankedResult {
55 pub result: SearchResult,
57
58 pub relevance_score: f32,
60
61 pub original_score: f32,
63
64 pub score_delta: f32,
66}
67
68#[async_trait]
70pub trait CrossEncoder: Send + Sync {
71 async fn rerank(
73 &self,
74 query: &str,
75 candidates: Vec<SearchResult>,
76 ) -> Result<Vec<RankedResult>>;
77
78 async fn score_pair(&self, query: &str, document: &str) -> Result<f32>;
80
81 async fn score_batch(
83 &self,
84 pairs: Vec<(String, String)>,
85 ) -> Result<Vec<f32>>;
86}
87
88pub struct ConfidenceCrossEncoder {
94 config: CrossEncoderConfig,
95}
96
97impl ConfidenceCrossEncoder {
98 pub fn new(config: CrossEncoderConfig) -> Self {
100 Self { config }
101 }
102
103 fn calculate_relevance(&self, query: &str, document: &str) -> f32 {
105 let query_tokens: Vec<&str> = query.split_whitespace().collect();
107 let doc_tokens: Vec<&str> = document.split_whitespace().collect();
108
109 if query_tokens.is_empty() || doc_tokens.is_empty() {
110 return 0.0;
111 }
112
113 let query_set: HashMap<&str, ()> = query_tokens.iter()
115 .map(|t| (*t, ()))
116 .collect();
117 let doc_set: HashMap<&str, ()> = doc_tokens.iter()
118 .map(|t| (*t, ()))
119 .collect();
120
121 let intersection: usize = query_set.keys()
122 .filter(|k| doc_set.contains_key(*k))
123 .count();
124
125 let union_size = query_set.len() + doc_set.len() - intersection;
126
127 let jaccard = if union_size > 0 {
128 intersection as f32 / union_size as f32
129 } else {
130 0.0
131 };
132
133 let length_factor = (doc_tokens.len() as f32 / 100.0).min(1.0);
135
136 let raw_score = jaccard * 0.7 + length_factor * 0.3;
138
139 if self.config.normalize_scores {
140 1.0 / (1.0 + (-5.0 * (raw_score - 0.5)).exp())
142 } else {
143 raw_score
144 }
145 }
146}
147
148#[async_trait]
149impl CrossEncoder for ConfidenceCrossEncoder {
150 async fn rerank(
151 &self,
152 query: &str,
153 candidates: Vec<SearchResult>,
154 ) -> Result<Vec<RankedResult>> {
155 if candidates.is_empty() {
156 return Ok(Vec::new());
157 }
158
159 let mut ranked: Vec<RankedResult> = candidates
161 .into_iter()
162 .map(|result| {
163 let relevance_score = self.calculate_relevance(query, &result.content);
164 let original_score = result.score;
165 let score_delta = relevance_score - original_score;
166
167 RankedResult {
168 result,
169 relevance_score,
170 original_score,
171 score_delta,
172 }
173 })
174 .collect();
175
176 ranked.sort_by(|a, b| {
178 b.relevance_score
179 .partial_cmp(&a.relevance_score)
180 .unwrap_or(std::cmp::Ordering::Equal)
181 });
182
183 ranked.retain(|r| r.relevance_score >= self.config.min_confidence);
185
186 ranked.truncate(self.config.top_k);
188
189 log::info!(
190 "Reranked {} candidates, returning top-{}",
191 ranked.len(),
192 self.config.top_k
193 );
194
195 Ok(ranked)
196 }
197
198 async fn score_pair(&self, query: &str, document: &str) -> Result<f32> {
199 Ok(self.calculate_relevance(query, document))
200 }
201
202 async fn score_batch(&self, pairs: Vec<(String, String)>) -> Result<Vec<f32>> {
203 let scores = pairs
204 .iter()
205 .map(|(query, doc)| self.calculate_relevance(query, doc))
206 .collect();
207
208 Ok(scores)
209 }
210}
211
212#[derive(Debug, Clone)]
214pub struct RerankingStats {
215 pub candidates_count: usize,
217
218 pub results_count: usize,
220
221 pub avg_score_improvement: f32,
223
224 pub max_score_improvement: f32,
226
227 pub filter_rate: f32,
229}
230
231impl RerankingStats {
232 pub fn from_results(
234 original_count: usize,
235 ranked: &[RankedResult],
236 ) -> Self {
237 let results_count = ranked.len();
238
239 let avg_score_improvement = if !ranked.is_empty() {
240 ranked.iter().map(|r| r.score_delta).sum::<f32>() / ranked.len() as f32
241 } else {
242 0.0
243 };
244
245 let max_score_improvement = ranked
246 .iter()
247 .map(|r| r.score_delta)
248 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
249 .unwrap_or(0.0);
250
251 let filter_rate = if original_count > 0 {
252 ((original_count - results_count) as f32 / original_count as f32) * 100.0
253 } else {
254 0.0
255 };
256
257 Self {
258 candidates_count: original_count,
259 results_count,
260 avg_score_improvement,
261 max_score_improvement,
262 filter_rate,
263 }
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use crate::retrieval::ResultType;
271
272 fn create_test_result(id: &str, content: &str, score: f32) -> SearchResult {
273 SearchResult {
274 id: id.to_string(),
275 content: content.to_string(),
276 score,
277 result_type: ResultType::Chunk,
278 entities: Vec::new(),
279 source_chunks: Vec::new(),
280 }
281 }
282
283 #[tokio::test]
284 async fn test_rerank_basic() {
285 let config = CrossEncoderConfig {
286 top_k: 3,
287 min_confidence: 0.0,
288 ..Default::default()
289 };
290
291 let encoder = ConfidenceCrossEncoder::new(config);
292
293 let query = "machine learning algorithms";
294 let candidates = vec![
295 create_test_result(
296 "1",
297 "Machine learning is a subset of artificial intelligence",
298 0.5,
299 ),
300 create_test_result(
301 "2",
302 "The weather today is sunny",
303 0.6,
304 ),
305 create_test_result(
306 "3",
307 "Neural networks are machine learning algorithms used for pattern recognition",
308 0.4,
309 ),
310 ];
311
312 let ranked = encoder.rerank(query, candidates).await.unwrap();
313
314 assert_eq!(ranked.len(), 3);
316
317 assert!(ranked[0].relevance_score >= ranked[1].relevance_score);
319 assert!(ranked[1].relevance_score >= ranked[2].relevance_score);
320 }
321
322 #[tokio::test]
323 async fn test_confidence_filtering() {
324 let config = CrossEncoderConfig {
325 top_k: 10,
326 min_confidence: 0.5, ..Default::default()
328 };
329
330 let encoder = ConfidenceCrossEncoder::new(config);
331
332 let query = "specific technical query";
333 let candidates = vec![
334 create_test_result("1", "highly relevant technical content", 0.3),
335 create_test_result("2", "somewhat relevant", 0.4),
336 create_test_result("3", "not relevant at all", 0.5),
337 ];
338
339 let ranked = encoder.rerank(query, candidates).await.unwrap();
340
341 for result in &ranked {
343 assert!(result.relevance_score >= 0.5);
344 }
345 }
346
347 #[tokio::test]
348 async fn test_score_pair() {
349 let config = CrossEncoderConfig::default();
350 let encoder = ConfidenceCrossEncoder::new(config);
351
352 let score = encoder
353 .score_pair(
354 "artificial intelligence",
355 "AI and machine learning are related fields",
356 )
357 .await
358 .unwrap();
359
360 assert!(score >= 0.0 && score <= 1.0);
361 }
362
363 #[test]
364 fn test_reranking_stats() {
365 let ranked = vec![
366 RankedResult {
367 result: create_test_result("1", "test", 0.5),
368 relevance_score: 0.8,
369 original_score: 0.5,
370 score_delta: 0.3,
371 },
372 RankedResult {
373 result: create_test_result("2", "test", 0.6),
374 relevance_score: 0.7,
375 original_score: 0.6,
376 score_delta: 0.1,
377 },
378 ];
379
380 let stats = RerankingStats::from_results(5, &ranked);
381
382 assert_eq!(stats.candidates_count, 5);
383 assert_eq!(stats.results_count, 2);
384 assert!((stats.filter_rate - 60.0).abs() < 0.001); assert!(stats.avg_score_improvement > 0.0);
387 }
388}