Skip to main content

nodedb_vector/rerank/
recall.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use super::types::RerankError;
4
5/// Scale `ef_search` and `oversample` up to meet a recall target.
6///
7/// `target_recall` must be in (0.0, 1.0]. Returns `(ef, oversample)` adjusted
8/// for the given target. When `target_recall` is `None` or already-default,
9/// returns the base values unchanged.
10///
11/// Formula (heuristic, not a guarantee):
12/// - For `r <= 0.80`: identity — `ef = base_ef`, `oversample = base_oversample`.
13/// - For `r > 0.80`: ramp scale from 1.0× at r=0.80 to 5.0× at r=1.00, linearly.
14///   `scale = 1.0 + (r - 0.80) / 0.20 * 4.0`, clamped to `[1.0, 5.0]`.
15/// - `ef` becomes `max(base_ef, (base_ef as f32 * scale).ceil() as usize)`.
16/// - `oversample` becomes `max(base_oversample, (base_oversample as f32 * scale.sqrt()).ceil() as u8)`
17///   — oversample grows sub-linearly because rerank cost is linear in oversample,
18///   while ef has a more favourable cost curve.
19pub fn recall_scale(
20    target_recall: Option<f32>,
21    base_ef: usize,
22    base_oversample: u8,
23) -> Result<(usize, u8), RerankError> {
24    let r = match target_recall {
25        None => return Ok((base_ef, base_oversample)),
26        Some(v) => v,
27    };
28
29    if r.is_nan() || r <= 0.0 || r > 1.0 {
30        return Err(RerankError::BadInput(format!(
31            "target_recall must be in (0.0, 1.0], got {r}"
32        )));
33    }
34
35    if r <= 0.80 {
36        return Ok((base_ef, base_oversample));
37    }
38
39    let scale = (1.0_f32 + (r - 0.80) / 0.20 * 4.0).clamp(1.0, 5.0);
40
41    let ef = base_ef.max((base_ef as f32 * scale).ceil() as usize);
42
43    let oversample_scaled = (base_oversample as f32 * scale.sqrt()).ceil() as u32;
44    let oversample = base_oversample.max(oversample_scaled.min(u8::MAX as u32) as u8);
45
46    Ok((ef, oversample))
47}
48
49#[cfg(test)]
50mod tests {
51    use super::*;
52
53    #[test]
54    fn none_returns_base() {
55        assert_eq!(recall_scale(None, 100, 1).unwrap(), (100, 1));
56    }
57
58    #[test]
59    fn boundary_identity_at_0_80() {
60        assert_eq!(recall_scale(Some(0.80), 100, 1).unwrap(), (100, 1));
61    }
62
63    #[test]
64    fn recall_0_90_scale_3() {
65        // scale = 1 + (0.10/0.20)*4 = 3.0; ef = 300; oversample = ceil(sqrt(3)) = 2
66        assert_eq!(recall_scale(Some(0.90), 100, 1).unwrap(), (300, 2));
67    }
68
69    #[test]
70    fn recall_1_00_scale_5() {
71        // scale = 5.0; ef = 500; oversample = ceil(sqrt(5)) = ceil(2.236) = 3
72        assert_eq!(recall_scale(Some(1.00), 100, 1).unwrap(), (500, 3));
73    }
74
75    #[test]
76    fn recall_0_95_spot_check() {
77        // scale = 1 + (0.15/0.20)*4 = 4.0; ef = 800; oversample = ceil(4*sqrt(4)) = ceil(8) = 8
78        assert_eq!(recall_scale(Some(0.95), 200, 4).unwrap(), (800, 8));
79    }
80
81    #[test]
82    fn zero_is_bad_input() {
83        assert!(matches!(
84            recall_scale(Some(0.0), 100, 1),
85            Err(RerankError::BadInput(_))
86        ));
87    }
88
89    #[test]
90    fn above_one_is_bad_input() {
91        assert!(matches!(
92            recall_scale(Some(1.01), 100, 1),
93            Err(RerankError::BadInput(_))
94        ));
95    }
96
97    #[test]
98    fn nan_is_bad_input() {
99        assert!(matches!(
100            recall_scale(Some(f32::NAN), 100, 1),
101            Err(RerankError::BadInput(_))
102        ));
103    }
104
105    #[test]
106    fn negative_is_bad_input() {
107        assert!(matches!(
108            recall_scale(Some(-0.5), 100, 1),
109            Err(RerankError::BadInput(_))
110        ));
111    }
112}