Skip to main content

khive_score/
ops.rs

1//! Aggregation and fusion operations for deterministic scores.
2
3use crate::DeterministicScore;
4use std::fmt;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum ScoreError {
8    LengthMismatch {
9        expected_desc: &'static str,
10        first_len: usize,
11        second_len: usize,
12    },
13    NonFiniteWeight {
14        index: usize,
15    },
16}
17
18impl fmt::Display for ScoreError {
19    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20        match self {
21            ScoreError::LengthMismatch {
22                expected_desc,
23                first_len,
24                second_len,
25            } => write!(
26                f,
27                "{expected_desc}: first has {first_len} elements, second has {second_len}"
28            ),
29            ScoreError::NonFiniteWeight { index } => {
30                write!(f, "weight at index {index} must be finite")
31            }
32        }
33    }
34}
35
36impl std::error::Error for ScoreError {}
37
38#[inline]
39pub fn sum_scores(scores: &[DeterministicScore]) -> DeterministicScore {
40    if scores.is_empty() {
41        return DeterministicScore::ZERO;
42    }
43    let sum: i128 = scores.iter().map(|s| s.to_raw() as i128).sum();
44    DeterministicScore::from_raw(sum.clamp(
45        DeterministicScore::NEG_INF.to_raw() as i128,
46        i64::MAX as i128,
47    ) as i64)
48}
49
50#[inline]
51pub fn avg_scores(scores: &[DeterministicScore]) -> DeterministicScore {
52    if scores.is_empty() {
53        return DeterministicScore::ZERO;
54    }
55    let sum: i128 = scores.iter().map(|s| s.to_raw() as i128).sum();
56    let mean = sum / scores.len() as i128;
57    DeterministicScore::from_raw(mean.clamp(
58        DeterministicScore::NEG_INF.to_raw() as i128,
59        i64::MAX as i128,
60    ) as i64)
61}
62
63#[inline]
64pub fn avg_scores_checked(scores: &[DeterministicScore]) -> (DeterministicScore, bool) {
65    if scores.is_empty() {
66        return (DeterministicScore::ZERO, false);
67    }
68    const SATURATION_THRESHOLD: i128 = (i64::MAX as i128) * 9 / 10;
69    let mut sum = 0i128;
70    let mut near_saturation = false;
71    for score in scores {
72        sum += score.to_raw() as i128;
73        near_saturation |= sum.abs() > SATURATION_THRESHOLD;
74    }
75    let mean = sum / scores.len() as i128;
76    near_saturation |= mean.abs() > SATURATION_THRESHOLD;
77    let result = DeterministicScore::from_raw(mean.clamp(
78        DeterministicScore::NEG_INF.to_raw() as i128,
79        i64::MAX as i128,
80    ) as i64);
81    (result, near_saturation)
82}
83
84#[inline]
85pub fn max_score(scores: &[DeterministicScore]) -> DeterministicScore {
86    scores
87        .iter()
88        .copied()
89        .max()
90        .unwrap_or(DeterministicScore::NEG_INF)
91}
92
93#[inline]
94pub fn min_score(scores: &[DeterministicScore]) -> DeterministicScore {
95    scores
96        .iter()
97        .copied()
98        .min()
99        .unwrap_or(DeterministicScore::MAX)
100}
101
102/// Reciprocal Rank Fusion score: `1 / (k + rank)`.
103#[inline]
104pub fn rrf_score(rank: usize, k: usize) -> DeterministicScore {
105    let Some(denominator) = k.checked_add(rank) else {
106        return DeterministicScore::ZERO;
107    };
108    if denominator == 0 {
109        return DeterministicScore::ZERO;
110    }
111    DeterministicScore::from_f64(1.0 / (denominator as f64))
112}
113
114#[inline]
115pub fn weighted_sum(
116    scores: &[DeterministicScore],
117    weights: &[f64],
118) -> Result<DeterministicScore, ScoreError> {
119    if scores.len() != weights.len() {
120        return Err(ScoreError::LengthMismatch {
121            expected_desc: "scores and weights must have same length",
122            first_len: scores.len(),
123            second_len: weights.len(),
124        });
125    }
126    let mut acc = DeterministicScore::ZERO;
127    for (index, (&score, &weight)) in scores.iter().zip(weights.iter()).enumerate() {
128        if !weight.is_finite() {
129            return Err(ScoreError::NonFiniteWeight { index });
130        }
131        acc = acc + score * weight;
132    }
133    Ok(acc)
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    fn s(v: f64) -> DeterministicScore {
141        DeterministicScore::from_f64(v)
142    }
143
144    #[test]
145    fn sum_basic() {
146        let scores = [s(0.1), s(0.2), s(0.3)];
147        let result = sum_scores(&scores);
148        assert!((result.to_f64() - 0.6).abs() < 1e-9);
149    }
150
151    #[test]
152    fn sum_empty() {
153        let result = sum_scores(&[]);
154        assert_eq!(result, DeterministicScore::ZERO);
155    }
156
157    #[test]
158    fn avg_basic() {
159        let scores = [s(0.1), s(0.2), s(0.3)];
160        let result = avg_scores(&scores);
161        assert!((result.to_f64() - 0.2).abs() < 1e-9);
162    }
163
164    #[test]
165    fn rrf_basic() {
166        let r1 = rrf_score(1, 60);
167        let r2 = rrf_score(2, 60);
168        assert!(r1 > r2);
169        assert!((r1.to_f64() - 1.0 / 61.0).abs() < 1e-9);
170    }
171
172    #[test]
173    fn weighted_sum_basic() {
174        let scores = [s(0.5), s(1.0)];
175        let weights = [0.4, 0.6];
176        let result = weighted_sum(&scores, &weights).unwrap();
177        assert!((result.to_f64() - 0.8).abs() < 1e-6);
178    }
179
180    #[test]
181    fn weighted_sum_length_mismatch() {
182        let err = weighted_sum(&[s(0.1)], &[0.5, 0.5]).unwrap_err();
183        assert!(matches!(err, ScoreError::LengthMismatch { .. }));
184    }
185
186    #[test]
187    fn weighted_sum_rejects_nan() {
188        let err = weighted_sum(&[s(0.1)], &[f64::NAN]).unwrap_err();
189        assert!(matches!(err, ScoreError::NonFiniteWeight { index: 0 }));
190    }
191
192    #[test]
193    fn sum_negative_saturation_clamps_to_neg_inf() {
194        let big_neg = DeterministicScore::NEG_INF;
195        let result = sum_scores(&[big_neg, big_neg, big_neg]);
196        assert_eq!(result, DeterministicScore::NEG_INF);
197        assert!(result.is_infinite());
198        assert_eq!(result.to_f64(), f64::NEG_INFINITY);
199    }
200
201    #[test]
202    fn avg_negative_saturation_clamps_to_neg_inf() {
203        let big_neg = DeterministicScore::NEG_INF;
204        let result = avg_scores(&[big_neg, big_neg]);
205        assert_eq!(result, DeterministicScore::NEG_INF);
206    }
207
208    #[test]
209    fn sum_order_independent() {
210        let a = DeterministicScore::from_f64(1e9);
211        let b = DeterministicScore::from_f64(-1e9);
212        let c = DeterministicScore::from_f64(0.5);
213        let r1 = sum_scores(&[a, b, c]);
214        let r2 = sum_scores(&[c, a, b]);
215        let r3 = sum_scores(&[b, c, a]);
216        assert_eq!(r1, r2);
217        assert_eq!(r2, r3);
218    }
219}