use saorsa_ai::{Usage, lookup_model, lookup_model_by_prefix};
#[derive(Clone, Debug)]
pub struct CostEntry {
pub model: String,
pub input_tokens: u32,
pub output_tokens: u32,
pub cost_usd: f64,
}
#[derive(Clone, Debug, Default)]
pub struct CostTracker {
pub entries: Vec<CostEntry>,
pub session_total: f64,
}
impl CostTracker {
pub fn new() -> Self {
Self::default()
}
pub fn track(&mut self, model: &str, usage: &Usage) -> CostEntry {
let model_info = lookup_model(model).or_else(|| lookup_model_by_prefix(model));
let cost_usd = model_info
.and_then(|info| {
let input_cost = info.cost_per_million_input?;
let output_cost = info.cost_per_million_output?;
let input_usd = f64::from(usage.input_tokens) * input_cost / 1_000_000.0;
let output_usd = f64::from(usage.output_tokens) * output_cost / 1_000_000.0;
Some(input_usd + output_usd)
})
.unwrap_or(0.0);
let entry = CostEntry {
model: model.to_string(),
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
cost_usd,
};
self.session_total += cost_usd;
self.entries.push(entry.clone());
entry
}
pub fn format_session_cost(&self) -> String {
if self.session_total < 0.01 {
format!("${:.4}", self.session_total)
} else {
format!("${:.2}", self.session_total)
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn new_tracker_empty() {
let tracker = CostTracker::new();
assert!(tracker.entries.is_empty());
assert!((tracker.session_total - 0.0).abs() < f64::EPSILON);
}
#[test]
fn track_known_model() {
let mut tracker = CostTracker::new();
let usage = Usage {
input_tokens: 1000,
output_tokens: 500,
cache_read_tokens: 0,
cache_write_tokens: 0,
};
let entry = tracker.track("claude-sonnet-4", &usage);
assert!(entry.cost_usd > 0.0);
assert_eq!(entry.input_tokens, 1000);
assert_eq!(entry.output_tokens, 500);
assert_eq!(entry.model, "claude-sonnet-4");
let expected = 1000.0 * 3.0 / 1_000_000.0 + 500.0 * 15.0 / 1_000_000.0;
assert!((entry.cost_usd - expected).abs() < f64::EPSILON);
}
#[test]
fn track_unknown_model() {
let mut tracker = CostTracker::new();
let usage = Usage {
input_tokens: 1000,
output_tokens: 500,
cache_read_tokens: 0,
cache_write_tokens: 0,
};
let entry = tracker.track("totally-unknown-model", &usage);
assert!((entry.cost_usd - 0.0).abs() < f64::EPSILON);
assert_eq!(tracker.entries.len(), 1);
}
#[test]
fn format_cost_small() {
let mut tracker = CostTracker::new();
tracker.session_total = 0.0035;
assert_eq!(tracker.format_session_cost(), "$0.0035");
}
#[test]
fn format_cost_large() {
let mut tracker = CostTracker::new();
tracker.session_total = 1.2345;
assert_eq!(tracker.format_session_cost(), "$1.23");
}
#[test]
fn session_total_accumulates() {
let mut tracker = CostTracker::new();
let usage = Usage {
input_tokens: 1_000_000,
output_tokens: 0,
cache_read_tokens: 0,
cache_write_tokens: 0,
};
tracker.track("claude-sonnet-4", &usage);
let first_total = tracker.session_total;
assert!((first_total - 3.0).abs() < f64::EPSILON);
tracker.track("claude-sonnet-4", &usage);
assert!((tracker.session_total - 6.0).abs() < f64::EPSILON);
assert_eq!(tracker.entries.len(), 2);
}
#[test]
fn track_prefix_matched_model() {
let mut tracker = CostTracker::new();
let usage = Usage {
input_tokens: 1000,
output_tokens: 500,
cache_read_tokens: 0,
cache_write_tokens: 0,
};
let entry = tracker.track("claude-sonnet-4-5-20250929", &usage);
assert!(entry.cost_usd > 0.0);
}
#[test]
fn track_model_without_pricing() {
let mut tracker = CostTracker::new();
let usage = Usage {
input_tokens: 1000,
output_tokens: 500,
cache_read_tokens: 0,
cache_write_tokens: 0,
};
let entry = tracker.track("llama3", &usage);
assert!((entry.cost_usd - 0.0).abs() < f64::EPSILON);
}
}