Skip to main content

oxirs_embed/
cross_encoder.rs

1//! Cross-encoder re-ranker for embedding search results.
2//!
3//! Implements a lightweight simulation of cross-encoder scoring using
4//! token-overlap (Jaccard) similarity.  In a production system the `score`
5//! function would call a transformer model; here it is kept deterministic and
6//! dependency-free for testing purposes.
7
8use std::collections::HashSet;
9
10// ---------------------------------------------------------------------------
11// Public types
12// ---------------------------------------------------------------------------
13
14/// A (query, document) pair submitted for re-ranking, together with the
15/// initial retrieval score produced by an upstream embedding model.
16#[derive(Debug, Clone)]
17pub struct CandidatePair {
18    /// The user query.
19    pub query: String,
20    /// The candidate document / passage.
21    pub document: String,
22    /// Score produced by the first-stage retrieval model.
23    pub initial_score: f32,
24}
25
26/// The outcome of re-ranking a single candidate document.
27#[derive(Debug, Clone)]
28pub struct RerankResult {
29    /// The candidate document text.
30    pub document: String,
31    /// Score from the first-stage retrieval model.
32    pub initial_score: f32,
33    /// Score assigned by the cross-encoder.
34    pub rerank_score: f32,
35    /// 1-based rank in the sorted result list (lower is better).
36    pub rank: usize,
37}
38
39/// Configuration for a `CrossEncoder` instance.
40#[derive(Debug, Clone)]
41pub struct CrossEncoderConfig {
42    /// Maximum token length (currently advisory; ignored in the simulation).
43    pub max_length: usize,
44    /// When `true`, scores are min-max normalised to `[0, 1]` before being
45    /// returned.
46    pub normalize_scores: bool,
47    /// How many pairs to process per batch.
48    pub batch_size: usize,
49}
50
51impl Default for CrossEncoderConfig {
52    fn default() -> Self {
53        CrossEncoderConfig {
54            max_length: 512,
55            normalize_scores: false,
56            batch_size: 32,
57        }
58    }
59}
60
61/// Stateful cross-encoder that tracks the total number of pairs scored.
62pub struct CrossEncoder {
63    config: CrossEncoderConfig,
64    total_scored: u64,
65}
66
67impl CrossEncoder {
68    /// Create a new `CrossEncoder` with the given configuration.
69    pub fn new(config: CrossEncoderConfig) -> Self {
70        CrossEncoder {
71            config,
72            total_scored: 0,
73        }
74    }
75
76    /// Score a single `(query, document)` pair using token-overlap similarity.
77    pub fn score(&mut self, pair: &CandidatePair) -> f32 {
78        self.total_scored += 1;
79        token_overlap_score(&pair.query, &pair.document)
80    }
81
82    /// Score multiple pairs at once.  Respects `batch_size` but processes
83    /// sequentially in the simulation (no concurrency overhead).
84    pub fn score_batch(&mut self, pairs: &[CandidatePair]) -> Vec<f32> {
85        pairs.iter().map(|p| self.score(p)).collect()
86    }
87
88    /// Re-rank `candidates` against `query`.
89    ///
90    /// Returns a list of [`RerankResult`] sorted by `rerank_score` descending,
91    /// with `rank` fields assigned from 1.
92    pub fn rerank(
93        &mut self,
94        query: &str,
95        candidates: &[String],
96        initial_scores: &[f32],
97    ) -> Vec<RerankResult> {
98        let n = candidates.len().min(initial_scores.len());
99        let pairs: Vec<CandidatePair> = (0..n)
100            .map(|i| CandidatePair {
101                query: query.to_string(),
102                document: candidates[i].clone(),
103                initial_score: initial_scores[i],
104            })
105            .collect();
106
107        let mut raw_scores = self.score_batch(&pairs);
108
109        if self.config.normalize_scores {
110            raw_scores = normalize_scores(&raw_scores);
111        }
112
113        let mut results: Vec<RerankResult> = (0..n)
114            .map(|i| RerankResult {
115                document: candidates[i].clone(),
116                initial_score: initial_scores[i],
117                rerank_score: raw_scores[i],
118                rank: 0, // filled below
119            })
120            .collect();
121
122        // Sort descending by rerank_score.
123        results.sort_by(|a, b| {
124            b.rerank_score
125                .partial_cmp(&a.rerank_score)
126                .unwrap_or(std::cmp::Ordering::Equal)
127        });
128
129        // Assign 1-based ranks.
130        for (idx, r) in results.iter_mut().enumerate() {
131            r.rank = idx + 1;
132        }
133
134        results
135    }
136
137    /// Like `rerank` but truncates the output to the top-`k` results.
138    pub fn top_k(
139        &mut self,
140        query: &str,
141        candidates: &[String],
142        initial_scores: &[f32],
143        k: usize,
144    ) -> Vec<RerankResult> {
145        let mut all = self.rerank(query, candidates, initial_scores);
146        all.truncate(k);
147        all
148    }
149
150    /// Total number of individual pairs that have been scored so far.
151    pub fn total_scored(&self) -> u64 {
152        self.total_scored
153    }
154}
155
156// ---------------------------------------------------------------------------
157// Internal helpers (also exposed for tests)
158// ---------------------------------------------------------------------------
159
160/// Jaccard similarity over whitespace-tokenised sets.
161///
162/// Returns `1.0` for identical strings (even empty ones) and `0.0` for
163/// completely disjoint token sets.
164pub(crate) fn token_overlap_score(a: &str, b: &str) -> f32 {
165    let set_a: HashSet<&str> = a.split_whitespace().collect();
166    let set_b: HashSet<&str> = b.split_whitespace().collect();
167
168    if set_a.is_empty() && set_b.is_empty() {
169        // Both empty → treat as identical.
170        return 1.0;
171    }
172
173    let intersection = set_a.intersection(&set_b).count();
174    let union = set_a.union(&set_b).count();
175    if union == 0 {
176        0.0
177    } else {
178        intersection as f32 / union as f32
179    }
180}
181
182/// Min-max normalise a slice of scores to `[0, 1]`.
183///
184/// If all scores are equal the output is all-zeros (undefined range).
185pub(crate) fn normalize_scores(scores: &[f32]) -> Vec<f32> {
186    if scores.is_empty() {
187        return Vec::new();
188    }
189
190    let min = scores.iter().copied().fold(f32::MAX, f32::min);
191    let max = scores.iter().copied().fold(f32::MIN, f32::max);
192
193    let range = max - min;
194    if range == 0.0 {
195        return vec![0.0; scores.len()];
196    }
197    scores.iter().map(|&s| (s - min) / range).collect()
198}
199
200// ---------------------------------------------------------------------------
201// Tests
202// ---------------------------------------------------------------------------
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    fn default_encoder() -> CrossEncoder {
209        CrossEncoder::new(CrossEncoderConfig::default())
210    }
211
212    fn norming_encoder() -> CrossEncoder {
213        CrossEncoder::new(CrossEncoderConfig {
214            normalize_scores: true,
215            ..CrossEncoderConfig::default()
216        })
217    }
218
219    // --- token_overlap_score ---
220
221    #[test]
222    fn test_token_overlap_identical_strings() {
223        let score = token_overlap_score("the quick brown fox", "the quick brown fox");
224        assert!(
225            (score - 1.0).abs() < 1e-6,
226            "identical strings should score 1.0"
227        );
228    }
229
230    #[test]
231    fn test_token_overlap_disjoint_strings() {
232        let score = token_overlap_score("apple orange", "banana grape");
233        assert!(
234            (score - 0.0).abs() < 1e-6,
235            "disjoint strings should score 0.0"
236        );
237    }
238
239    #[test]
240    fn test_token_overlap_partial_match() {
241        let score = token_overlap_score("the fox", "the cat");
242        // intersection={the}, union={the,fox,cat} → 1/3
243        assert!((score - 1.0 / 3.0).abs() < 1e-5);
244    }
245
246    #[test]
247    fn test_token_overlap_both_empty() {
248        let score = token_overlap_score("", "");
249        assert!((score - 1.0).abs() < 1e-6);
250    }
251
252    #[test]
253    fn test_token_overlap_one_empty() {
254        let score = token_overlap_score("hello", "");
255        assert!((score - 0.0).abs() < 1e-6);
256    }
257
258    // --- normalize_scores ---
259
260    #[test]
261    fn test_normalize_scores_range() {
262        let scores = vec![0.1f32, 0.5, 0.9];
263        let norm = normalize_scores(&scores);
264        // All normalised values must lie in [0, 1].
265        for &v in &norm {
266            assert!(v >= 0.0, "normalised value {v} is below 0");
267            assert!(v <= 1.0, "normalised value {v} is above 1");
268        }
269    }
270
271    #[test]
272    fn test_normalize_scores_min_is_zero() {
273        let scores = vec![2.0f32, 4.0, 6.0];
274        let norm = normalize_scores(&scores);
275        assert!((norm[0] - 0.0).abs() < 1e-6);
276    }
277
278    #[test]
279    fn test_normalize_scores_max_is_one() {
280        let scores = vec![2.0f32, 4.0, 6.0];
281        let norm = normalize_scores(&scores);
282        assert!((norm[2] - 1.0).abs() < 1e-6);
283    }
284
285    #[test]
286    fn test_normalize_scores_all_equal() {
287        let scores = vec![3.0f32, 3.0, 3.0];
288        let norm = normalize_scores(&scores);
289        assert!(norm.iter().all(|&v| v == 0.0));
290    }
291
292    #[test]
293    fn test_normalize_scores_empty() {
294        let norm = normalize_scores(&[]);
295        assert!(norm.is_empty());
296    }
297
298    // --- CrossEncoder::score ---
299
300    #[test]
301    fn test_score_identical() {
302        let mut enc = default_encoder();
303        let pair = CandidatePair {
304            query: "foo bar".into(),
305            document: "foo bar".into(),
306            initial_score: 0.9,
307        };
308        let s = enc.score(&pair);
309        assert!((s - 1.0).abs() < 1e-6);
310    }
311
312    #[test]
313    fn test_score_disjoint() {
314        let mut enc = default_encoder();
315        let pair = CandidatePair {
316            query: "apple".into(),
317            document: "banana".into(),
318            initial_score: 0.1,
319        };
320        let s = enc.score(&pair);
321        assert!((s - 0.0).abs() < 1e-6);
322    }
323
324    #[test]
325    fn test_score_increments_total_scored() {
326        let mut enc = default_encoder();
327        assert_eq!(enc.total_scored(), 0);
328        let pair = CandidatePair {
329            query: "x".into(),
330            document: "y".into(),
331            initial_score: 0.0,
332        };
333        enc.score(&pair);
334        assert_eq!(enc.total_scored(), 1);
335    }
336
337    // --- CrossEncoder::score_batch ---
338
339    #[test]
340    fn test_score_batch_length_matches_input() {
341        let mut enc = default_encoder();
342        let pairs: Vec<CandidatePair> = (0..5)
343            .map(|i| CandidatePair {
344                query: format!("query {i}"),
345                document: format!("doc {i}"),
346                initial_score: 0.5,
347            })
348            .collect();
349        let scores = enc.score_batch(&pairs);
350        assert_eq!(scores.len(), 5);
351    }
352
353    #[test]
354    fn test_score_batch_increments_total_scored() {
355        let mut enc = default_encoder();
356        let pairs: Vec<CandidatePair> = (0..10)
357            .map(|i| CandidatePair {
358                query: "q".into(),
359                document: format!("d {i}"),
360                initial_score: 0.0,
361            })
362            .collect();
363        enc.score_batch(&pairs);
364        assert_eq!(enc.total_scored(), 10);
365    }
366
367    // --- CrossEncoder::rerank ---
368
369    #[test]
370    fn test_rerank_sorted_descending() {
371        let mut enc = default_encoder();
372        let candidates = vec![
373            "apple".to_string(),
374            "apple banana".to_string(),
375            "apple banana cherry".to_string(),
376        ];
377        let query = "apple banana cherry";
378        let initial = vec![0.3, 0.6, 0.9];
379        let results = enc.rerank(query, &candidates, &initial);
380        // Results should be sorted by rerank_score descending.
381        for w in results.windows(2) {
382            assert!(w[0].rerank_score >= w[1].rerank_score);
383        }
384    }
385
386    #[test]
387    fn test_rerank_rank_field_correct() {
388        let mut enc = default_encoder();
389        let candidates = vec!["a b c".to_string(), "x y z".to_string()];
390        let results = enc.rerank("a b c", &candidates, &[0.5, 0.5]);
391        assert_eq!(results[0].rank, 1);
392        assert_eq!(results[1].rank, 2);
393    }
394
395    #[test]
396    fn test_rerank_empty_candidates() {
397        let mut enc = default_encoder();
398        let results = enc.rerank("query", &[], &[]);
399        assert!(results.is_empty());
400    }
401
402    #[test]
403    fn test_rerank_total_scored_increments() {
404        let mut enc = default_encoder();
405        let docs: Vec<String> = (0..3).map(|i| format!("doc {i}")).collect();
406        let scores: Vec<f32> = (0..3).map(|i| i as f32 * 0.1).collect();
407        enc.rerank("q", &docs, &scores);
408        assert_eq!(enc.total_scored(), 3);
409    }
410
411    // --- CrossEncoder::top_k ---
412
413    #[test]
414    fn test_top_k_limits_output() {
415        let mut enc = default_encoder();
416        let docs: Vec<String> = (0..10).map(|i| format!("word{i} text")).collect();
417        let initial: Vec<f32> = (0..10).map(|i| i as f32 * 0.1).collect();
418        let results = enc.top_k("word5 text", &docs, &initial, 3);
419        assert_eq!(results.len(), 3);
420    }
421
422    #[test]
423    fn test_top_k_returns_all_when_k_exceeds_count() {
424        let mut enc = default_encoder();
425        let docs = vec!["a".to_string(), "b".to_string()];
426        let results = enc.top_k("a", &docs, &[0.5, 0.2], 100);
427        assert_eq!(results.len(), 2);
428    }
429
430    #[test]
431    fn test_top_k_rank_starts_at_one() {
432        let mut enc = default_encoder();
433        let docs = vec!["hello world".to_string(), "foo bar".to_string()];
434        let results = enc.top_k("hello world", &docs, &[0.5, 0.5], 2);
435        assert_eq!(results[0].rank, 1);
436    }
437
438    // --- normalize_scores integration ---
439
440    #[test]
441    fn test_rerank_with_normalisation_range() {
442        let mut enc = norming_encoder();
443        let docs = vec!["a b".to_string(), "c d".to_string(), "e f".to_string()];
444        let initial = vec![0.1, 0.5, 0.9];
445        let results = enc.rerank("a b", &docs, &initial);
446        for r in &results {
447            assert!(r.rerank_score >= 0.0 && r.rerank_score <= 1.0);
448        }
449    }
450}