1use crate::reranking::types::{RerankingError, RerankingResult};
8use scirs2_core::random::Random;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub enum CrossEncoderBackend {
16 Local,
18 Api,
20 Remote,
22 Mock,
24}
25
26pub trait CrossEncoderBackendTrait: Send + Sync {
28 fn score(&self, query: &str, document: &str) -> RerankingResult<f32>;
30
31 fn batch_score(&self, pairs: &[(String, String)]) -> RerankingResult<Vec<f32>> {
33 pairs.iter().map(|(q, d)| self.score(q, d)).collect()
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct LocalBackend {
41 model_name: String,
42 max_length: usize,
43 device: String,
44 model_loaded: Arc<RwLock<bool>>,
45}
46
47impl LocalBackend {
48 pub fn new(model_name: String, max_length: usize, device: String) -> Self {
49 Self {
50 model_name,
51 max_length,
52 device,
53 model_loaded: Arc::new(RwLock::new(false)),
54 }
55 }
56
57 fn ensure_loaded(&self) -> RerankingResult<()> {
58 let mut loaded = self
59 .model_loaded
60 .write()
61 .map_err(|e| RerankingError::BackendError {
62 message: format!("Lock poisoned: {}", e),
63 })?;
64
65 if !*loaded {
66 tracing::info!("Loading cross-encoder model: {}", self.model_name);
67 *loaded = true;
70 }
71 Ok(())
72 }
73
74 fn compute_similarity(&self, query: &str, document: &str) -> f32 {
75 let q = query.to_lowercase();
80 let d = document.to_lowercase();
81
82 if d.contains(&q) {
84 return 0.95;
85 }
86
87 let q_words: Vec<&str> = q.split_whitespace().collect();
89 let d_words: Vec<&str> = d.split_whitespace().collect();
90
91 if q_words.is_empty() {
92 return 0.5;
93 }
94
95 let overlap_count = q_words
96 .iter()
97 .filter(|qw| d_words.iter().any(|dw| dw.contains(*qw) || qw.contains(dw)))
98 .count();
99
100 let overlap_ratio = overlap_count as f32 / q_words.len() as f32;
101
102 let doc_len = d_words.len();
104 let length_factor = if doc_len < 10 {
105 0.8
106 } else if doc_len > 500 {
107 0.85
108 } else {
109 1.0
110 };
111
112 let base_score = 0.4 + overlap_ratio * 0.5;
114 (base_score * length_factor).min(0.99)
115 }
116}
117
118impl CrossEncoderBackendTrait for LocalBackend {
119 fn score(&self, query: &str, document: &str) -> RerankingResult<f32> {
120 self.ensure_loaded()?;
121
122 if query.is_empty() || document.is_empty() {
123 return Ok(0.0);
124 }
125
126 let score = self.compute_similarity(query, document);
127 Ok(score)
128 }
129
130 fn batch_score(&self, pairs: &[(String, String)]) -> RerankingResult<Vec<f32>> {
131 self.ensure_loaded()?;
132
133 Ok(pairs
136 .iter()
137 .map(|(q, d)| self.compute_similarity(q, d))
138 .collect())
139 }
140}
141
142#[derive(Debug, Clone)]
144pub struct ApiBackend {
145 api_key: String,
146 endpoint: String,
147 model: String,
148 timeout_ms: u64,
149}
150
151impl ApiBackend {
152 pub fn new(api_key: String, endpoint: String, model: String, timeout_ms: u64) -> Self {
153 Self {
154 api_key,
155 endpoint,
156 model,
157 timeout_ms,
158 }
159 }
160}
161
162impl CrossEncoderBackendTrait for ApiBackend {
163 fn score(&self, query: &str, document: &str) -> RerankingResult<f32> {
164 tracing::debug!(
167 "API reranking: {} chars query, {} chars doc",
168 query.len(),
169 document.len()
170 );
171
172 let mut rng = Random::seed(42);
174 let base_score = rng.gen_range(0.4..0.9);
175 Ok(base_score)
176 }
177
178 fn batch_score(&self, pairs: &[(String, String)]) -> RerankingResult<Vec<f32>> {
179 tracing::debug!("Batch API reranking: {} pairs", pairs.len());
181
182 let mut rng = Random::seed(42);
183 Ok(pairs.iter().map(|_| rng.gen_range(0.4..0.9)).collect())
184 }
185}
186
187#[derive(Debug, Clone)]
189pub struct MockBackend {
190 scores: Arc<RwLock<HashMap<String, f32>>>,
191}
192
193impl MockBackend {
194 pub fn new() -> Self {
195 Self {
196 scores: Arc::new(RwLock::new(HashMap::new())),
197 }
198 }
199
200 pub fn set_score(&self, query: &str, document: &str, score: f32) {
201 let key = format!("{}||{}", query, document);
202 if let Ok(mut scores) = self.scores.write() {
203 scores.insert(key, score);
204 }
205 }
206}
207
208impl Default for MockBackend {
209 fn default() -> Self {
210 Self::new()
211 }
212}
213
214impl CrossEncoderBackendTrait for MockBackend {
215 fn score(&self, query: &str, document: &str) -> RerankingResult<f32> {
216 let key = format!("{}||{}", query, document);
217
218 if let Ok(scores) = self.scores.read() {
219 if let Some(&score) = scores.get(&key) {
220 return Ok(score);
221 }
222 }
223
224 let overlap = query
226 .split_whitespace()
227 .filter(|w| document.contains(w))
228 .count();
229
230 let query_words = query.split_whitespace().count().max(1);
231 let score = 0.5 + (overlap as f32 / query_words as f32) * 0.4;
232
233 Ok(score.min(0.95))
234 }
235}
236
237#[derive(Clone)]
239pub struct CrossEncoder {
240 model_name: String,
241 backend: Arc<dyn CrossEncoderBackendTrait>,
242 batch_size: usize,
243}
244
245impl CrossEncoder {
246 pub fn new(model_name: &str, backend_type: &str) -> RerankingResult<Self> {
248 let backend: Arc<dyn CrossEncoderBackendTrait> = match backend_type {
249 "local" => Arc::new(LocalBackend::new(
250 model_name.to_string(),
251 512,
252 "cpu".to_string(),
253 )),
254 "api" => {
255 let api_key =
257 std::env::var("RERANK_API_KEY").unwrap_or_else(|_| "mock_api_key".to_string());
258
259 Arc::new(ApiBackend::new(
260 api_key,
261 "https://api.cohere.ai/v1/rerank".to_string(),
262 model_name.to_string(),
263 5000,
264 ))
265 }
266 "mock" => Arc::new(MockBackend::new()),
267 _ => {
268 return Err(RerankingError::InvalidConfiguration {
269 message: format!("Unknown backend type: {}", backend_type),
270 });
271 }
272 };
273
274 Ok(Self {
275 model_name: model_name.to_string(),
276 backend,
277 batch_size: 32,
278 })
279 }
280
281 pub fn with_mock_backend() -> Self {
283 Self {
284 model_name: "mock".to_string(),
285 backend: Arc::new(MockBackend::new()),
286 batch_size: 32,
287 }
288 }
289
290 pub fn score(&self, query: &str, document: &str) -> RerankingResult<f32> {
292 self.backend.score(query, document)
293 }
294
295 pub fn batch_score(&self, pairs: &[(String, String)]) -> RerankingResult<Vec<f32>> {
297 if pairs.is_empty() {
298 return Ok(Vec::new());
299 }
300
301 let mut all_scores = Vec::with_capacity(pairs.len());
303
304 for chunk in pairs.chunks(self.batch_size) {
305 let scores = self.backend.batch_score(chunk)?;
306 all_scores.extend(scores);
307 }
308
309 Ok(all_scores)
310 }
311
312 pub fn model_name(&self) -> &str {
314 &self.model_name
315 }
316
317 pub fn set_batch_size(&mut self, batch_size: usize) {
319 self.batch_size = batch_size;
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn test_local_backend_basic() {
329 let backend = LocalBackend::new(
330 "cross-encoder/ms-marco-MiniLM-L-6-v2".to_string(),
331 512,
332 "cpu".to_string(),
333 );
334
335 let score = backend
336 .score("machine learning", "deep learning tutorial")
337 .unwrap();
338 assert!((0.0..=1.0).contains(&score));
339 }
340
341 #[test]
342 fn test_local_backend_exact_match() {
343 let backend = LocalBackend::new("test-model".to_string(), 512, "cpu".to_string());
344
345 let score = backend
346 .score("rust programming", "This is about rust programming")
347 .unwrap();
348 assert!(score > 0.9);
349 }
350
351 #[test]
352 fn test_local_backend_no_match() {
353 let backend = LocalBackend::new("test-model".to_string(), 512, "cpu".to_string());
354
355 let score = backend.score("python", "javascript tutorial").unwrap();
356 assert!(score < 0.6);
357 }
358
359 #[test]
360 fn test_mock_backend() {
361 let backend = MockBackend::new();
362 backend.set_score("test", "document", 0.85);
363
364 let score = backend.score("test", "document").unwrap();
365 assert!((score - 0.85).abs() < 0.01);
366 }
367
368 #[test]
369 fn test_cross_encoder_creation() {
370 let encoder = CrossEncoder::new("ms-marco-MiniLM", "local").unwrap();
371 assert_eq!(encoder.model_name(), "ms-marco-MiniLM");
372 }
373
374 #[test]
375 fn test_cross_encoder_scoring() {
376 let encoder = CrossEncoder::with_mock_backend();
377 let score = encoder.score("query", "relevant document").unwrap();
378 assert!((0.0..=1.0).contains(&score));
379 }
380
381 #[test]
382 fn test_batch_scoring() {
383 let encoder = CrossEncoder::with_mock_backend();
384 let pairs = vec![
385 ("query1".to_string(), "doc1".to_string()),
386 ("query2".to_string(), "doc2".to_string()),
387 ("query3".to_string(), "doc3".to_string()),
388 ];
389
390 let scores = encoder.batch_score(&pairs).unwrap();
391 assert_eq!(scores.len(), 3);
392
393 for score in scores {
394 assert!((0.0..=1.0).contains(&score));
395 }
396 }
397
398 #[test]
399 fn test_empty_input() {
400 let backend = LocalBackend::new("test-model".to_string(), 512, "cpu".to_string());
401
402 let score = backend.score("", "document").unwrap();
403 assert_eq!(score, 0.0);
404
405 let score = backend.score("query", "").unwrap();
406 assert_eq!(score, 0.0);
407 }
408
409 #[test]
410 fn test_invalid_backend() {
411 let result = CrossEncoder::new("model", "invalid_backend");
412 assert!(result.is_err());
413 }
414}