1use crate::DeterministicScore;
4use std::fmt;
5use std::num::NonZeroUsize;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
9pub enum ScoreError {
10 LengthMismatch {
12 expected_desc: &'static str,
14 first_len: usize,
16 second_len: usize,
18 },
19 NonFiniteWeight {
21 index: usize,
23 },
24 NonFiniteDistance,
26 InvalidDistanceRange {
28 metric_name: &'static str,
30 dist_bits: u32,
32 },
33 UnsupportedMetric,
35}
36
37impl fmt::Display for ScoreError {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 match self {
40 ScoreError::LengthMismatch {
41 expected_desc,
42 first_len,
43 second_len,
44 } => write!(
45 f,
46 "{expected_desc}: first has {first_len} elements, second has {second_len}"
47 ),
48 ScoreError::NonFiniteWeight { index } => {
49 write!(f, "weight at index {index} must be finite")
50 }
51 ScoreError::NonFiniteDistance => {
52 write!(f, "distance must be finite (not NaN or infinity)")
53 }
54 ScoreError::InvalidDistanceRange {
55 metric_name,
56 dist_bits,
57 } => write!(
58 f,
59 "distance value (bits=0x{dist_bits:08x}) is out of valid range for metric {metric_name}"
60 ),
61 ScoreError::UnsupportedMetric => {
62 write!(f, "unsupported distance metric")
63 }
64 }
65 }
66}
67
68impl std::error::Error for ScoreError {}
69
70#[inline]
72pub fn sum_scores(scores: &[DeterministicScore]) -> DeterministicScore {
73 if scores.is_empty() {
74 return DeterministicScore::ZERO;
75 }
76 let sum: i128 = scores.iter().map(|s| s.to_raw() as i128).sum();
77 DeterministicScore::from_raw(sum.clamp(
78 DeterministicScore::NEG_INF.to_raw() as i128,
79 i64::MAX as i128,
80 ) as i64)
81}
82
83#[inline]
85pub fn avg_scores(scores: &[DeterministicScore]) -> DeterministicScore {
86 if scores.is_empty() {
87 return DeterministicScore::ZERO;
88 }
89 let sum: i128 = scores.iter().map(|s| s.to_raw() as i128).sum();
90 let mean = sum / scores.len() as i128;
91 DeterministicScore::from_raw(mean.clamp(
92 DeterministicScore::NEG_INF.to_raw() as i128,
93 i64::MAX as i128,
94 ) as i64)
95}
96
97#[inline]
99pub fn avg_scores_checked(scores: &[DeterministicScore]) -> (DeterministicScore, bool) {
100 if scores.is_empty() {
101 return (DeterministicScore::ZERO, false);
102 }
103 const SATURATION_THRESHOLD: i128 = (i64::MAX as i128) * 9 / 10;
104 let sum: i128 = scores.iter().map(|s| s.to_raw() as i128).sum();
105 let mean = sum / scores.len() as i128;
106 let abs_mass: i128 = scores
109 .iter()
110 .map(|s| (s.to_raw() as i128).unsigned_abs() as i128)
111 .sum();
112 let near_saturation =
113 abs_mass > SATURATION_THRESHOLD || mean.unsigned_abs() as i128 > SATURATION_THRESHOLD;
114 let result = DeterministicScore::from_raw(mean.clamp(
115 DeterministicScore::NEG_INF.to_raw() as i128,
116 i64::MAX as i128,
117 ) as i64);
118 (result, near_saturation)
119}
120
121#[inline]
123pub fn max_score(scores: &[DeterministicScore]) -> DeterministicScore {
124 scores
125 .iter()
126 .copied()
127 .max()
128 .unwrap_or(DeterministicScore::NEG_INF)
129}
130
131#[inline]
133pub fn min_score(scores: &[DeterministicScore]) -> DeterministicScore {
134 scores
135 .iter()
136 .copied()
137 .min()
138 .unwrap_or(DeterministicScore::MAX)
139}
140
141#[inline]
143pub fn rrf_score(rank: usize, k: usize) -> DeterministicScore {
144 let Some(denominator) = k.checked_add(rank) else {
145 return DeterministicScore::ZERO;
146 };
147 if denominator == 0 {
148 return DeterministicScore::ZERO;
149 }
150 DeterministicScore::from_f64(1.0 / (denominator as f64))
151}
152
153#[inline]
155pub fn rrf_score_one_based(rank: NonZeroUsize, k: usize) -> DeterministicScore {
156 let Some(denominator) = k.checked_add(rank.get()) else {
157 return DeterministicScore::ZERO;
158 };
159 DeterministicScore::from_f64(1.0 / denominator as f64)
160}
161
162#[inline]
164pub fn rrf_score_zero_based(index: usize, k: usize) -> DeterministicScore {
165 let Some(rank) = index.checked_add(1).and_then(NonZeroUsize::new) else {
166 return DeterministicScore::ZERO;
167 };
168 rrf_score_one_based(rank, k)
169}
170
171const SCALE_RAW: i128 = 4_294_967_296; #[inline]
175pub fn weighted_sum(
176 scores: &[DeterministicScore],
177 weights: &[f64],
178) -> Result<DeterministicScore, ScoreError> {
179 if scores.len() != weights.len() {
180 return Err(ScoreError::LengthMismatch {
181 expected_desc: "scores and weights must have same length",
182 first_len: scores.len(),
183 second_len: weights.len(),
184 });
185 }
186 let mut acc = 0i128;
187 for (index, (&score, &weight)) in scores.iter().zip(weights.iter()).enumerate() {
188 if !weight.is_finite() {
189 return Err(ScoreError::NonFiniteWeight { index });
190 }
191 let w = DeterministicScore::from_f64(weight);
192 acc += (score.to_raw() as i128 * w.to_raw() as i128) / SCALE_RAW;
193 }
194 Ok(DeterministicScore::from_raw(acc.clamp(
195 DeterministicScore::NEG_INF.to_raw() as i128,
196 DeterministicScore::MAX.to_raw() as i128,
197 ) as i64))
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 fn s(v: f64) -> DeterministicScore {
205 DeterministicScore::from_f64(v)
206 }
207
208 #[test]
209 fn sum_basic() {
210 let scores = [s(0.1), s(0.2), s(0.3)];
211 let result = sum_scores(&scores);
212 assert!((result.to_f64() - 0.6).abs() < 1e-9);
213 }
214
215 #[test]
216 fn sum_empty() {
217 let result = sum_scores(&[]);
218 assert_eq!(result, DeterministicScore::ZERO);
219 }
220
221 #[test]
222 fn avg_basic() {
223 let scores = [s(0.1), s(0.2), s(0.3)];
224 let result = avg_scores(&scores);
225 assert!((result.to_f64() - 0.2).abs() < 1e-9);
226 }
227
228 #[test]
229 fn rrf_basic() {
230 let r1 = rrf_score(1, 60);
231 let r2 = rrf_score(2, 60);
232 assert!(r1 > r2);
233 assert!((r1.to_f64() - 1.0 / 61.0).abs() < 1e-9);
234 }
235
236 #[test]
237 fn weighted_sum_basic() {
238 let scores = [s(0.5), s(1.0)];
239 let weights = [0.4, 0.6];
240 let result = weighted_sum(&scores, &weights).unwrap();
241 assert!((result.to_f64() - 0.8).abs() < 1e-6);
242 }
243
244 #[test]
245 fn weighted_sum_length_mismatch() {
246 let err = weighted_sum(&[s(0.1)], &[0.5, 0.5]).unwrap_err();
247 assert!(matches!(err, ScoreError::LengthMismatch { .. }));
248 }
249
250 #[test]
251 fn weighted_sum_rejects_nan() {
252 let err = weighted_sum(&[s(0.1)], &[f64::NAN]).unwrap_err();
253 assert!(matches!(err, ScoreError::NonFiniteWeight { index: 0 }));
254 }
255
256 #[test]
257 fn sum_negative_saturation_clamps_to_neg_inf() {
258 let big_neg = DeterministicScore::NEG_INF;
259 let result = sum_scores(&[big_neg, big_neg, big_neg]);
260 assert_eq!(result, DeterministicScore::NEG_INF);
261 assert!(result.is_infinite());
262 assert_eq!(result.to_f64(), f64::NEG_INFINITY);
263 }
264
265 #[test]
266 fn avg_negative_saturation_clamps_to_neg_inf() {
267 let big_neg = DeterministicScore::NEG_INF;
268 let result = avg_scores(&[big_neg, big_neg]);
269 assert_eq!(result, DeterministicScore::NEG_INF);
270 }
271
272 #[test]
273 fn sum_order_independent() {
274 let a = DeterministicScore::from_f64(1e9);
275 let b = DeterministicScore::from_f64(-1e9);
276 let c = DeterministicScore::from_f64(0.5);
277 let r1 = sum_scores(&[a, b, c]);
278 let r2 = sum_scores(&[c, a, b]);
279 let r3 = sum_scores(&[b, c, a]);
280 assert_eq!(r1, r2);
281 assert_eq!(r2, r3);
282 }
283
284 #[test]
285 fn avg_scores_checked_empty_returns_zero_no_flag() {
286 let (mean, flag) = avg_scores_checked(&[]);
287 assert_eq!(mean, DeterministicScore::ZERO);
288 assert!(!flag);
289 }
290
291 #[test]
292 fn avg_scores_checked_near_saturation_sets_flag() {
293 let (_, flag) = avg_scores_checked(&[DeterministicScore::MAX, DeterministicScore::MAX]);
294 assert!(flag);
295 }
296
297 #[test]
298 fn max_score_empty_returns_neg_inf() {
299 assert_eq!(max_score(&[]), DeterministicScore::NEG_INF);
300 }
301
302 #[test]
303 fn min_score_empty_returns_max() {
304 assert_eq!(min_score(&[]), DeterministicScore::MAX);
305 }
306
307 #[test]
308 fn rrf_score_zero_denominator_returns_zero() {
309 assert_eq!(rrf_score(0, 0), DeterministicScore::ZERO);
310 }
311
312 #[test]
313 fn rrf_score_overflow_returns_zero() {
314 assert_eq!(rrf_score(usize::MAX, 1), DeterministicScore::ZERO);
315 }
316
317 #[test]
318 fn weighted_sum_empty_returns_zero() {
319 assert_eq!(weighted_sum(&[], &[]).unwrap(), DeterministicScore::ZERO);
320 }
321
322 #[test]
323 fn weighted_sum_rejects_infinite_weight() {
324 let err = weighted_sum(&[s(1.0)], &[f64::INFINITY]).unwrap_err();
325 assert_eq!(err, ScoreError::NonFiniteWeight { index: 0 });
326 }
327
328 #[test]
331 fn rrf_one_based_rank_1_equals_legacy_rank_1() {
332 use std::num::NonZeroUsize;
333 let one_based = rrf_score_one_based(NonZeroUsize::new(1).unwrap(), 60);
334 let legacy = rrf_score(1, 60);
335 assert_eq!(
336 one_based, legacy,
337 "rrf_score_one_based(1,60) must match rrf_score(1,60)"
338 );
339 }
340
341 #[test]
342 fn rrf_zero_based_index_0_equals_one_based_rank_1() {
343 use std::num::NonZeroUsize;
344 let zero_based = rrf_score_zero_based(0, 60);
345 let one_based = rrf_score_one_based(NonZeroUsize::new(1).unwrap(), 60);
346 assert_eq!(zero_based, one_based);
347 }
348
349 #[test]
350 fn rrf_one_based_monotone_decreasing() {
351 use std::num::NonZeroUsize;
352 let r1 = rrf_score_one_based(NonZeroUsize::new(1).unwrap(), 60);
353 let r2 = rrf_score_one_based(NonZeroUsize::new(2).unwrap(), 60);
354 let r10 = rrf_score_one_based(NonZeroUsize::new(10).unwrap(), 60);
355 assert!(r1 > r2);
356 assert!(r2 > r10);
357 }
358
359 #[test]
360 fn rrf_one_based_value_matches_formula() {
361 use std::num::NonZeroUsize;
362 let score = rrf_score_one_based(NonZeroUsize::new(1).unwrap(), 60);
363 assert!((score.to_f64() - 1.0 / 61.0).abs() < 1e-9);
364 }
365
366 #[test]
367 fn rrf_one_based_overflow_returns_zero() {
368 use std::num::NonZeroUsize;
369 let score = rrf_score_one_based(NonZeroUsize::new(usize::MAX).unwrap(), 1);
370 assert_eq!(score, DeterministicScore::ZERO);
371 }
372
373 #[test]
374 fn rrf_zero_based_overflow_returns_zero() {
375 let score = rrf_score_zero_based(usize::MAX, 1);
376 assert_eq!(score, DeterministicScore::ZERO);
377 }
378
379 #[test]
382 fn avg_scores_checked_saturation_flag_is_order_independent() {
383 let big = DeterministicScore::from_raw(i64::MAX / 2);
386 let neg = DeterministicScore::from_raw(i64::MIN / 2 + 1);
387 let scores_order_a = [big, neg, big, neg];
388 let scores_order_b = [big, big, neg, neg];
389 let (_, flag_a) = avg_scores_checked(&scores_order_a);
390 let (_, flag_b) = avg_scores_checked(&scores_order_b);
391 assert_eq!(
392 flag_a, flag_b,
393 "near_saturation flag must be order-independent: order_a={flag_a}, order_b={flag_b}"
394 );
395 }
396
397 #[test]
398 fn avg_scores_checked_normal_values_no_saturation_flag() {
399 let scores = [s(0.1), s(0.2), s(-0.1), s(0.3)];
400 let (_, flag) = avg_scores_checked(&scores);
401 assert!(!flag, "small scores should not trigger near_saturation");
402 }
403}