use std::collections::VecDeque;
use serde::{Deserialize, Serialize};
use crate::math::clamp;
pub const DEFAULT_TOKEN_CAP: u32 = 10_000;
pub const DEFAULT_HISTORY_WINDOW: usize = 20;
pub const QUALITY_DECLINE_WINDOW: usize = 3;
pub const QUALITY_DECLINE_MIN_DELTA: f64 = 0.15;
pub const POOR_QUALITY_MEAN: f64 = 0.5;
pub const TYPICAL_TURN_TOKENS_OUT: u32 = 1_000;
pub const TYPICAL_TURN_WALLCLOCK_MS: u32 = 10_000;
pub const TOKEN_COST_WEIGHT: f64 = 0.7;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostAccumulator {
total_tokens_in: u32,
total_tokens_out: u32,
total_wallclock_ms: u64,
turn_count: usize,
cap_tokens: u32,
quality_history: VecDeque<f64>,
history_window: usize,
}
impl CostAccumulator {
pub fn new() -> Self {
Self::with_cap(DEFAULT_TOKEN_CAP)
}
pub fn with_cap(cap_tokens: u32) -> Self {
Self {
total_tokens_in: 0,
total_tokens_out: 0,
total_wallclock_ms: 0,
turn_count: 0,
cap_tokens,
quality_history: VecDeque::new(),
history_window: DEFAULT_HISTORY_WINDOW,
}
}
pub fn record_cost(&mut self, tokens_in: u32, tokens_out: u32, wallclock_ms: u32) {
self.total_tokens_in = self.total_tokens_in.saturating_add(tokens_in);
self.total_tokens_out = self.total_tokens_out.saturating_add(tokens_out);
self.total_wallclock_ms = self
.total_wallclock_ms
.saturating_add(u64::from(wallclock_ms));
self.turn_count = self.turn_count.saturating_add(1);
}
pub fn record_quality(&mut self, quality: f64) {
if !quality.is_finite() {
return;
}
let q = clamp(quality, 0.0, 1.0);
self.quality_history.push_back(q);
if self.quality_history.len() > self.history_window {
self.quality_history.pop_front();
}
}
pub fn total_tokens_in(&self) -> u32 {
self.total_tokens_in
}
pub fn total_tokens_out(&self) -> u32 {
self.total_tokens_out
}
pub fn total_wallclock_ms(&self) -> u64 {
self.total_wallclock_ms
}
pub fn turn_count(&self) -> usize {
self.turn_count
}
pub fn cap_tokens(&self) -> u32 {
self.cap_tokens
}
pub fn set_cap(&mut self, cap: u32) {
self.cap_tokens = cap;
}
pub fn cap_reached(&self) -> bool {
self.total_tokens_out >= self.cap_tokens
}
pub fn mean_quality_last_n(&self, n: usize) -> Option<f64> {
if n == 0 || self.quality_history.len() < n {
return None;
}
let start = self.quality_history.len() - n;
let sum: f64 = self.quality_history.iter().skip(start).sum();
Some(sum / n as f64)
}
pub fn quality_decline_over_n(&self, n: usize, min_delta: f64) -> Option<f64> {
if n < 2 || self.quality_history.len() < n {
return None;
}
let start = self.quality_history.len() - n;
let oldest = *self.quality_history.get(start)?;
let newest = *self.quality_history.back()?;
let delta = oldest - newest;
if delta >= min_delta {
Some(delta)
} else {
None
}
}
}
impl Default for CostAccumulator {
fn default() -> Self {
Self::new()
}
}
pub fn normalize_cost(tokens_out: u32, wallclock_ms: u32) -> f64 {
let tok_component = clamp(
f64::from(tokens_out) / f64::from(TYPICAL_TURN_TOKENS_OUT),
0.0,
1.0,
);
let wc_component = clamp(
f64::from(wallclock_ms) / f64::from(TYPICAL_TURN_WALLCLOCK_MS),
0.0,
1.0,
);
TOKEN_COST_WEIGHT * tok_component + (1.0 - TOKEN_COST_WEIGHT) * wc_component
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_accumulator_reports_zeros() {
let acc = CostAccumulator::new();
assert_eq!(acc.total_tokens_in(), 0);
assert_eq!(acc.total_tokens_out(), 0);
assert_eq!(acc.total_wallclock_ms(), 0);
assert_eq!(acc.turn_count(), 0);
assert!(!acc.cap_reached());
assert_eq!(acc.cap_tokens(), DEFAULT_TOKEN_CAP);
}
#[test]
fn record_cost_accumulates_across_turns() {
let mut acc = CostAccumulator::new();
acc.record_cost(100, 200, 500);
acc.record_cost(50, 150, 1_000);
assert_eq!(acc.total_tokens_in(), 150);
assert_eq!(acc.total_tokens_out(), 350);
assert_eq!(acc.total_wallclock_ms(), 1_500);
assert_eq!(acc.turn_count(), 2);
}
#[test]
fn cap_reached_fires_at_or_above_cap() {
let mut acc = CostAccumulator::with_cap(1_000);
acc.record_cost(0, 500, 0);
assert!(!acc.cap_reached());
acc.record_cost(0, 500, 0);
assert!(acc.cap_reached(), "cap exactly met should be reached");
acc.record_cost(0, 100, 0);
assert!(acc.cap_reached(), "cap exceeded should stay reached");
}
#[test]
fn saturating_adds_handle_overflow_gracefully() {
let mut acc = CostAccumulator::new();
acc.record_cost(u32::MAX, u32::MAX, u32::MAX);
acc.record_cost(1, 1, 1);
assert_eq!(acc.total_tokens_in(), u32::MAX);
assert_eq!(acc.total_tokens_out(), u32::MAX);
}
#[test]
fn record_quality_clamps_and_skips_nonfinite() {
let mut acc = CostAccumulator::new();
acc.record_quality(0.8);
acc.record_quality(-1.0); acc.record_quality(2.0); acc.record_quality(f64::NAN); acc.record_quality(f64::INFINITY);
let mean = acc.mean_quality_last_n(3).expect("three values stored");
assert!(
((0.8 + 0.0 + 1.0) / 3.0 - mean).abs() < 1e-9,
"mean should reflect clamped finite values only (got {mean})"
);
}
#[test]
fn quality_history_evicts_oldest_on_overflow() {
let mut acc = CostAccumulator::new();
for _ in 0..DEFAULT_HISTORY_WINDOW + 5 {
acc.record_quality(0.5);
}
assert!(acc.mean_quality_last_n(DEFAULT_HISTORY_WINDOW).is_some());
assert!(acc
.mean_quality_last_n(DEFAULT_HISTORY_WINDOW + 1)
.is_none());
}
#[test]
fn mean_quality_last_n_empty_is_none() {
let acc = CostAccumulator::new();
assert!(acc.mean_quality_last_n(3).is_none());
}
#[test]
fn mean_quality_last_n_zero_is_none() {
let mut acc = CostAccumulator::new();
acc.record_quality(0.9);
assert!(acc.mean_quality_last_n(0).is_none());
}
#[test]
fn mean_quality_last_n_computes_trailing_mean() {
let mut acc = CostAccumulator::new();
for q in [0.1, 0.2, 0.3, 0.4, 0.5] {
acc.record_quality(q);
}
let mean3 = acc.mean_quality_last_n(3).expect("three trailing values");
assert!(
((0.3 + 0.4 + 0.5) / 3.0 - mean3).abs() < 1e-9,
"mean of trailing 3 should be 0.4 (got {mean3})"
);
}
#[test]
fn quality_decline_detects_monotonic_drop() {
let mut acc = CostAccumulator::new();
for q in [0.9, 0.7, 0.5, 0.3] {
acc.record_quality(q);
}
let delta = acc
.quality_decline_over_n(3, 0.15)
.expect("three-turn decline must fire");
assert!((delta - 0.4).abs() < 1e-9);
}
#[test]
fn quality_decline_returns_none_when_stable() {
let mut acc = CostAccumulator::new();
for q in [0.7, 0.65, 0.72, 0.68, 0.7] {
acc.record_quality(q);
}
assert!(
acc.quality_decline_over_n(3, 0.15).is_none(),
"stable quality must not register as declining"
);
}
#[test]
fn quality_decline_returns_none_when_improving() {
let mut acc = CostAccumulator::new();
for q in [0.3, 0.5, 0.7] {
acc.record_quality(q);
}
assert!(acc.quality_decline_over_n(3, 0.15).is_none());
}
#[test]
fn quality_decline_requires_min_points() {
let mut acc = CostAccumulator::new();
acc.record_quality(0.9);
acc.record_quality(0.3);
assert!(acc.quality_decline_over_n(3, 0.1).is_none());
assert!(acc.quality_decline_over_n(2, 0.1).is_some());
assert!(acc.quality_decline_over_n(1, 0.0).is_none());
}
#[test]
fn quality_decline_below_threshold_returns_none() {
let mut acc = CostAccumulator::new();
for q in [0.7, 0.65, 0.60] {
acc.record_quality(q);
}
assert!(acc.quality_decline_over_n(3, 0.15).is_none());
}
#[test]
fn normalize_cost_typical_turn_is_midrange() {
let c = normalize_cost(TYPICAL_TURN_TOKENS_OUT, TYPICAL_TURN_WALLCLOCK_MS);
assert!((c - 1.0).abs() < 1e-9);
}
#[test]
fn normalize_cost_zero_input_is_zero() {
assert_eq!(normalize_cost(0, 0), 0.0);
}
#[test]
fn normalize_cost_half_typical_gives_half_weight() {
let c = normalize_cost(500, 5_000);
assert!((c - 0.5).abs() < 1e-9);
}
#[test]
fn normalize_cost_clamps_runaway_turn() {
let c = normalize_cost(TYPICAL_TURN_TOKENS_OUT * 10, TYPICAL_TURN_WALLCLOCK_MS * 10);
assert!((c - 1.0).abs() < 1e-9);
}
#[test]
fn normalize_cost_weights_tokens_dominantly() {
let c = normalize_cost(TYPICAL_TURN_TOKENS_OUT, 0);
assert!((c - TOKEN_COST_WEIGHT).abs() < 1e-9);
let c = normalize_cost(0, TYPICAL_TURN_WALLCLOCK_MS);
assert!((c - (1.0 - TOKEN_COST_WEIGHT)).abs() < 1e-9);
}
}