1use crate::DeterministicScore;
4use std::fmt;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum ScoreError {
8 LengthMismatch {
9 expected_desc: &'static str,
10 first_len: usize,
11 second_len: usize,
12 },
13 NonFiniteWeight {
14 index: usize,
15 },
16}
17
18impl fmt::Display for ScoreError {
19 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20 match self {
21 ScoreError::LengthMismatch {
22 expected_desc,
23 first_len,
24 second_len,
25 } => write!(
26 f,
27 "{expected_desc}: first has {first_len} elements, second has {second_len}"
28 ),
29 ScoreError::NonFiniteWeight { index } => {
30 write!(f, "weight at index {index} must be finite")
31 }
32 }
33 }
34}
35
36impl std::error::Error for ScoreError {}
37
38#[inline]
39pub fn sum_scores(scores: &[DeterministicScore]) -> DeterministicScore {
40 if scores.is_empty() {
41 return DeterministicScore::ZERO;
42 }
43 let sum: i128 = scores.iter().map(|s| s.to_raw() as i128).sum();
44 DeterministicScore::from_raw(sum.clamp(
45 DeterministicScore::NEG_INF.to_raw() as i128,
46 i64::MAX as i128,
47 ) as i64)
48}
49
50#[inline]
51pub fn avg_scores(scores: &[DeterministicScore]) -> DeterministicScore {
52 if scores.is_empty() {
53 return DeterministicScore::ZERO;
54 }
55 let sum: i128 = scores.iter().map(|s| s.to_raw() as i128).sum();
56 let mean = sum / scores.len() as i128;
57 DeterministicScore::from_raw(mean.clamp(
58 DeterministicScore::NEG_INF.to_raw() as i128,
59 i64::MAX as i128,
60 ) as i64)
61}
62
63#[inline]
64pub fn avg_scores_checked(scores: &[DeterministicScore]) -> (DeterministicScore, bool) {
65 if scores.is_empty() {
66 return (DeterministicScore::ZERO, false);
67 }
68 const SATURATION_THRESHOLD: i128 = (i64::MAX as i128) * 9 / 10;
69 let mut sum = 0i128;
70 let mut near_saturation = false;
71 for score in scores {
72 sum += score.to_raw() as i128;
73 near_saturation |= sum.abs() > SATURATION_THRESHOLD;
74 }
75 let mean = sum / scores.len() as i128;
76 near_saturation |= mean.abs() > SATURATION_THRESHOLD;
77 let result = DeterministicScore::from_raw(mean.clamp(
78 DeterministicScore::NEG_INF.to_raw() as i128,
79 i64::MAX as i128,
80 ) as i64);
81 (result, near_saturation)
82}
83
84#[inline]
85pub fn max_score(scores: &[DeterministicScore]) -> DeterministicScore {
86 scores
87 .iter()
88 .copied()
89 .max()
90 .unwrap_or(DeterministicScore::NEG_INF)
91}
92
93#[inline]
94pub fn min_score(scores: &[DeterministicScore]) -> DeterministicScore {
95 scores
96 .iter()
97 .copied()
98 .min()
99 .unwrap_or(DeterministicScore::MAX)
100}
101
102#[inline]
104pub fn rrf_score(rank: usize, k: usize) -> DeterministicScore {
105 let Some(denominator) = k.checked_add(rank) else {
106 return DeterministicScore::ZERO;
107 };
108 if denominator == 0 {
109 return DeterministicScore::ZERO;
110 }
111 DeterministicScore::from_f64(1.0 / (denominator as f64))
112}
113
114#[inline]
115pub fn weighted_sum(
116 scores: &[DeterministicScore],
117 weights: &[f64],
118) -> Result<DeterministicScore, ScoreError> {
119 if scores.len() != weights.len() {
120 return Err(ScoreError::LengthMismatch {
121 expected_desc: "scores and weights must have same length",
122 first_len: scores.len(),
123 second_len: weights.len(),
124 });
125 }
126 let mut acc = DeterministicScore::ZERO;
127 for (index, (&score, &weight)) in scores.iter().zip(weights.iter()).enumerate() {
128 if !weight.is_finite() {
129 return Err(ScoreError::NonFiniteWeight { index });
130 }
131 acc = acc + score * weight;
132 }
133 Ok(acc)
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 fn s(v: f64) -> DeterministicScore {
141 DeterministicScore::from_f64(v)
142 }
143
144 #[test]
145 fn sum_basic() {
146 let scores = [s(0.1), s(0.2), s(0.3)];
147 let result = sum_scores(&scores);
148 assert!((result.to_f64() - 0.6).abs() < 1e-9);
149 }
150
151 #[test]
152 fn sum_empty() {
153 let result = sum_scores(&[]);
154 assert_eq!(result, DeterministicScore::ZERO);
155 }
156
157 #[test]
158 fn avg_basic() {
159 let scores = [s(0.1), s(0.2), s(0.3)];
160 let result = avg_scores(&scores);
161 assert!((result.to_f64() - 0.2).abs() < 1e-9);
162 }
163
164 #[test]
165 fn rrf_basic() {
166 let r1 = rrf_score(1, 60);
167 let r2 = rrf_score(2, 60);
168 assert!(r1 > r2);
169 assert!((r1.to_f64() - 1.0 / 61.0).abs() < 1e-9);
170 }
171
172 #[test]
173 fn weighted_sum_basic() {
174 let scores = [s(0.5), s(1.0)];
175 let weights = [0.4, 0.6];
176 let result = weighted_sum(&scores, &weights).unwrap();
177 assert!((result.to_f64() - 0.8).abs() < 1e-6);
178 }
179
180 #[test]
181 fn weighted_sum_length_mismatch() {
182 let err = weighted_sum(&[s(0.1)], &[0.5, 0.5]).unwrap_err();
183 assert!(matches!(err, ScoreError::LengthMismatch { .. }));
184 }
185
186 #[test]
187 fn weighted_sum_rejects_nan() {
188 let err = weighted_sum(&[s(0.1)], &[f64::NAN]).unwrap_err();
189 assert!(matches!(err, ScoreError::NonFiniteWeight { index: 0 }));
190 }
191
192 #[test]
193 fn sum_negative_saturation_clamps_to_neg_inf() {
194 let big_neg = DeterministicScore::NEG_INF;
195 let result = sum_scores(&[big_neg, big_neg, big_neg]);
196 assert_eq!(result, DeterministicScore::NEG_INF);
197 assert!(result.is_infinite());
198 assert_eq!(result.to_f64(), f64::NEG_INFINITY);
199 }
200
201 #[test]
202 fn avg_negative_saturation_clamps_to_neg_inf() {
203 let big_neg = DeterministicScore::NEG_INF;
204 let result = avg_scores(&[big_neg, big_neg]);
205 assert_eq!(result, DeterministicScore::NEG_INF);
206 }
207
208 #[test]
209 fn sum_order_independent() {
210 let a = DeterministicScore::from_f64(1e9);
211 let b = DeterministicScore::from_f64(-1e9);
212 let c = DeterministicScore::from_f64(0.5);
213 let r1 = sum_scores(&[a, b, c]);
214 let r2 = sum_scores(&[c, a, b]);
215 let r3 = sum_scores(&[b, c, a]);
216 assert_eq!(r1, r2);
217 assert_eq!(r2, r3);
218 }
219
220 #[test]
221 fn avg_scores_checked_empty_returns_zero_no_flag() {
222 let (mean, flag) = avg_scores_checked(&[]);
223 assert_eq!(mean, DeterministicScore::ZERO);
224 assert!(!flag);
225 }
226
227 #[test]
228 fn avg_scores_checked_near_saturation_sets_flag() {
229 let (_, flag) = avg_scores_checked(&[DeterministicScore::MAX, DeterministicScore::MAX]);
230 assert!(flag);
231 }
232
233 #[test]
234 fn max_score_empty_returns_neg_inf() {
235 assert_eq!(max_score(&[]), DeterministicScore::NEG_INF);
236 }
237
238 #[test]
239 fn min_score_empty_returns_max() {
240 assert_eq!(min_score(&[]), DeterministicScore::MAX);
241 }
242
243 #[test]
244 fn rrf_score_zero_denominator_returns_zero() {
245 assert_eq!(rrf_score(0, 0), DeterministicScore::ZERO);
246 }
247
248 #[test]
249 fn rrf_score_overflow_returns_zero() {
250 assert_eq!(rrf_score(usize::MAX, 1), DeterministicScore::ZERO);
251 }
252
253 #[test]
254 fn weighted_sum_empty_returns_zero() {
255 assert_eq!(weighted_sum(&[], &[]).unwrap(), DeterministicScore::ZERO);
256 }
257
258 #[test]
259 fn weighted_sum_rejects_infinite_weight() {
260 let err = weighted_sum(&[s(1.0)], &[f64::INFINITY]).unwrap_err();
261 assert_eq!(err, ScoreError::NonFiniteWeight { index: 0 });
262 }
263}