1use crate::DeterministicScore;
4use std::fmt;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum ScoreError {
8 LengthMismatch {
9 expected_desc: &'static str,
10 first_len: usize,
11 second_len: usize,
12 },
13 NonFiniteWeight {
14 index: usize,
15 },
16}
17
18impl fmt::Display for ScoreError {
19 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20 match self {
21 ScoreError::LengthMismatch {
22 expected_desc,
23 first_len,
24 second_len,
25 } => write!(
26 f,
27 "{expected_desc}: first has {first_len} elements, second has {second_len}"
28 ),
29 ScoreError::NonFiniteWeight { index } => {
30 write!(f, "weight at index {index} must be finite")
31 }
32 }
33 }
34}
35
36impl std::error::Error for ScoreError {}
37
38#[inline]
39pub fn sum_scores(scores: &[DeterministicScore]) -> DeterministicScore {
40 if scores.is_empty() {
41 return DeterministicScore::ZERO;
42 }
43 let sum: i128 = scores.iter().map(|s| s.to_raw() as i128).sum();
44 DeterministicScore::from_raw(sum.clamp(
45 DeterministicScore::NEG_INF.to_raw() as i128,
46 i64::MAX as i128,
47 ) as i64)
48}
49
50#[inline]
51pub fn avg_scores(scores: &[DeterministicScore]) -> DeterministicScore {
52 if scores.is_empty() {
53 return DeterministicScore::ZERO;
54 }
55 let sum: i128 = scores.iter().map(|s| s.to_raw() as i128).sum();
56 let mean = sum / scores.len() as i128;
57 DeterministicScore::from_raw(mean.clamp(
58 DeterministicScore::NEG_INF.to_raw() as i128,
59 i64::MAX as i128,
60 ) as i64)
61}
62
63#[inline]
64pub fn avg_scores_checked(scores: &[DeterministicScore]) -> (DeterministicScore, bool) {
65 if scores.is_empty() {
66 return (DeterministicScore::ZERO, false);
67 }
68 const SATURATION_THRESHOLD: i128 = (i64::MAX as i128) * 9 / 10;
69 let mut sum = 0i128;
70 let mut near_saturation = false;
71 for score in scores {
72 sum += score.to_raw() as i128;
73 near_saturation |= sum.abs() > SATURATION_THRESHOLD;
74 }
75 let mean = sum / scores.len() as i128;
76 near_saturation |= mean.abs() > SATURATION_THRESHOLD;
77 let result = DeterministicScore::from_raw(mean.clamp(
78 DeterministicScore::NEG_INF.to_raw() as i128,
79 i64::MAX as i128,
80 ) as i64);
81 (result, near_saturation)
82}
83
84#[inline]
85pub fn max_score(scores: &[DeterministicScore]) -> DeterministicScore {
86 scores
87 .iter()
88 .copied()
89 .max()
90 .unwrap_or(DeterministicScore::NEG_INF)
91}
92
93#[inline]
94pub fn min_score(scores: &[DeterministicScore]) -> DeterministicScore {
95 scores
96 .iter()
97 .copied()
98 .min()
99 .unwrap_or(DeterministicScore::MAX)
100}
101
102#[inline]
104pub fn rrf_score(rank: usize, k: usize) -> DeterministicScore {
105 let Some(denominator) = k.checked_add(rank) else {
106 return DeterministicScore::ZERO;
107 };
108 if denominator == 0 {
109 return DeterministicScore::ZERO;
110 }
111 DeterministicScore::from_f64(1.0 / (denominator as f64))
112}
113
114const SCALE_RAW: i128 = 4_294_967_296; #[inline]
117pub fn weighted_sum(
118 scores: &[DeterministicScore],
119 weights: &[f64],
120) -> Result<DeterministicScore, ScoreError> {
121 if scores.len() != weights.len() {
122 return Err(ScoreError::LengthMismatch {
123 expected_desc: "scores and weights must have same length",
124 first_len: scores.len(),
125 second_len: weights.len(),
126 });
127 }
128 let mut acc = 0i128;
129 for (index, (&score, &weight)) in scores.iter().zip(weights.iter()).enumerate() {
130 if !weight.is_finite() {
131 return Err(ScoreError::NonFiniteWeight { index });
132 }
133 let w = DeterministicScore::from_f64(weight);
134 acc += (score.to_raw() as i128 * w.to_raw() as i128) / SCALE_RAW;
135 }
136 Ok(DeterministicScore::from_raw(acc.clamp(
137 DeterministicScore::NEG_INF.to_raw() as i128,
138 DeterministicScore::MAX.to_raw() as i128,
139 ) as i64))
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 fn s(v: f64) -> DeterministicScore {
147 DeterministicScore::from_f64(v)
148 }
149
150 #[test]
151 fn sum_basic() {
152 let scores = [s(0.1), s(0.2), s(0.3)];
153 let result = sum_scores(&scores);
154 assert!((result.to_f64() - 0.6).abs() < 1e-9);
155 }
156
157 #[test]
158 fn sum_empty() {
159 let result = sum_scores(&[]);
160 assert_eq!(result, DeterministicScore::ZERO);
161 }
162
163 #[test]
164 fn avg_basic() {
165 let scores = [s(0.1), s(0.2), s(0.3)];
166 let result = avg_scores(&scores);
167 assert!((result.to_f64() - 0.2).abs() < 1e-9);
168 }
169
170 #[test]
171 fn rrf_basic() {
172 let r1 = rrf_score(1, 60);
173 let r2 = rrf_score(2, 60);
174 assert!(r1 > r2);
175 assert!((r1.to_f64() - 1.0 / 61.0).abs() < 1e-9);
176 }
177
178 #[test]
179 fn weighted_sum_basic() {
180 let scores = [s(0.5), s(1.0)];
181 let weights = [0.4, 0.6];
182 let result = weighted_sum(&scores, &weights).unwrap();
183 assert!((result.to_f64() - 0.8).abs() < 1e-6);
184 }
185
186 #[test]
187 fn weighted_sum_length_mismatch() {
188 let err = weighted_sum(&[s(0.1)], &[0.5, 0.5]).unwrap_err();
189 assert!(matches!(err, ScoreError::LengthMismatch { .. }));
190 }
191
192 #[test]
193 fn weighted_sum_rejects_nan() {
194 let err = weighted_sum(&[s(0.1)], &[f64::NAN]).unwrap_err();
195 assert!(matches!(err, ScoreError::NonFiniteWeight { index: 0 }));
196 }
197
198 #[test]
199 fn sum_negative_saturation_clamps_to_neg_inf() {
200 let big_neg = DeterministicScore::NEG_INF;
201 let result = sum_scores(&[big_neg, big_neg, big_neg]);
202 assert_eq!(result, DeterministicScore::NEG_INF);
203 assert!(result.is_infinite());
204 assert_eq!(result.to_f64(), f64::NEG_INFINITY);
205 }
206
207 #[test]
208 fn avg_negative_saturation_clamps_to_neg_inf() {
209 let big_neg = DeterministicScore::NEG_INF;
210 let result = avg_scores(&[big_neg, big_neg]);
211 assert_eq!(result, DeterministicScore::NEG_INF);
212 }
213
214 #[test]
215 fn sum_order_independent() {
216 let a = DeterministicScore::from_f64(1e9);
217 let b = DeterministicScore::from_f64(-1e9);
218 let c = DeterministicScore::from_f64(0.5);
219 let r1 = sum_scores(&[a, b, c]);
220 let r2 = sum_scores(&[c, a, b]);
221 let r3 = sum_scores(&[b, c, a]);
222 assert_eq!(r1, r2);
223 assert_eq!(r2, r3);
224 }
225
226 #[test]
227 fn avg_scores_checked_empty_returns_zero_no_flag() {
228 let (mean, flag) = avg_scores_checked(&[]);
229 assert_eq!(mean, DeterministicScore::ZERO);
230 assert!(!flag);
231 }
232
233 #[test]
234 fn avg_scores_checked_near_saturation_sets_flag() {
235 let (_, flag) = avg_scores_checked(&[DeterministicScore::MAX, DeterministicScore::MAX]);
236 assert!(flag);
237 }
238
239 #[test]
240 fn max_score_empty_returns_neg_inf() {
241 assert_eq!(max_score(&[]), DeterministicScore::NEG_INF);
242 }
243
244 #[test]
245 fn min_score_empty_returns_max() {
246 assert_eq!(min_score(&[]), DeterministicScore::MAX);
247 }
248
249 #[test]
250 fn rrf_score_zero_denominator_returns_zero() {
251 assert_eq!(rrf_score(0, 0), DeterministicScore::ZERO);
252 }
253
254 #[test]
255 fn rrf_score_overflow_returns_zero() {
256 assert_eq!(rrf_score(usize::MAX, 1), DeterministicScore::ZERO);
257 }
258
259 #[test]
260 fn weighted_sum_empty_returns_zero() {
261 assert_eq!(weighted_sum(&[], &[]).unwrap(), DeterministicScore::ZERO);
262 }
263
264 #[test]
265 fn weighted_sum_rejects_infinite_weight() {
266 let err = weighted_sum(&[s(1.0)], &[f64::INFINITY]).unwrap_err();
267 assert_eq!(err, ScoreError::NonFiniteWeight { index: 0 });
268 }
269}