1use std::collections::HashSet;
9
10#[derive(Debug, Clone)]
17pub struct CandidatePair {
18 pub query: String,
20 pub document: String,
22 pub initial_score: f32,
24}
25
26#[derive(Debug, Clone)]
28pub struct RerankResult {
29 pub document: String,
31 pub initial_score: f32,
33 pub rerank_score: f32,
35 pub rank: usize,
37}
38
39#[derive(Debug, Clone)]
41pub struct CrossEncoderConfig {
42 pub max_length: usize,
44 pub normalize_scores: bool,
47 pub batch_size: usize,
49}
50
51impl Default for CrossEncoderConfig {
52 fn default() -> Self {
53 CrossEncoderConfig {
54 max_length: 512,
55 normalize_scores: false,
56 batch_size: 32,
57 }
58 }
59}
60
61pub struct CrossEncoder {
63 config: CrossEncoderConfig,
64 total_scored: u64,
65}
66
67impl CrossEncoder {
68 pub fn new(config: CrossEncoderConfig) -> Self {
70 CrossEncoder {
71 config,
72 total_scored: 0,
73 }
74 }
75
76 pub fn score(&mut self, pair: &CandidatePair) -> f32 {
78 self.total_scored += 1;
79 token_overlap_score(&pair.query, &pair.document)
80 }
81
82 pub fn score_batch(&mut self, pairs: &[CandidatePair]) -> Vec<f32> {
85 pairs.iter().map(|p| self.score(p)).collect()
86 }
87
88 pub fn rerank(
93 &mut self,
94 query: &str,
95 candidates: &[String],
96 initial_scores: &[f32],
97 ) -> Vec<RerankResult> {
98 let n = candidates.len().min(initial_scores.len());
99 let pairs: Vec<CandidatePair> = (0..n)
100 .map(|i| CandidatePair {
101 query: query.to_string(),
102 document: candidates[i].clone(),
103 initial_score: initial_scores[i],
104 })
105 .collect();
106
107 let mut raw_scores = self.score_batch(&pairs);
108
109 if self.config.normalize_scores {
110 raw_scores = normalize_scores(&raw_scores);
111 }
112
113 let mut results: Vec<RerankResult> = (0..n)
114 .map(|i| RerankResult {
115 document: candidates[i].clone(),
116 initial_score: initial_scores[i],
117 rerank_score: raw_scores[i],
118 rank: 0, })
120 .collect();
121
122 results.sort_by(|a, b| {
124 b.rerank_score
125 .partial_cmp(&a.rerank_score)
126 .unwrap_or(std::cmp::Ordering::Equal)
127 });
128
129 for (idx, r) in results.iter_mut().enumerate() {
131 r.rank = idx + 1;
132 }
133
134 results
135 }
136
137 pub fn top_k(
139 &mut self,
140 query: &str,
141 candidates: &[String],
142 initial_scores: &[f32],
143 k: usize,
144 ) -> Vec<RerankResult> {
145 let mut all = self.rerank(query, candidates, initial_scores);
146 all.truncate(k);
147 all
148 }
149
150 pub fn total_scored(&self) -> u64 {
152 self.total_scored
153 }
154}
155
156pub(crate) fn token_overlap_score(a: &str, b: &str) -> f32 {
165 let set_a: HashSet<&str> = a.split_whitespace().collect();
166 let set_b: HashSet<&str> = b.split_whitespace().collect();
167
168 if set_a.is_empty() && set_b.is_empty() {
169 return 1.0;
171 }
172
173 let intersection = set_a.intersection(&set_b).count();
174 let union = set_a.union(&set_b).count();
175 if union == 0 {
176 0.0
177 } else {
178 intersection as f32 / union as f32
179 }
180}
181
182pub(crate) fn normalize_scores(scores: &[f32]) -> Vec<f32> {
186 if scores.is_empty() {
187 return Vec::new();
188 }
189
190 let min = scores.iter().copied().fold(f32::MAX, f32::min);
191 let max = scores.iter().copied().fold(f32::MIN, f32::max);
192
193 let range = max - min;
194 if range == 0.0 {
195 return vec![0.0; scores.len()];
196 }
197 scores.iter().map(|&s| (s - min) / range).collect()
198}
199
200#[cfg(test)]
205mod tests {
206 use super::*;
207
208 fn default_encoder() -> CrossEncoder {
209 CrossEncoder::new(CrossEncoderConfig::default())
210 }
211
212 fn norming_encoder() -> CrossEncoder {
213 CrossEncoder::new(CrossEncoderConfig {
214 normalize_scores: true,
215 ..CrossEncoderConfig::default()
216 })
217 }
218
219 #[test]
222 fn test_token_overlap_identical_strings() {
223 let score = token_overlap_score("the quick brown fox", "the quick brown fox");
224 assert!(
225 (score - 1.0).abs() < 1e-6,
226 "identical strings should score 1.0"
227 );
228 }
229
230 #[test]
231 fn test_token_overlap_disjoint_strings() {
232 let score = token_overlap_score("apple orange", "banana grape");
233 assert!(
234 (score - 0.0).abs() < 1e-6,
235 "disjoint strings should score 0.0"
236 );
237 }
238
239 #[test]
240 fn test_token_overlap_partial_match() {
241 let score = token_overlap_score("the fox", "the cat");
242 assert!((score - 1.0 / 3.0).abs() < 1e-5);
244 }
245
246 #[test]
247 fn test_token_overlap_both_empty() {
248 let score = token_overlap_score("", "");
249 assert!((score - 1.0).abs() < 1e-6);
250 }
251
252 #[test]
253 fn test_token_overlap_one_empty() {
254 let score = token_overlap_score("hello", "");
255 assert!((score - 0.0).abs() < 1e-6);
256 }
257
258 #[test]
261 fn test_normalize_scores_range() {
262 let scores = vec![0.1f32, 0.5, 0.9];
263 let norm = normalize_scores(&scores);
264 for &v in &norm {
266 assert!(v >= 0.0, "normalised value {v} is below 0");
267 assert!(v <= 1.0, "normalised value {v} is above 1");
268 }
269 }
270
271 #[test]
272 fn test_normalize_scores_min_is_zero() {
273 let scores = vec![2.0f32, 4.0, 6.0];
274 let norm = normalize_scores(&scores);
275 assert!((norm[0] - 0.0).abs() < 1e-6);
276 }
277
278 #[test]
279 fn test_normalize_scores_max_is_one() {
280 let scores = vec![2.0f32, 4.0, 6.0];
281 let norm = normalize_scores(&scores);
282 assert!((norm[2] - 1.0).abs() < 1e-6);
283 }
284
285 #[test]
286 fn test_normalize_scores_all_equal() {
287 let scores = vec![3.0f32, 3.0, 3.0];
288 let norm = normalize_scores(&scores);
289 assert!(norm.iter().all(|&v| v == 0.0));
290 }
291
292 #[test]
293 fn test_normalize_scores_empty() {
294 let norm = normalize_scores(&[]);
295 assert!(norm.is_empty());
296 }
297
298 #[test]
301 fn test_score_identical() {
302 let mut enc = default_encoder();
303 let pair = CandidatePair {
304 query: "foo bar".into(),
305 document: "foo bar".into(),
306 initial_score: 0.9,
307 };
308 let s = enc.score(&pair);
309 assert!((s - 1.0).abs() < 1e-6);
310 }
311
312 #[test]
313 fn test_score_disjoint() {
314 let mut enc = default_encoder();
315 let pair = CandidatePair {
316 query: "apple".into(),
317 document: "banana".into(),
318 initial_score: 0.1,
319 };
320 let s = enc.score(&pair);
321 assert!((s - 0.0).abs() < 1e-6);
322 }
323
324 #[test]
325 fn test_score_increments_total_scored() {
326 let mut enc = default_encoder();
327 assert_eq!(enc.total_scored(), 0);
328 let pair = CandidatePair {
329 query: "x".into(),
330 document: "y".into(),
331 initial_score: 0.0,
332 };
333 enc.score(&pair);
334 assert_eq!(enc.total_scored(), 1);
335 }
336
337 #[test]
340 fn test_score_batch_length_matches_input() {
341 let mut enc = default_encoder();
342 let pairs: Vec<CandidatePair> = (0..5)
343 .map(|i| CandidatePair {
344 query: format!("query {i}"),
345 document: format!("doc {i}"),
346 initial_score: 0.5,
347 })
348 .collect();
349 let scores = enc.score_batch(&pairs);
350 assert_eq!(scores.len(), 5);
351 }
352
353 #[test]
354 fn test_score_batch_increments_total_scored() {
355 let mut enc = default_encoder();
356 let pairs: Vec<CandidatePair> = (0..10)
357 .map(|i| CandidatePair {
358 query: "q".into(),
359 document: format!("d {i}"),
360 initial_score: 0.0,
361 })
362 .collect();
363 enc.score_batch(&pairs);
364 assert_eq!(enc.total_scored(), 10);
365 }
366
367 #[test]
370 fn test_rerank_sorted_descending() {
371 let mut enc = default_encoder();
372 let candidates = vec![
373 "apple".to_string(),
374 "apple banana".to_string(),
375 "apple banana cherry".to_string(),
376 ];
377 let query = "apple banana cherry";
378 let initial = vec![0.3, 0.6, 0.9];
379 let results = enc.rerank(query, &candidates, &initial);
380 for w in results.windows(2) {
382 assert!(w[0].rerank_score >= w[1].rerank_score);
383 }
384 }
385
386 #[test]
387 fn test_rerank_rank_field_correct() {
388 let mut enc = default_encoder();
389 let candidates = vec!["a b c".to_string(), "x y z".to_string()];
390 let results = enc.rerank("a b c", &candidates, &[0.5, 0.5]);
391 assert_eq!(results[0].rank, 1);
392 assert_eq!(results[1].rank, 2);
393 }
394
395 #[test]
396 fn test_rerank_empty_candidates() {
397 let mut enc = default_encoder();
398 let results = enc.rerank("query", &[], &[]);
399 assert!(results.is_empty());
400 }
401
402 #[test]
403 fn test_rerank_total_scored_increments() {
404 let mut enc = default_encoder();
405 let docs: Vec<String> = (0..3).map(|i| format!("doc {i}")).collect();
406 let scores: Vec<f32> = (0..3).map(|i| i as f32 * 0.1).collect();
407 enc.rerank("q", &docs, &scores);
408 assert_eq!(enc.total_scored(), 3);
409 }
410
411 #[test]
414 fn test_top_k_limits_output() {
415 let mut enc = default_encoder();
416 let docs: Vec<String> = (0..10).map(|i| format!("word{i} text")).collect();
417 let initial: Vec<f32> = (0..10).map(|i| i as f32 * 0.1).collect();
418 let results = enc.top_k("word5 text", &docs, &initial, 3);
419 assert_eq!(results.len(), 3);
420 }
421
422 #[test]
423 fn test_top_k_returns_all_when_k_exceeds_count() {
424 let mut enc = default_encoder();
425 let docs = vec!["a".to_string(), "b".to_string()];
426 let results = enc.top_k("a", &docs, &[0.5, 0.2], 100);
427 assert_eq!(results.len(), 2);
428 }
429
430 #[test]
431 fn test_top_k_rank_starts_at_one() {
432 let mut enc = default_encoder();
433 let docs = vec!["hello world".to_string(), "foo bar".to_string()];
434 let results = enc.top_k("hello world", &docs, &[0.5, 0.5], 2);
435 assert_eq!(results[0].rank, 1);
436 }
437
438 #[test]
441 fn test_rerank_with_normalisation_range() {
442 let mut enc = norming_encoder();
443 let docs = vec!["a b".to_string(), "c d".to_string(), "e f".to_string()];
444 let initial = vec![0.1, 0.5, 0.9];
445 let results = enc.rerank("a b", &docs, &initial);
446 for r in &results {
447 assert!(r.rerank_score >= 0.0 && r.rerank_score <= 1.0);
448 }
449 }
450}