Skip to main content

khive_score/
ops.rs

1//! Aggregation and fusion operations for deterministic scores.
2
3use crate::DeterministicScore;
4use std::fmt;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum ScoreError {
8    LengthMismatch {
9        expected_desc: &'static str,
10        first_len: usize,
11        second_len: usize,
12    },
13    NonFiniteWeight {
14        index: usize,
15    },
16}
17
18impl fmt::Display for ScoreError {
19    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20        match self {
21            ScoreError::LengthMismatch {
22                expected_desc,
23                first_len,
24                second_len,
25            } => write!(
26                f,
27                "{expected_desc}: first has {first_len} elements, second has {second_len}"
28            ),
29            ScoreError::NonFiniteWeight { index } => {
30                write!(f, "weight at index {index} must be finite")
31            }
32        }
33    }
34}
35
36impl std::error::Error for ScoreError {}
37
38#[inline]
39pub fn sum_scores(scores: &[DeterministicScore]) -> DeterministicScore {
40    if scores.is_empty() {
41        return DeterministicScore::ZERO;
42    }
43    let sum: i128 = scores.iter().map(|s| s.to_raw() as i128).sum();
44    DeterministicScore::from_raw(sum.clamp(
45        DeterministicScore::NEG_INF.to_raw() as i128,
46        i64::MAX as i128,
47    ) as i64)
48}
49
50#[inline]
51pub fn avg_scores(scores: &[DeterministicScore]) -> DeterministicScore {
52    if scores.is_empty() {
53        return DeterministicScore::ZERO;
54    }
55    let sum: i128 = scores.iter().map(|s| s.to_raw() as i128).sum();
56    let mean = sum / scores.len() as i128;
57    DeterministicScore::from_raw(mean.clamp(
58        DeterministicScore::NEG_INF.to_raw() as i128,
59        i64::MAX as i128,
60    ) as i64)
61}
62
63#[inline]
64pub fn avg_scores_checked(scores: &[DeterministicScore]) -> (DeterministicScore, bool) {
65    if scores.is_empty() {
66        return (DeterministicScore::ZERO, false);
67    }
68    const SATURATION_THRESHOLD: i128 = (i64::MAX as i128) * 9 / 10;
69    let mut sum = 0i128;
70    let mut near_saturation = false;
71    for score in scores {
72        sum += score.to_raw() as i128;
73        near_saturation |= sum.abs() > SATURATION_THRESHOLD;
74    }
75    let mean = sum / scores.len() as i128;
76    near_saturation |= mean.abs() > SATURATION_THRESHOLD;
77    let result = DeterministicScore::from_raw(mean.clamp(
78        DeterministicScore::NEG_INF.to_raw() as i128,
79        i64::MAX as i128,
80    ) as i64);
81    (result, near_saturation)
82}
83
84#[inline]
85pub fn max_score(scores: &[DeterministicScore]) -> DeterministicScore {
86    scores
87        .iter()
88        .copied()
89        .max()
90        .unwrap_or(DeterministicScore::NEG_INF)
91}
92
93#[inline]
94pub fn min_score(scores: &[DeterministicScore]) -> DeterministicScore {
95    scores
96        .iter()
97        .copied()
98        .min()
99        .unwrap_or(DeterministicScore::MAX)
100}
101
102/// Reciprocal Rank Fusion score: `1 / (k + rank)`.
103#[inline]
104pub fn rrf_score(rank: usize, k: usize) -> DeterministicScore {
105    let Some(denominator) = k.checked_add(rank) else {
106        return DeterministicScore::ZERO;
107    };
108    if denominator == 0 {
109        return DeterministicScore::ZERO;
110    }
111    DeterministicScore::from_f64(1.0 / (denominator as f64))
112}
113
114const SCALE_RAW: i128 = 4_294_967_296; // 2^32 — matches DeterministicScore::SCALE
115
116#[inline]
117pub fn weighted_sum(
118    scores: &[DeterministicScore],
119    weights: &[f64],
120) -> Result<DeterministicScore, ScoreError> {
121    if scores.len() != weights.len() {
122        return Err(ScoreError::LengthMismatch {
123            expected_desc: "scores and weights must have same length",
124            first_len: scores.len(),
125            second_len: weights.len(),
126        });
127    }
128    let mut acc = 0i128;
129    for (index, (&score, &weight)) in scores.iter().zip(weights.iter()).enumerate() {
130        if !weight.is_finite() {
131            return Err(ScoreError::NonFiniteWeight { index });
132        }
133        let w = DeterministicScore::from_f64(weight);
134        acc += (score.to_raw() as i128 * w.to_raw() as i128) / SCALE_RAW;
135    }
136    Ok(DeterministicScore::from_raw(acc.clamp(
137        DeterministicScore::NEG_INF.to_raw() as i128,
138        DeterministicScore::MAX.to_raw() as i128,
139    ) as i64))
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    fn s(v: f64) -> DeterministicScore {
147        DeterministicScore::from_f64(v)
148    }
149
150    #[test]
151    fn sum_basic() {
152        let scores = [s(0.1), s(0.2), s(0.3)];
153        let result = sum_scores(&scores);
154        assert!((result.to_f64() - 0.6).abs() < 1e-9);
155    }
156
157    #[test]
158    fn sum_empty() {
159        let result = sum_scores(&[]);
160        assert_eq!(result, DeterministicScore::ZERO);
161    }
162
163    #[test]
164    fn avg_basic() {
165        let scores = [s(0.1), s(0.2), s(0.3)];
166        let result = avg_scores(&scores);
167        assert!((result.to_f64() - 0.2).abs() < 1e-9);
168    }
169
170    #[test]
171    fn rrf_basic() {
172        let r1 = rrf_score(1, 60);
173        let r2 = rrf_score(2, 60);
174        assert!(r1 > r2);
175        assert!((r1.to_f64() - 1.0 / 61.0).abs() < 1e-9);
176    }
177
178    #[test]
179    fn weighted_sum_basic() {
180        let scores = [s(0.5), s(1.0)];
181        let weights = [0.4, 0.6];
182        let result = weighted_sum(&scores, &weights).unwrap();
183        assert!((result.to_f64() - 0.8).abs() < 1e-6);
184    }
185
186    #[test]
187    fn weighted_sum_length_mismatch() {
188        let err = weighted_sum(&[s(0.1)], &[0.5, 0.5]).unwrap_err();
189        assert!(matches!(err, ScoreError::LengthMismatch { .. }));
190    }
191
192    #[test]
193    fn weighted_sum_rejects_nan() {
194        let err = weighted_sum(&[s(0.1)], &[f64::NAN]).unwrap_err();
195        assert!(matches!(err, ScoreError::NonFiniteWeight { index: 0 }));
196    }
197
198    #[test]
199    fn sum_negative_saturation_clamps_to_neg_inf() {
200        let big_neg = DeterministicScore::NEG_INF;
201        let result = sum_scores(&[big_neg, big_neg, big_neg]);
202        assert_eq!(result, DeterministicScore::NEG_INF);
203        assert!(result.is_infinite());
204        assert_eq!(result.to_f64(), f64::NEG_INFINITY);
205    }
206
207    #[test]
208    fn avg_negative_saturation_clamps_to_neg_inf() {
209        let big_neg = DeterministicScore::NEG_INF;
210        let result = avg_scores(&[big_neg, big_neg]);
211        assert_eq!(result, DeterministicScore::NEG_INF);
212    }
213
214    #[test]
215    fn sum_order_independent() {
216        let a = DeterministicScore::from_f64(1e9);
217        let b = DeterministicScore::from_f64(-1e9);
218        let c = DeterministicScore::from_f64(0.5);
219        let r1 = sum_scores(&[a, b, c]);
220        let r2 = sum_scores(&[c, a, b]);
221        let r3 = sum_scores(&[b, c, a]);
222        assert_eq!(r1, r2);
223        assert_eq!(r2, r3);
224    }
225
226    #[test]
227    fn avg_scores_checked_empty_returns_zero_no_flag() {
228        let (mean, flag) = avg_scores_checked(&[]);
229        assert_eq!(mean, DeterministicScore::ZERO);
230        assert!(!flag);
231    }
232
233    #[test]
234    fn avg_scores_checked_near_saturation_sets_flag() {
235        let (_, flag) = avg_scores_checked(&[DeterministicScore::MAX, DeterministicScore::MAX]);
236        assert!(flag);
237    }
238
239    #[test]
240    fn max_score_empty_returns_neg_inf() {
241        assert_eq!(max_score(&[]), DeterministicScore::NEG_INF);
242    }
243
244    #[test]
245    fn min_score_empty_returns_max() {
246        assert_eq!(min_score(&[]), DeterministicScore::MAX);
247    }
248
249    #[test]
250    fn rrf_score_zero_denominator_returns_zero() {
251        assert_eq!(rrf_score(0, 0), DeterministicScore::ZERO);
252    }
253
254    #[test]
255    fn rrf_score_overflow_returns_zero() {
256        assert_eq!(rrf_score(usize::MAX, 1), DeterministicScore::ZERO);
257    }
258
259    #[test]
260    fn weighted_sum_empty_returns_zero() {
261        assert_eq!(weighted_sum(&[], &[]).unwrap(), DeterministicScore::ZERO);
262    }
263
264    #[test]
265    fn weighted_sum_rejects_infinite_weight() {
266        let err = weighted_sum(&[s(1.0)], &[f64::INFINITY]).unwrap_err();
267        assert_eq!(err, ScoreError::NonFiniteWeight { index: 0 });
268    }
269}