nodedb_vector/rerank/
recall.rs1use super::types::RerankError;
4
5pub fn recall_scale(
20 target_recall: Option<f32>,
21 base_ef: usize,
22 base_oversample: u8,
23) -> Result<(usize, u8), RerankError> {
24 let r = match target_recall {
25 None => return Ok((base_ef, base_oversample)),
26 Some(v) => v,
27 };
28
29 if r.is_nan() || r <= 0.0 || r > 1.0 {
30 return Err(RerankError::BadInput(format!(
31 "target_recall must be in (0.0, 1.0], got {r}"
32 )));
33 }
34
35 if r <= 0.80 {
36 return Ok((base_ef, base_oversample));
37 }
38
39 let scale = (1.0_f32 + (r - 0.80) / 0.20 * 4.0).clamp(1.0, 5.0);
40
41 let ef = base_ef.max((base_ef as f32 * scale).ceil() as usize);
42
43 let oversample_scaled = (base_oversample as f32 * scale.sqrt()).ceil() as u32;
44 let oversample = base_oversample.max(oversample_scaled.min(u8::MAX as u32) as u8);
45
46 Ok((ef, oversample))
47}
48
49#[cfg(test)]
50mod tests {
51 use super::*;
52
53 #[test]
54 fn none_returns_base() {
55 assert_eq!(recall_scale(None, 100, 1).unwrap(), (100, 1));
56 }
57
58 #[test]
59 fn boundary_identity_at_0_80() {
60 assert_eq!(recall_scale(Some(0.80), 100, 1).unwrap(), (100, 1));
61 }
62
63 #[test]
64 fn recall_0_90_scale_3() {
65 assert_eq!(recall_scale(Some(0.90), 100, 1).unwrap(), (300, 2));
67 }
68
69 #[test]
70 fn recall_1_00_scale_5() {
71 assert_eq!(recall_scale(Some(1.00), 100, 1).unwrap(), (500, 3));
73 }
74
75 #[test]
76 fn recall_0_95_spot_check() {
77 assert_eq!(recall_scale(Some(0.95), 200, 4).unwrap(), (800, 8));
79 }
80
81 #[test]
82 fn zero_is_bad_input() {
83 assert!(matches!(
84 recall_scale(Some(0.0), 100, 1),
85 Err(RerankError::BadInput(_))
86 ));
87 }
88
89 #[test]
90 fn above_one_is_bad_input() {
91 assert!(matches!(
92 recall_scale(Some(1.01), 100, 1),
93 Err(RerankError::BadInput(_))
94 ));
95 }
96
97 #[test]
98 fn nan_is_bad_input() {
99 assert!(matches!(
100 recall_scale(Some(f32::NAN), 100, 1),
101 Err(RerankError::BadInput(_))
102 ));
103 }
104
105 #[test]
106 fn negative_is_bad_input() {
107 assert!(matches!(
108 recall_scale(Some(-0.5), 100, 1),
109 Err(RerankError::BadInput(_))
110 ));
111 }
112}