use super::types::RerankError;
pub fn recall_scale(
target_recall: Option<f32>,
base_ef: usize,
base_oversample: u8,
) -> Result<(usize, u8), RerankError> {
let r = match target_recall {
None => return Ok((base_ef, base_oversample)),
Some(v) => v,
};
if r.is_nan() || r <= 0.0 || r > 1.0 {
return Err(RerankError::BadInput(format!(
"target_recall must be in (0.0, 1.0], got {r}"
)));
}
if r <= 0.80 {
return Ok((base_ef, base_oversample));
}
let scale = (1.0_f32 + (r - 0.80) / 0.20 * 4.0).clamp(1.0, 5.0);
let ef = base_ef.max((base_ef as f32 * scale).ceil() as usize);
let oversample_scaled = (base_oversample as f32 * scale.sqrt()).ceil() as u32;
let oversample = base_oversample.max(oversample_scaled.min(u8::MAX as u32) as u8);
Ok((ef, oversample))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn none_returns_base() {
assert_eq!(recall_scale(None, 100, 1).unwrap(), (100, 1));
}
#[test]
fn boundary_identity_at_0_80() {
assert_eq!(recall_scale(Some(0.80), 100, 1).unwrap(), (100, 1));
}
#[test]
fn recall_0_90_scale_3() {
assert_eq!(recall_scale(Some(0.90), 100, 1).unwrap(), (300, 2));
}
#[test]
fn recall_1_00_scale_5() {
assert_eq!(recall_scale(Some(1.00), 100, 1).unwrap(), (500, 3));
}
#[test]
fn recall_0_95_spot_check() {
assert_eq!(recall_scale(Some(0.95), 200, 4).unwrap(), (800, 8));
}
#[test]
fn zero_is_bad_input() {
assert!(matches!(
recall_scale(Some(0.0), 100, 1),
Err(RerankError::BadInput(_))
));
}
#[test]
fn above_one_is_bad_input() {
assert!(matches!(
recall_scale(Some(1.01), 100, 1),
Err(RerankError::BadInput(_))
));
}
#[test]
fn nan_is_bad_input() {
assert!(matches!(
recall_scale(Some(f32::NAN), 100, 1),
Err(RerankError::BadInput(_))
));
}
#[test]
fn negative_is_bad_input() {
assert!(matches!(
recall_scale(Some(-0.5), 100, 1),
Err(RerankError::BadInput(_))
));
}
}