oxirs_vec/reranking/
cross_encoder.rs

1//! Cross-encoder model wrapper for query-document relevance scoring
2//!
3//! Cross-encoders jointly encode query and document pairs to produce
4//! accurate relevance scores. Unlike bi-encoders, they see both inputs
5//! together, enabling fine-grained relevance modeling at higher computational cost.
6
7use crate::reranking::types::{RerankingError, RerankingResult};
8use scirs2_core::random::Random;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12
13/// Backend for cross-encoder inference
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub enum CrossEncoderBackend {
16    /// Local model inference (PyTorch/ONNX)
17    Local,
18    /// API-based inference (OpenAI, Cohere, etc.)
19    Api,
20    /// Remote inference server
21    Remote,
22    /// Mock backend for testing
23    Mock,
24}
25
26/// Trait for cross-encoder backends
27pub trait CrossEncoderBackendTrait: Send + Sync {
28    /// Score a single query-document pair
29    fn score(&self, query: &str, document: &str) -> RerankingResult<f32>;
30
31    /// Score multiple query-document pairs in batch
32    fn batch_score(&self, pairs: &[(String, String)]) -> RerankingResult<Vec<f32>> {
33        // Default implementation: score one by one
34        pairs.iter().map(|(q, d)| self.score(q, d)).collect()
35    }
36}
37
38/// Local model backend using sentence transformers
39#[derive(Debug, Clone)]
40pub struct LocalBackend {
41    model_name: String,
42    max_length: usize,
43    device: String,
44    model_loaded: Arc<RwLock<bool>>,
45}
46
47impl LocalBackend {
48    pub fn new(model_name: String, max_length: usize, device: String) -> Self {
49        Self {
50            model_name,
51            max_length,
52            device,
53            model_loaded: Arc::new(RwLock::new(false)),
54        }
55    }
56
57    fn ensure_loaded(&self) -> RerankingResult<()> {
58        let mut loaded = self
59            .model_loaded
60            .write()
61            .map_err(|e| RerankingError::BackendError {
62                message: format!("Lock poisoned: {}", e),
63            })?;
64
65        if !*loaded {
66            tracing::info!("Loading cross-encoder model: {}", self.model_name);
67            // In real implementation: load model using tokenizers + PyTorch/ONNX
68            // For now, mark as loaded
69            *loaded = true;
70        }
71        Ok(())
72    }
73
74    fn compute_similarity(&self, query: &str, document: &str) -> f32 {
75        // Simplified similarity computation
76        // Real implementation would use transformer model inference
77
78        // Normalize texts
79        let q = query.to_lowercase();
80        let d = document.to_lowercase();
81
82        // Exact match bonus
83        if d.contains(&q) {
84            return 0.95;
85        }
86
87        // Word overlap score
88        let q_words: Vec<&str> = q.split_whitespace().collect();
89        let d_words: Vec<&str> = d.split_whitespace().collect();
90
91        if q_words.is_empty() {
92            return 0.5;
93        }
94
95        let overlap_count = q_words
96            .iter()
97            .filter(|qw| d_words.iter().any(|dw| dw.contains(*qw) || qw.contains(dw)))
98            .count();
99
100        let overlap_ratio = overlap_count as f32 / q_words.len() as f32;
101
102        // Length penalty for very short or very long documents
103        let doc_len = d_words.len();
104        let length_factor = if doc_len < 10 {
105            0.8
106        } else if doc_len > 500 {
107            0.85
108        } else {
109            1.0
110        };
111
112        // Combine scores
113        let base_score = 0.4 + overlap_ratio * 0.5;
114        (base_score * length_factor).min(0.99)
115    }
116}
117
118impl CrossEncoderBackendTrait for LocalBackend {
119    fn score(&self, query: &str, document: &str) -> RerankingResult<f32> {
120        self.ensure_loaded()?;
121
122        if query.is_empty() || document.is_empty() {
123            return Ok(0.0);
124        }
125
126        let score = self.compute_similarity(query, document);
127        Ok(score)
128    }
129
130    fn batch_score(&self, pairs: &[(String, String)]) -> RerankingResult<Vec<f32>> {
131        self.ensure_loaded()?;
132
133        // Real implementation would batch through model inference
134        // For now, process individually
135        Ok(pairs
136            .iter()
137            .map(|(q, d)| self.compute_similarity(q, d))
138            .collect())
139    }
140}
141
142/// API-based backend (e.g., Cohere Rerank API)
143#[derive(Debug, Clone)]
144pub struct ApiBackend {
145    api_key: String,
146    endpoint: String,
147    model: String,
148    timeout_ms: u64,
149}
150
151impl ApiBackend {
152    pub fn new(api_key: String, endpoint: String, model: String, timeout_ms: u64) -> Self {
153        Self {
154            api_key,
155            endpoint,
156            model,
157            timeout_ms,
158        }
159    }
160}
161
162impl CrossEncoderBackendTrait for ApiBackend {
163    fn score(&self, query: &str, document: &str) -> RerankingResult<f32> {
164        // In real implementation: make API call to reranking service
165        // For now, return mock score
166        tracing::debug!(
167            "API reranking: {} chars query, {} chars doc",
168            query.len(),
169            document.len()
170        );
171
172        // Simulate API delay with random jitter
173        let mut rng = Random::seed(42);
174        let base_score = rng.gen_range(0.4..0.9);
175        Ok(base_score)
176    }
177
178    fn batch_score(&self, pairs: &[(String, String)]) -> RerankingResult<Vec<f32>> {
179        // Real implementation would batch API call
180        tracing::debug!("Batch API reranking: {} pairs", pairs.len());
181
182        let mut rng = Random::seed(42);
183        Ok(pairs.iter().map(|_| rng.gen_range(0.4..0.9)).collect())
184    }
185}
186
187/// Mock backend for testing
188#[derive(Debug, Clone)]
189pub struct MockBackend {
190    scores: Arc<RwLock<HashMap<String, f32>>>,
191}
192
193impl MockBackend {
194    pub fn new() -> Self {
195        Self {
196            scores: Arc::new(RwLock::new(HashMap::new())),
197        }
198    }
199
200    pub fn set_score(&self, query: &str, document: &str, score: f32) {
201        let key = format!("{}||{}", query, document);
202        if let Ok(mut scores) = self.scores.write() {
203            scores.insert(key, score);
204        }
205    }
206}
207
208impl Default for MockBackend {
209    fn default() -> Self {
210        Self::new()
211    }
212}
213
214impl CrossEncoderBackendTrait for MockBackend {
215    fn score(&self, query: &str, document: &str) -> RerankingResult<f32> {
216        let key = format!("{}||{}", query, document);
217
218        if let Ok(scores) = self.scores.read() {
219            if let Some(&score) = scores.get(&key) {
220                return Ok(score);
221            }
222        }
223
224        // Default mock score based on text overlap
225        let overlap = query
226            .split_whitespace()
227            .filter(|w| document.contains(w))
228            .count();
229
230        let query_words = query.split_whitespace().count().max(1);
231        let score = 0.5 + (overlap as f32 / query_words as f32) * 0.4;
232
233        Ok(score.min(0.95))
234    }
235}
236
237/// Cross-encoder model for relevance scoring
238#[derive(Clone)]
239pub struct CrossEncoder {
240    model_name: String,
241    backend: Arc<dyn CrossEncoderBackendTrait>,
242    batch_size: usize,
243}
244
245impl CrossEncoder {
246    /// Create new cross-encoder with specified backend
247    pub fn new(model_name: &str, backend_type: &str) -> RerankingResult<Self> {
248        let backend: Arc<dyn CrossEncoderBackendTrait> = match backend_type {
249            "local" => Arc::new(LocalBackend::new(
250                model_name.to_string(),
251                512,
252                "cpu".to_string(),
253            )),
254            "api" => {
255                // Read API key from environment
256                let api_key =
257                    std::env::var("RERANK_API_KEY").unwrap_or_else(|_| "mock_api_key".to_string());
258
259                Arc::new(ApiBackend::new(
260                    api_key,
261                    "https://api.cohere.ai/v1/rerank".to_string(),
262                    model_name.to_string(),
263                    5000,
264                ))
265            }
266            "mock" => Arc::new(MockBackend::new()),
267            _ => {
268                return Err(RerankingError::InvalidConfiguration {
269                    message: format!("Unknown backend type: {}", backend_type),
270                });
271            }
272        };
273
274        Ok(Self {
275            model_name: model_name.to_string(),
276            backend,
277            batch_size: 32,
278        })
279    }
280
281    /// Create with mock backend for testing
282    pub fn with_mock_backend() -> Self {
283        Self {
284            model_name: "mock".to_string(),
285            backend: Arc::new(MockBackend::new()),
286            batch_size: 32,
287        }
288    }
289
290    /// Score a single query-document pair
291    pub fn score(&self, query: &str, document: &str) -> RerankingResult<f32> {
292        self.backend.score(query, document)
293    }
294
295    /// Score multiple query-document pairs in batch
296    pub fn batch_score(&self, pairs: &[(String, String)]) -> RerankingResult<Vec<f32>> {
297        if pairs.is_empty() {
298            return Ok(Vec::new());
299        }
300
301        // Process in batches
302        let mut all_scores = Vec::with_capacity(pairs.len());
303
304        for chunk in pairs.chunks(self.batch_size) {
305            let scores = self.backend.batch_score(chunk)?;
306            all_scores.extend(scores);
307        }
308
309        Ok(all_scores)
310    }
311
312    /// Get model name
313    pub fn model_name(&self) -> &str {
314        &self.model_name
315    }
316
317    /// Set batch size
318    pub fn set_batch_size(&mut self, batch_size: usize) {
319        self.batch_size = batch_size;
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    #[test]
328    fn test_local_backend_basic() {
329        let backend = LocalBackend::new(
330            "cross-encoder/ms-marco-MiniLM-L-6-v2".to_string(),
331            512,
332            "cpu".to_string(),
333        );
334
335        let score = backend
336            .score("machine learning", "deep learning tutorial")
337            .unwrap();
338        assert!((0.0..=1.0).contains(&score));
339    }
340
341    #[test]
342    fn test_local_backend_exact_match() {
343        let backend = LocalBackend::new("test-model".to_string(), 512, "cpu".to_string());
344
345        let score = backend
346            .score("rust programming", "This is about rust programming")
347            .unwrap();
348        assert!(score > 0.9);
349    }
350
351    #[test]
352    fn test_local_backend_no_match() {
353        let backend = LocalBackend::new("test-model".to_string(), 512, "cpu".to_string());
354
355        let score = backend.score("python", "javascript tutorial").unwrap();
356        assert!(score < 0.6);
357    }
358
359    #[test]
360    fn test_mock_backend() {
361        let backend = MockBackend::new();
362        backend.set_score("test", "document", 0.85);
363
364        let score = backend.score("test", "document").unwrap();
365        assert!((score - 0.85).abs() < 0.01);
366    }
367
368    #[test]
369    fn test_cross_encoder_creation() {
370        let encoder = CrossEncoder::new("ms-marco-MiniLM", "local").unwrap();
371        assert_eq!(encoder.model_name(), "ms-marco-MiniLM");
372    }
373
374    #[test]
375    fn test_cross_encoder_scoring() {
376        let encoder = CrossEncoder::with_mock_backend();
377        let score = encoder.score("query", "relevant document").unwrap();
378        assert!((0.0..=1.0).contains(&score));
379    }
380
381    #[test]
382    fn test_batch_scoring() {
383        let encoder = CrossEncoder::with_mock_backend();
384        let pairs = vec![
385            ("query1".to_string(), "doc1".to_string()),
386            ("query2".to_string(), "doc2".to_string()),
387            ("query3".to_string(), "doc3".to_string()),
388        ];
389
390        let scores = encoder.batch_score(&pairs).unwrap();
391        assert_eq!(scores.len(), 3);
392
393        for score in scores {
394            assert!((0.0..=1.0).contains(&score));
395        }
396    }
397
398    #[test]
399    fn test_empty_input() {
400        let backend = LocalBackend::new("test-model".to_string(), 512, "cpu".to_string());
401
402        let score = backend.score("", "document").unwrap();
403        assert_eq!(score, 0.0);
404
405        let score = backend.score("query", "").unwrap();
406        assert_eq!(score, 0.0);
407    }
408
409    #[test]
410    fn test_invalid_backend() {
411        let result = CrossEncoder::new("model", "invalid_backend");
412        assert!(result.is_err());
413    }
414}