use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenUsage {
pub model: String,
pub input_tokens: u64,
pub output_tokens: u64,
pub total_tokens: u64,
pub cost_usd: f64,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
impl TokenUsage {
fn sanitize_price(value: f64) -> f64 {
if value.is_finite() && value > 0.0 {
value
} else {
0.0
}
}
pub fn new(
model: impl Into<String>,
input_tokens: u64,
output_tokens: u64,
input_price_per_million: f64,
output_price_per_million: f64,
) -> Self {
let model = model.into();
let input_price_per_million = Self::sanitize_price(input_price_per_million);
let output_price_per_million = Self::sanitize_price(output_price_per_million);
let total_tokens = input_tokens.saturating_add(output_tokens);
let input_cost = (input_tokens as f64 / 1_000_000.0) * input_price_per_million;
let output_cost = (output_tokens as f64 / 1_000_000.0) * output_price_per_million;
let cost_usd = input_cost + output_cost;
Self {
model,
input_tokens,
output_tokens,
total_tokens,
cost_usd,
timestamp: chrono::Utc::now(),
}
}
pub fn cost(&self) -> f64 {
self.cost_usd
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum UsagePeriod {
Session,
Day,
Month,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostRecord {
pub id: String,
pub usage: TokenUsage,
pub session_id: String,
}
impl CostRecord {
pub fn new(session_id: impl Into<String>, usage: TokenUsage) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
usage,
session_id: session_id.into(),
}
}
}
#[derive(Debug, Clone)]
pub enum BudgetCheck {
Allowed,
Warning {
current_usd: f64,
limit_usd: f64,
period: UsagePeriod,
},
Exceeded {
current_usd: f64,
limit_usd: f64,
period: UsagePeriod,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostSummary {
pub session_cost_usd: f64,
pub daily_cost_usd: f64,
pub monthly_cost_usd: f64,
pub total_tokens: u64,
pub request_count: usize,
pub by_model: std::collections::HashMap<String, ModelStats>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelStats {
pub model: String,
pub cost_usd: f64,
pub total_tokens: u64,
pub request_count: usize,
}
impl Default for CostSummary {
fn default() -> Self {
Self {
session_cost_usd: 0.0,
daily_cost_usd: 0.0,
monthly_cost_usd: 0.0,
total_tokens: 0,
request_count: 0,
by_model: std::collections::HashMap::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn token_usage_calculation() {
let usage = TokenUsage::new("test/model", 1000, 500, 3.0, 15.0);
assert!((usage.cost_usd - 0.0105).abs() < 0.0001);
assert_eq!(usage.input_tokens, 1000);
assert_eq!(usage.output_tokens, 500);
assert_eq!(usage.total_tokens, 1500);
}
#[test]
fn token_usage_zero_tokens() {
let usage = TokenUsage::new("test/model", 0, 0, 3.0, 15.0);
assert!(usage.cost_usd.abs() < f64::EPSILON);
assert_eq!(usage.total_tokens, 0);
}
#[test]
fn token_usage_negative_or_non_finite_prices_are_clamped() {
let usage = TokenUsage::new("test/model", 1000, 1000, -3.0, f64::NAN);
assert!(usage.cost_usd.abs() < f64::EPSILON);
assert_eq!(usage.total_tokens, 2000);
}
#[test]
fn cost_record_creation() {
let usage = TokenUsage::new("test/model", 100, 50, 1.0, 2.0);
let record = CostRecord::new("session-123", usage);
assert_eq!(record.session_id, "session-123");
assert!(!record.id.is_empty());
assert_eq!(record.usage.model, "test/model");
}
}