use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct EstimationCalibrator {
corrections: HashMap<String, f32>,
}
const SMOOTHING: f32 = 0.2;
impl EstimationCalibrator {
#[must_use]
pub fn new() -> Self {
Self {
corrections: HashMap::new(),
}
}
pub fn record_observation(&mut self, model: &str, estimated: u32, actual: u32) {
if estimated == 0 {
return;
}
#[allow(clippy::cast_precision_loss)]
let observed_ratio = actual as f32 / estimated as f32;
let factor = self.corrections.entry(model.to_string()).or_insert(1.0);
*factor = SMOOTHING.mul_add(observed_ratio, (1.0 - SMOOTHING) * *factor);
}
#[must_use]
pub fn corrected_estimate(&self, model: &str, raw_estimate: u32) -> u32 {
let factor = self.corrections.get(model).copied().unwrap_or(1.0);
#[allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::cast_precision_loss
)]
let corrected = (raw_estimate as f32 * factor).ceil() as u32;
corrected.max(1)
}
#[must_use]
pub fn correction_factor(&self, model: &str) -> f32 {
self.corrections.get(model).copied().unwrap_or(1.0)
}
}
impl Default for EstimationCalibrator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cold_start_returns_raw_estimate() {
let cal = EstimationCalibrator::new();
assert_eq!(cal.corrected_estimate("llama3", 100), 100);
}
#[test]
fn cold_start_factor_is_one() {
let cal = EstimationCalibrator::new();
assert!((cal.correction_factor("llama3") - 1.0).abs() < f32::EPSILON);
}
#[test]
fn single_observation_blends_with_default() {
let mut cal = EstimationCalibrator::new();
cal.record_observation("llama3", 100, 120);
let factor = cal.correction_factor("llama3");
assert!((factor - 1.04).abs() < 0.01);
}
#[test]
fn convergence_after_five_observations() {
let mut cal = EstimationCalibrator::new();
for _ in 0..10 {
cal.record_observation("model", 100, 120);
}
let factor = cal.correction_factor("model");
assert!((factor - 1.2).abs() < 0.05);
}
#[test]
fn multiple_models_independent() {
let mut cal = EstimationCalibrator::new();
cal.record_observation("model_a", 100, 80); cal.record_observation("model_b", 100, 130);
let factor_a = cal.correction_factor("model_a");
let factor_b = cal.correction_factor("model_b");
assert!(factor_a < 1.0);
assert!(factor_b > 1.0);
}
#[test]
fn zero_estimated_is_ignored() {
let mut cal = EstimationCalibrator::new();
cal.record_observation("model", 0, 100);
assert!((cal.correction_factor("model") - 1.0).abs() < f32::EPSILON);
}
#[test]
fn corrected_estimate_applies_factor() {
let mut cal = EstimationCalibrator::new();
for _ in 0..20 {
cal.record_observation("model", 100, 120);
}
let corrected = cal.corrected_estimate("model", 100);
assert!((115..=125).contains(&corrected));
}
#[test]
fn corrected_estimate_minimum_is_one() {
let cal = EstimationCalibrator::new();
assert!(cal.corrected_estimate("model", 0) >= 1);
}
}