1use async_trait::async_trait;
12
13use crate::retrieval::SearchResult;
14use crate::Result;
15
16#[derive(Debug, Clone)]
18pub struct CrossEncoderConfig {
19 pub model_name: String,
21
22 pub max_length: usize,
24
25 pub batch_size: usize,
27
28 pub top_k: usize,
30
31 pub min_confidence: f32,
33
34 pub normalize_scores: bool,
36}
37
38impl Default for CrossEncoderConfig {
39 fn default() -> Self {
40 Self {
41 model_name: "cross-encoder/ms-marco-MiniLM-L-6-v2".to_string(),
42 max_length: 512,
43 batch_size: 32,
44 top_k: 10,
45 min_confidence: 0.0,
46 normalize_scores: true,
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
53pub struct RankedResult {
54 pub result: SearchResult,
56
57 pub relevance_score: f32,
59
60 pub original_score: f32,
62
63 pub score_delta: f32,
65}
66
67#[async_trait]
69pub trait CrossEncoder: Send + Sync {
70 async fn rerank(&self, query: &str, candidates: Vec<SearchResult>)
72 -> Result<Vec<RankedResult>>;
73
74 async fn score_pair(&self, query: &str, document: &str) -> Result<f32>;
76
77 async fn score_batch(&self, pairs: Vec<(String, String)>) -> Result<Vec<f32>>;
79}
80
81#[cfg(feature = "neural-embeddings")]
82use candle_core::{Device, Tensor};
83#[cfg(feature = "neural-embeddings")]
84use candle_nn::VarBuilder;
85#[cfg(feature = "neural-embeddings")]
86use candle_transformers::models::bert::{BertModel, Config, Dtype};
87#[cfg(feature = "huggingface-hub")]
88use hf_hub::api::sync::Api;
89#[cfg(feature = "neural-embeddings")]
90use tokenizers::Tokenizer;
91
92#[cfg(feature = "neural-embeddings")]
94pub struct CandleCrossEncoder {
95 config: CrossEncoderConfig,
96 model: BertModel,
97 tokenizer: Tokenizer,
98 device: Device,
99}
100
101#[cfg(feature = "neural-embeddings")]
102impl CandleCrossEncoder {
103 pub fn new(config: CrossEncoderConfig) -> Result<Self> {
104 let api = Api::new().map_err(|e| GraphRAGError::Embedding {
105 message: format!("Failed to create HF Hub API: {}", e),
106 })?;
107 let repo = api.model(config.model_name.clone());
108
109 let model_file = repo
110 .get("model.safetensors")
111 .or_else(|_| repo.get("pytorch_model.bin"))
112 .map_err(|e| GraphRAGError::Embedding {
113 message: format!("Failed to download model '{}': {}", config.model_name, e),
114 })?;
115
116 let tokenizer_file = repo
117 .get("tokenizer.json")
118 .map_err(|e| GraphRAGError::Embedding {
119 message: format!("Failed to download tokenizer: {}", e),
120 })?;
121
122 let config_file = repo
123 .get("config.json")
124 .map_err(|e| GraphRAGError::Embedding {
125 message: format!("Failed to download config: {}", e),
126 })?;
127
128 let device = Device::Cpu;
129 let model_config: Config =
130 serde_json::from_str(&std::fs::read_to_string(config_file).map_err(|e| {
131 GraphRAGError::Embedding {
132 message: format!("Failed to read config: {}", e),
133 }
134 })?)
135 .map_err(|e| GraphRAGError::Embedding {
136 message: format!("Failed to parse config: {}", e),
137 })?;
138
139 let tokenizer =
140 Tokenizer::from_file(tokenizer_file).map_err(|e| GraphRAGError::Embedding {
141 message: format!("Failed to load tokenizer: {}", e),
142 })?;
143
144 let vb = unsafe {
145 VarBuilder::from_mmaped_safetensors(&[model_file], Dtype::F32, &device).map_err(
146 |e| GraphRAGError::Embedding {
147 message: format!("Failed to load weights: {}", e),
148 },
149 )?
150 };
151
152 let model = BertModel::load(vb, &model_config).map_err(|e| GraphRAGError::Embedding {
153 message: format!("Failed to load BERT model: {}", e),
154 })?;
155
156 Ok(Self {
157 config,
158 model,
159 tokenizer,
160 device,
161 })
162 }
163}
164
165#[cfg(feature = "neural-embeddings")]
166#[async_trait]
167impl CrossEncoder for CandleCrossEncoder {
168 async fn rerank(
169 &self,
170 query: &str,
171 candidates: Vec<SearchResult>,
172 ) -> Result<Vec<RankedResult>> {
173 let mut ranked = Vec::new();
174
175 for candidate in candidates {
176 let score = self.score_pair(query, &candidate.content).await?;
177 let score_delta = score - candidate.score;
178
179 if score >= self.config.min_confidence {
180 ranked.push(RankedResult {
181 result: candidate,
182 relevance_score: score,
183 original_score: candidate.score,
184 score_delta,
185 });
186 }
187 }
188
189 ranked.sort_by(|a, b| {
190 b.relevance_score
191 .partial_cmp(&a.relevance_score)
192 .unwrap_or(std::cmp::Ordering::Equal)
193 });
194 ranked.truncate(self.config.top_k);
195 Ok(ranked)
196 }
197
198 async fn score_pair(&self, query: &str, document: &str) -> Result<f32> {
199 let tokens = self
200 .tokenizer
201 .encode((query, document), true)
202 .map_err(|e| GraphRAGError::Embedding {
203 message: format!("Tokenization failed: {}", e),
204 })?;
205
206 let token_ids = Tensor::new(tokens.get_ids(), &self.device)
207 .map_err(|e| GraphRAGError::Embedding {
208 message: format!("Tensor creation failed: {}", e),
209 })?
210 .unsqueeze(0)
211 .map_err(|_| GraphRAGError::Embedding {
212 message: "Unsqueeze failed".to_string(),
213 })?;
214
215 let token_type_ids = Tensor::new(tokens.get_type_ids(), &self.device)
216 .map_err(|e| GraphRAGError::Embedding {
217 message: format!("Type tensor creation failed: {}", e),
218 })?
219 .unsqueeze(0)
220 .map_err(|_| GraphRAGError::Embedding {
221 message: "Unsqueeze failed".to_string(),
222 })?;
223
224 let logits = self
225 .model
226 .forward(&token_ids, &token_type_ids)
227 .map_err(|e| GraphRAGError::Embedding {
228 message: format!("Forward pass failed: {}", e),
229 })?;
230
231 let score = logits
235 .squeeze(0)
236 .map_err(|_| GraphRAGError::Embedding {
237 message: "Squeeze failed".to_string(),
238 })?
239 .to_vec1::<f32>()
240 .map_err(|e| GraphRAGError::Embedding {
241 message: format!("To vec failed: {}", e),
242 })?;
243
244 let raw_score = score[0];
246
247 if self.config.normalize_scores {
248 Ok(1.0 / (1.0 + (-raw_score).exp()))
249 } else {
250 Ok(raw_score)
251 }
252 }
253
254 async fn score_batch(&self, pairs: Vec<(String, String)>) -> Result<Vec<f32>> {
255 let mut scores = Vec::new();
256 for (q, d) in pairs {
257 scores.push(self.score_pair(&q, &d).await?);
258 }
259 Ok(scores)
260 }
261}
262
263#[derive(Debug, Clone)]
265pub struct RerankingStats {
266 pub candidates_count: usize,
268
269 pub results_count: usize,
271
272 pub avg_score_improvement: f32,
274
275 pub max_score_improvement: f32,
277
278 pub filter_rate: f32,
280}
281
282impl RerankingStats {
283 pub fn from_results(original_count: usize, ranked: &[RankedResult]) -> Self {
285 let results_count = ranked.len();
286
287 let avg_score_improvement = if !ranked.is_empty() {
288 ranked.iter().map(|r| r.score_delta).sum::<f32>() / ranked.len() as f32
289 } else {
290 0.0
291 };
292
293 let max_score_improvement = ranked
294 .iter()
295 .map(|r| r.score_delta)
296 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
297 .unwrap_or(0.0);
298
299 let filter_rate = if original_count > 0 {
300 ((original_count - results_count) as f32 / original_count as f32) * 100.0
301 } else {
302 0.0
303 };
304
305 Self {
306 candidates_count: original_count,
307 results_count,
308 avg_score_improvement,
309 max_score_improvement,
310 filter_rate,
311 }
312 }
313}
314
315pub struct ConfidenceCrossEncoder {
317 _config: CrossEncoderConfig,
318}
319
320impl ConfidenceCrossEncoder {
321 pub fn new(config: CrossEncoderConfig) -> Self {
323 Self { _config: config }
324 }
325}
326
327#[async_trait]
328impl CrossEncoder for ConfidenceCrossEncoder {
329 async fn rerank(
330 &self,
331 _query: &str,
332 candidates: Vec<SearchResult>,
333 ) -> Result<Vec<RankedResult>> {
334 let mut ranked = Vec::new();
336 for candidate in candidates {
337 ranked.push(RankedResult {
338 result: candidate.clone(),
339 relevance_score: candidate.score,
340 original_score: candidate.score,
341 score_delta: 0.0,
342 });
343 }
344 Ok(ranked)
345 }
346
347 async fn score_pair(&self, _query: &str, _document: &str) -> Result<f32> {
348 Ok(0.0)
349 }
350
351 async fn score_batch(&self, pairs: Vec<(String, String)>) -> Result<Vec<f32>> {
352 Ok(vec![0.0; pairs.len()])
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359 use crate::retrieval::ResultType;
360
361 fn create_test_result(id: &str, content: &str, score: f32) -> SearchResult {
362 SearchResult {
363 id: id.to_string(),
364 content: content.to_string(),
365 score,
366 result_type: ResultType::Chunk,
367 entities: Vec::new(),
368 source_chunks: Vec::new(),
369 }
370 }
371
372 #[tokio::test]
373 async fn test_rerank_basic() {
374 let config = CrossEncoderConfig {
375 top_k: 3,
376 min_confidence: 0.0,
377 ..Default::default()
378 };
379
380 let encoder = ConfidenceCrossEncoder::new(config);
381
382 let query = "machine learning algorithms";
383 let candidates = vec![
384 create_test_result(
385 "1",
386 "Machine learning is a subset of artificial intelligence",
387 0.5,
388 ),
389 create_test_result("2", "The weather today is sunny", 0.6),
390 create_test_result(
391 "3",
392 "Neural networks are machine learning algorithms used for pattern recognition",
393 0.4,
394 ),
395 ];
396
397 let ranked = encoder.rerank(query, candidates).await.unwrap();
398
399 assert_eq!(ranked.len(), 3);
401
402 assert!(ranked[0].relevance_score >= ranked[1].relevance_score);
404 assert!(ranked[1].relevance_score >= ranked[2].relevance_score);
405 }
406
407 #[tokio::test]
408 async fn test_confidence_filtering() {
409 let config = CrossEncoderConfig {
410 top_k: 10,
411 min_confidence: 0.5, ..Default::default()
413 };
414
415 let encoder = ConfidenceCrossEncoder::new(config);
416
417 let query = "specific technical query";
418 let candidates = vec![
419 create_test_result("1", "highly relevant technical content", 0.3),
420 create_test_result("2", "somewhat relevant", 0.4),
421 create_test_result("3", "not relevant at all", 0.5),
422 ];
423
424 let ranked = encoder.rerank(query, candidates).await.unwrap();
425
426 for result in &ranked {
428 assert!(result.relevance_score >= 0.5);
429 }
430 }
431
432 #[tokio::test]
433 async fn test_score_pair() {
434 let config = CrossEncoderConfig::default();
435 let encoder = ConfidenceCrossEncoder::new(config);
436
437 let score = encoder
438 .score_pair(
439 "artificial intelligence",
440 "AI and machine learning are related fields",
441 )
442 .await
443 .unwrap();
444
445 assert!(score >= 0.0 && score <= 1.0);
446 }
447
448 #[test]
449 fn test_reranking_stats() {
450 let ranked = vec![
451 RankedResult {
452 result: create_test_result("1", "test", 0.5),
453 relevance_score: 0.8,
454 original_score: 0.5,
455 score_delta: 0.3,
456 },
457 RankedResult {
458 result: create_test_result("2", "test", 0.6),
459 relevance_score: 0.7,
460 original_score: 0.6,
461 score_delta: 0.1,
462 },
463 ];
464
465 let stats = RerankingStats::from_results(5, &ranked);
466
467 assert_eq!(stats.candidates_count, 5);
468 assert_eq!(stats.results_count, 2);
469 assert!((stats.filter_rate - 60.0).abs() < 0.001); assert!(stats.avg_score_improvement > 0.0);
472 }
473}