Skip to main content

khive_score/
ops.rs

1//! Aggregation and fusion operations for deterministic scores.
2
3use crate::DeterministicScore;
4use std::fmt;
5use std::num::NonZeroUsize;
6
7/// Errors produced by score aggregation and distance conversion operations.
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub enum ScoreError {
10    /// The two input slices have different lengths.
11    LengthMismatch {
12        /// Human-readable description of which two slices were compared.
13        expected_desc: &'static str,
14        /// Length of the first slice.
15        first_len: usize,
16        /// Length of the second slice.
17        second_len: usize,
18    },
19    /// A weight at the given index is NaN or infinite.
20    NonFiniteWeight {
21        /// Zero-based index of the offending weight.
22        index: usize,
23    },
24    /// The distance value is NaN, `+Inf`, or `-Inf`.
25    NonFiniteDistance,
26    /// The distance is finite but outside the valid range for the metric.
27    InvalidDistanceRange {
28        /// The metric whose range was violated.
29        metric_name: &'static str,
30        /// Bit-representation of the out-of-range distance for diagnostics.
31        dist_bits: u32,
32    },
33    /// The distance metric is not one of the three currently supported variants.
34    UnsupportedMetric,
35}
36
37impl fmt::Display for ScoreError {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        match self {
40            ScoreError::LengthMismatch {
41                expected_desc,
42                first_len,
43                second_len,
44            } => write!(
45                f,
46                "{expected_desc}: first has {first_len} elements, second has {second_len}"
47            ),
48            ScoreError::NonFiniteWeight { index } => {
49                write!(f, "weight at index {index} must be finite")
50            }
51            ScoreError::NonFiniteDistance => {
52                write!(f, "distance must be finite (not NaN or infinity)")
53            }
54            ScoreError::InvalidDistanceRange {
55                metric_name,
56                dist_bits,
57            } => write!(
58                f,
59                "distance value (bits=0x{dist_bits:08x}) is out of valid range for metric {metric_name}"
60            ),
61            ScoreError::UnsupportedMetric => {
62                write!(f, "unsupported distance metric")
63            }
64        }
65    }
66}
67
68impl std::error::Error for ScoreError {}
69
70/// Return the saturating sum of `scores`, clamped to `[NEG_INF, MAX]`.
71#[inline]
72pub fn sum_scores(scores: &[DeterministicScore]) -> DeterministicScore {
73    if scores.is_empty() {
74        return DeterministicScore::ZERO;
75    }
76    let sum: i128 = scores.iter().map(|s| s.to_raw() as i128).sum();
77    DeterministicScore::from_raw(sum.clamp(
78        DeterministicScore::NEG_INF.to_raw() as i128,
79        i64::MAX as i128,
80    ) as i64)
81}
82
83/// Return the arithmetic mean of `scores`, clamped to `[NEG_INF, MAX]`.
84#[inline]
85pub fn avg_scores(scores: &[DeterministicScore]) -> DeterministicScore {
86    if scores.is_empty() {
87        return DeterministicScore::ZERO;
88    }
89    let sum: i128 = scores.iter().map(|s| s.to_raw() as i128).sum();
90    let mean = sum / scores.len() as i128;
91    DeterministicScore::from_raw(mean.clamp(
92        DeterministicScore::NEG_INF.to_raw() as i128,
93        i64::MAX as i128,
94    ) as i64)
95}
96
97/// Return the mean of `scores` and a boolean saturation flag.
98#[inline]
99pub fn avg_scores_checked(scores: &[DeterministicScore]) -> (DeterministicScore, bool) {
100    if scores.is_empty() {
101        return (DeterministicScore::ZERO, false);
102    }
103    const SATURATION_THRESHOLD: i128 = (i64::MAX as i128) * 9 / 10;
104    let sum: i128 = scores.iter().map(|s| s.to_raw() as i128).sum();
105    let mean = sum / scores.len() as i128;
106    // Use order-independent measures: check the absolute sum of all input
107    // magnitudes (independent of sign cancellation order) and the final mean.
108    let abs_mass: i128 = scores
109        .iter()
110        .map(|s| (s.to_raw() as i128).unsigned_abs() as i128)
111        .sum();
112    let near_saturation =
113        abs_mass > SATURATION_THRESHOLD || mean.unsigned_abs() as i128 > SATURATION_THRESHOLD;
114    let result = DeterministicScore::from_raw(mean.clamp(
115        DeterministicScore::NEG_INF.to_raw() as i128,
116        i64::MAX as i128,
117    ) as i64);
118    (result, near_saturation)
119}
120
121/// Return the maximum score, or [`DeterministicScore::NEG_INF`] for an empty slice.
122#[inline]
123pub fn max_score(scores: &[DeterministicScore]) -> DeterministicScore {
124    scores
125        .iter()
126        .copied()
127        .max()
128        .unwrap_or(DeterministicScore::NEG_INF)
129}
130
131/// Return the minimum score, or [`DeterministicScore::MAX`] for an empty slice.
132#[inline]
133pub fn min_score(scores: &[DeterministicScore]) -> DeterministicScore {
134    scores
135        .iter()
136        .copied()
137        .min()
138        .unwrap_or(DeterministicScore::MAX)
139}
140
141/// RRF score `1 / (k + rank)`. Rank is 1-based; prefer `rrf_score_one_based` or `rrf_score_zero_based`.
142#[inline]
143pub fn rrf_score(rank: usize, k: usize) -> DeterministicScore {
144    let Some(denominator) = k.checked_add(rank) else {
145        return DeterministicScore::ZERO;
146    };
147    if denominator == 0 {
148        return DeterministicScore::ZERO;
149    }
150    DeterministicScore::from_f64(1.0 / (denominator as f64))
151}
152
153/// RRF score with 1-based rank (first result = rank 1). `k` is the smoothing constant.
154#[inline]
155pub fn rrf_score_one_based(rank: NonZeroUsize, k: usize) -> DeterministicScore {
156    let Some(denominator) = k.checked_add(rank.get()) else {
157        return DeterministicScore::ZERO;
158    };
159    DeterministicScore::from_f64(1.0 / denominator as f64)
160}
161
162/// RRF score with 0-based index (index 0 → rank 1 internally).
163#[inline]
164pub fn rrf_score_zero_based(index: usize, k: usize) -> DeterministicScore {
165    let Some(rank) = index.checked_add(1).and_then(NonZeroUsize::new) else {
166        return DeterministicScore::ZERO;
167    };
168    rrf_score_one_based(rank, k)
169}
170
171const SCALE_RAW: i128 = 4_294_967_296; // 2^32 — matches DeterministicScore::SCALE
172
173/// Weighted sum of `scores`. Errors on length mismatch or non-finite weights.
174#[inline]
175pub fn weighted_sum(
176    scores: &[DeterministicScore],
177    weights: &[f64],
178) -> Result<DeterministicScore, ScoreError> {
179    if scores.len() != weights.len() {
180        return Err(ScoreError::LengthMismatch {
181            expected_desc: "scores and weights must have same length",
182            first_len: scores.len(),
183            second_len: weights.len(),
184        });
185    }
186    let mut acc = 0i128;
187    for (index, (&score, &weight)) in scores.iter().zip(weights.iter()).enumerate() {
188        if !weight.is_finite() {
189            return Err(ScoreError::NonFiniteWeight { index });
190        }
191        let w = DeterministicScore::from_f64(weight);
192        acc += (score.to_raw() as i128 * w.to_raw() as i128) / SCALE_RAW;
193    }
194    Ok(DeterministicScore::from_raw(acc.clamp(
195        DeterministicScore::NEG_INF.to_raw() as i128,
196        DeterministicScore::MAX.to_raw() as i128,
197    ) as i64))
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    fn s(v: f64) -> DeterministicScore {
205        DeterministicScore::from_f64(v)
206    }
207
208    #[test]
209    fn sum_basic() {
210        let scores = [s(0.1), s(0.2), s(0.3)];
211        let result = sum_scores(&scores);
212        assert!((result.to_f64() - 0.6).abs() < 1e-9);
213    }
214
215    #[test]
216    fn sum_empty() {
217        let result = sum_scores(&[]);
218        assert_eq!(result, DeterministicScore::ZERO);
219    }
220
221    #[test]
222    fn avg_basic() {
223        let scores = [s(0.1), s(0.2), s(0.3)];
224        let result = avg_scores(&scores);
225        assert!((result.to_f64() - 0.2).abs() < 1e-9);
226    }
227
228    #[test]
229    fn rrf_basic() {
230        let r1 = rrf_score(1, 60);
231        let r2 = rrf_score(2, 60);
232        assert!(r1 > r2);
233        assert!((r1.to_f64() - 1.0 / 61.0).abs() < 1e-9);
234    }
235
236    #[test]
237    fn weighted_sum_basic() {
238        let scores = [s(0.5), s(1.0)];
239        let weights = [0.4, 0.6];
240        let result = weighted_sum(&scores, &weights).unwrap();
241        assert!((result.to_f64() - 0.8).abs() < 1e-6);
242    }
243
244    #[test]
245    fn weighted_sum_length_mismatch() {
246        let err = weighted_sum(&[s(0.1)], &[0.5, 0.5]).unwrap_err();
247        assert!(matches!(err, ScoreError::LengthMismatch { .. }));
248    }
249
250    #[test]
251    fn weighted_sum_rejects_nan() {
252        let err = weighted_sum(&[s(0.1)], &[f64::NAN]).unwrap_err();
253        assert!(matches!(err, ScoreError::NonFiniteWeight { index: 0 }));
254    }
255
256    #[test]
257    fn sum_negative_saturation_clamps_to_neg_inf() {
258        let big_neg = DeterministicScore::NEG_INF;
259        let result = sum_scores(&[big_neg, big_neg, big_neg]);
260        assert_eq!(result, DeterministicScore::NEG_INF);
261        assert!(result.is_infinite());
262        assert_eq!(result.to_f64(), f64::NEG_INFINITY);
263    }
264
265    #[test]
266    fn avg_negative_saturation_clamps_to_neg_inf() {
267        let big_neg = DeterministicScore::NEG_INF;
268        let result = avg_scores(&[big_neg, big_neg]);
269        assert_eq!(result, DeterministicScore::NEG_INF);
270    }
271
272    #[test]
273    fn sum_order_independent() {
274        let a = DeterministicScore::from_f64(1e9);
275        let b = DeterministicScore::from_f64(-1e9);
276        let c = DeterministicScore::from_f64(0.5);
277        let r1 = sum_scores(&[a, b, c]);
278        let r2 = sum_scores(&[c, a, b]);
279        let r3 = sum_scores(&[b, c, a]);
280        assert_eq!(r1, r2);
281        assert_eq!(r2, r3);
282    }
283
284    #[test]
285    fn avg_scores_checked_empty_returns_zero_no_flag() {
286        let (mean, flag) = avg_scores_checked(&[]);
287        assert_eq!(mean, DeterministicScore::ZERO);
288        assert!(!flag);
289    }
290
291    #[test]
292    fn avg_scores_checked_near_saturation_sets_flag() {
293        let (_, flag) = avg_scores_checked(&[DeterministicScore::MAX, DeterministicScore::MAX]);
294        assert!(flag);
295    }
296
297    #[test]
298    fn max_score_empty_returns_neg_inf() {
299        assert_eq!(max_score(&[]), DeterministicScore::NEG_INF);
300    }
301
302    #[test]
303    fn min_score_empty_returns_max() {
304        assert_eq!(min_score(&[]), DeterministicScore::MAX);
305    }
306
307    #[test]
308    fn rrf_score_zero_denominator_returns_zero() {
309        assert_eq!(rrf_score(0, 0), DeterministicScore::ZERO);
310    }
311
312    #[test]
313    fn rrf_score_overflow_returns_zero() {
314        assert_eq!(rrf_score(usize::MAX, 1), DeterministicScore::ZERO);
315    }
316
317    #[test]
318    fn weighted_sum_empty_returns_zero() {
319        assert_eq!(weighted_sum(&[], &[]).unwrap(), DeterministicScore::ZERO);
320    }
321
322    #[test]
323    fn weighted_sum_rejects_infinite_weight() {
324        let err = weighted_sum(&[s(1.0)], &[f64::INFINITY]).unwrap_err();
325        assert_eq!(err, ScoreError::NonFiniteWeight { index: 0 });
326    }
327
328    // ── rrf_score_one_based / rrf_score_zero_based ────────────────────────────
329
330    #[test]
331    fn rrf_one_based_rank_1_equals_legacy_rank_1() {
332        use std::num::NonZeroUsize;
333        let one_based = rrf_score_one_based(NonZeroUsize::new(1).unwrap(), 60);
334        let legacy = rrf_score(1, 60);
335        assert_eq!(
336            one_based, legacy,
337            "rrf_score_one_based(1,60) must match rrf_score(1,60)"
338        );
339    }
340
341    #[test]
342    fn rrf_zero_based_index_0_equals_one_based_rank_1() {
343        use std::num::NonZeroUsize;
344        let zero_based = rrf_score_zero_based(0, 60);
345        let one_based = rrf_score_one_based(NonZeroUsize::new(1).unwrap(), 60);
346        assert_eq!(zero_based, one_based);
347    }
348
349    #[test]
350    fn rrf_one_based_monotone_decreasing() {
351        use std::num::NonZeroUsize;
352        let r1 = rrf_score_one_based(NonZeroUsize::new(1).unwrap(), 60);
353        let r2 = rrf_score_one_based(NonZeroUsize::new(2).unwrap(), 60);
354        let r10 = rrf_score_one_based(NonZeroUsize::new(10).unwrap(), 60);
355        assert!(r1 > r2);
356        assert!(r2 > r10);
357    }
358
359    #[test]
360    fn rrf_one_based_value_matches_formula() {
361        use std::num::NonZeroUsize;
362        let score = rrf_score_one_based(NonZeroUsize::new(1).unwrap(), 60);
363        assert!((score.to_f64() - 1.0 / 61.0).abs() < 1e-9);
364    }
365
366    #[test]
367    fn rrf_one_based_overflow_returns_zero() {
368        use std::num::NonZeroUsize;
369        let score = rrf_score_one_based(NonZeroUsize::new(usize::MAX).unwrap(), 1);
370        assert_eq!(score, DeterministicScore::ZERO);
371    }
372
373    #[test]
374    fn rrf_zero_based_overflow_returns_zero() {
375        let score = rrf_score_zero_based(usize::MAX, 1);
376        assert_eq!(score, DeterministicScore::ZERO);
377    }
378
379    // ── avg_scores_checked order-independence ─────────────────────────────────
380
381    #[test]
382    fn avg_scores_checked_saturation_flag_is_order_independent() {
383        // Build two orderings of the same multiset.
384        // The near_saturation flag must be the same regardless of order.
385        let big = DeterministicScore::from_raw(i64::MAX / 2);
386        let neg = DeterministicScore::from_raw(i64::MIN / 2 + 1);
387        let scores_order_a = [big, neg, big, neg];
388        let scores_order_b = [big, big, neg, neg];
389        let (_, flag_a) = avg_scores_checked(&scores_order_a);
390        let (_, flag_b) = avg_scores_checked(&scores_order_b);
391        assert_eq!(
392            flag_a, flag_b,
393            "near_saturation flag must be order-independent: order_a={flag_a}, order_b={flag_b}"
394        );
395    }
396
397    #[test]
398    fn avg_scores_checked_normal_values_no_saturation_flag() {
399        let scores = [s(0.1), s(0.2), s(-0.1), s(0.3)];
400        let (_, flag) = avg_scores_checked(&scores);
401        assert!(!flag, "small scores should not trigger near_saturation");
402    }
403}