ai_tokenopt 0.5.10

Adaptive token optimization engine for LLM inference pipelines — compresses prompts, conversation history, tool schemas, and output streams to minimize token usage while preserving response quality.
Documentation
//! Per-model token estimation calibration.
//!
//! Tracks the ratio between estimated and actual token counts reported
//! by the LLM (e.g. Ollama's `prompt_eval_count`). Applies a learned
//! correction factor to future estimates, converging within ~5 calls.
//!
//! The calibrator is in-memory only — correction factors reset on restart.

use std::collections::HashMap;

/// Per-model correction factor for token estimation.
///
/// Tracks observed `actual / estimated` ratios and blends them with
/// an exponential moving average to produce a smoothed correction factor.
#[derive(Debug, Clone)]
pub struct EstimationCalibrator {
    /// Model name → correction multiplier (1.0 = no correction)
    corrections: HashMap<String, f32>,
}

/// Smoothing weight for blending new observations with existing factor.
///
/// `new_factor = (1 - SMOOTHING) * old_factor + SMOOTHING * observed_ratio`
const SMOOTHING: f32 = 0.2;

impl EstimationCalibrator {
    /// Create a new calibrator with no correction history.
    #[must_use]
    pub fn new() -> Self {
        Self {
            corrections: HashMap::new(),
        }
    }

    /// Record an observed estimated vs. actual token count pair.
    ///
    /// Updates the stored correction factor for the given model using
    /// exponential moving average blending.
    pub fn record_observation(&mut self, model: &str, estimated: u32, actual: u32) {
        if estimated == 0 {
            return;
        }

        #[allow(clippy::cast_precision_loss)]
        let observed_ratio = actual as f32 / estimated as f32;

        let factor = self.corrections.entry(model.to_string()).or_insert(1.0);
        *factor = SMOOTHING.mul_add(observed_ratio, (1.0 - SMOOTHING) * *factor);
    }

    /// Apply the learned correction factor to a raw estimate.
    ///
    /// Returns the raw estimate if no correction is available for the model.
    #[must_use]
    pub fn corrected_estimate(&self, model: &str, raw_estimate: u32) -> u32 {
        let factor = self.corrections.get(model).copied().unwrap_or(1.0);
        #[allow(
            clippy::cast_possible_truncation,
            clippy::cast_sign_loss,
            clippy::cast_precision_loss
        )]
        let corrected = (raw_estimate as f32 * factor).ceil() as u32;
        corrected.max(1)
    }

    /// Get the current correction factor for a model (1.0 if unknown).
    #[must_use]
    pub fn correction_factor(&self, model: &str) -> f32 {
        self.corrections.get(model).copied().unwrap_or(1.0)
    }
}

impl Default for EstimationCalibrator {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn cold_start_returns_raw_estimate() {
        let cal = EstimationCalibrator::new();
        assert_eq!(cal.corrected_estimate("llama3", 100), 100);
    }

    #[test]
    fn cold_start_factor_is_one() {
        let cal = EstimationCalibrator::new();
        assert!((cal.correction_factor("llama3") - 1.0).abs() < f32::EPSILON);
    }

    #[test]
    fn single_observation_blends_with_default() {
        let mut cal = EstimationCalibrator::new();
        // Estimated 100, actual 120 → ratio 1.2
        // factor = 0.8 * 1.0 + 0.2 * 1.2 = 1.04
        cal.record_observation("llama3", 100, 120);
        let factor = cal.correction_factor("llama3");
        assert!((factor - 1.04).abs() < 0.01);
    }

    #[test]
    fn convergence_after_five_observations() {
        let mut cal = EstimationCalibrator::new();
        // Consistently underestimate by 20% (actual/estimated = 1.2)
        for _ in 0..10 {
            cal.record_observation("model", 100, 120);
        }
        let factor = cal.correction_factor("model");
        // Should converge close to 1.2
        assert!((factor - 1.2).abs() < 0.05);
    }

    #[test]
    fn multiple_models_independent() {
        let mut cal = EstimationCalibrator::new();
        cal.record_observation("model_a", 100, 80); // overestimate
        cal.record_observation("model_b", 100, 130); // underestimate

        let factor_a = cal.correction_factor("model_a");
        let factor_b = cal.correction_factor("model_b");

        assert!(factor_a < 1.0);
        assert!(factor_b > 1.0);
    }

    #[test]
    fn zero_estimated_is_ignored() {
        let mut cal = EstimationCalibrator::new();
        cal.record_observation("model", 0, 100);
        assert!((cal.correction_factor("model") - 1.0).abs() < f32::EPSILON);
    }

    #[test]
    fn corrected_estimate_applies_factor() {
        let mut cal = EstimationCalibrator::new();
        // Drive factor toward 1.2
        for _ in 0..20 {
            cal.record_observation("model", 100, 120);
        }
        let corrected = cal.corrected_estimate("model", 100);
        // Should be approximately 120
        assert!((115..=125).contains(&corrected));
    }

    #[test]
    fn corrected_estimate_minimum_is_one() {
        let cal = EstimationCalibrator::new();
        assert!(cal.corrected_estimate("model", 0) >= 1);
    }
}