use std::sync::{Mutex, OnceLock};
use crate::models::Usage;
use crate::pricing::CostEstimate;
static PENDING: OnceLock<Mutex<CostEstimate>> = OnceLock::new();
fn cell() -> &'static Mutex<CostEstimate> {
PENDING.get_or_init(|| Mutex::new(CostEstimate::default()))
}
pub fn report(model: &str, usage: &Usage) {
let Some(cost) = crate::pricing::calculate_turn_cost_estimate_from_usage(model, usage) else {
return;
};
if !cost.is_positive() {
return;
}
if let Ok(mut pending) = cell().lock() {
pending.usd += cost.usd;
pending.cny += cost.cny;
}
}
pub fn drain() -> CostEstimate {
let Ok(mut pending) = cell().lock() else {
return CostEstimate::default();
};
std::mem::take(&mut *pending)
}
#[cfg(test)]
pub fn reset_for_tests() {
if let Ok(mut pending) = cell().lock() {
*pending = CostEstimate::default();
}
}
#[cfg(test)]
mod tests {
use super::*;
fn small_usage() -> Usage {
Usage {
input_tokens: 1_000,
output_tokens: 500,
..Default::default()
}
}
fn serial_lock() -> std::sync::MutexGuard<'static, ()> {
static M: OnceLock<Mutex<()>> = OnceLock::new();
M.get_or_init(|| Mutex::new(()))
.lock()
.unwrap_or_else(|e| e.into_inner())
}
#[test]
fn report_adds_to_pool_and_drain_returns_then_resets() {
let _g = serial_lock();
reset_for_tests();
report("deepseek-v4-flash", &small_usage());
let first = drain();
assert!(first.usd > 0.0, "expected positive USD cost, got {first:?}");
assert!(first.cny > 0.0, "expected positive CNY cost, got {first:?}");
let second = drain();
assert_eq!(second, CostEstimate::default(), "drain must zero the pool");
}
#[test]
fn report_skips_unknown_models() {
let _g = serial_lock();
reset_for_tests();
report("deepseek-ai/deepseek-v4-pro", &small_usage());
assert_eq!(drain(), CostEstimate::default());
}
#[test]
fn report_accumulates_across_multiple_calls() {
let _g = serial_lock();
reset_for_tests();
report("deepseek-v4-flash", &small_usage());
report("deepseek-v4-flash", &small_usage());
let total = drain();
let single = crate::pricing::calculate_turn_cost_estimate_from_usage(
"deepseek-v4-flash",
&small_usage(),
)
.unwrap();
assert!((total.usd - 2.0 * single.usd).abs() < 1e-12);
assert!((total.cny - 2.0 * single.cny).abs() < 1e-12);
}
}