use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct TokenUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
impl TokenUsage {
pub fn new(prompt_tokens: u32, completion_tokens: u32) -> Self {
Self {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
}
}
pub fn from_total(total_tokens: u32) -> Self {
Self {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens,
}
}
pub fn is_empty(&self) -> bool {
self.total_tokens == 0
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelPricing {
pub input_cost_per_1m: f64,
pub output_cost_per_1m: f64,
}
impl ModelPricing {
pub fn new(input_cost_per_1m: f64, output_cost_per_1m: f64) -> Self {
Self {
input_cost_per_1m,
output_cost_per_1m,
}
}
pub fn default_pricing() -> Self {
Self {
input_cost_per_1m: 3.0, output_cost_per_1m: 15.0, }
}
pub fn gpt4() -> Self {
Self {
input_cost_per_1m: 30.0,
output_cost_per_1m: 60.0,
}
}
pub fn gpt4_turbo() -> Self {
Self {
input_cost_per_1m: 10.0,
output_cost_per_1m: 30.0,
}
}
pub fn gpt4o() -> Self {
Self {
input_cost_per_1m: 2.5,
output_cost_per_1m: 10.0,
}
}
pub fn gpt4o_mini() -> Self {
Self {
input_cost_per_1m: 0.15,
output_cost_per_1m: 0.60,
}
}
pub fn claude3_opus() -> Self {
Self {
input_cost_per_1m: 15.0,
output_cost_per_1m: 75.0,
}
}
pub fn claude35_sonnet() -> Self {
Self {
input_cost_per_1m: 3.0,
output_cost_per_1m: 15.0,
}
}
pub fn claude3_haiku() -> Self {
Self {
input_cost_per_1m: 0.25,
output_cost_per_1m: 1.25,
}
}
}
pub struct CostCalculator;
impl CostCalculator {
pub fn calculate_cost(usage: &TokenUsage, pricing: &ModelPricing) -> f64 {
let input_cost = (usage.prompt_tokens as f64 / 1_000_000.0) * pricing.input_cost_per_1m;
let output_cost =
(usage.completion_tokens as f64 / 1_000_000.0) * pricing.output_cost_per_1m;
input_cost + output_cost
}
pub fn calculate_cost_default(usage: &TokenUsage) -> f64 {
Self::calculate_cost(usage, &ModelPricing::default_pricing())
}
pub fn pricing_for_model(model: &str) -> ModelPricing {
let model_lower = model.to_lowercase();
if model_lower.contains("gpt-4o-mini") || model_lower.contains("gpt-4-1-mini") {
ModelPricing::gpt4o_mini()
} else if model_lower.contains("gpt-4o") {
ModelPricing::gpt4o()
} else if model_lower.contains("gpt-4-turbo") {
ModelPricing::gpt4_turbo()
} else if model_lower.contains("gpt-4") {
ModelPricing::gpt4()
} else if model_lower.contains("claude-3-opus") || model_lower.contains("opus") {
ModelPricing::claude3_opus()
} else if model_lower.contains("claude-3.5-sonnet")
|| model_lower.contains("claude-3-5-sonnet")
|| model_lower.contains("sonnet")
{
ModelPricing::claude35_sonnet()
} else if model_lower.contains("claude-3-haiku") || model_lower.contains("haiku") {
ModelPricing::claude3_haiku()
} else {
ModelPricing::default_pricing()
}
}
}
#[derive(Debug, Clone, Default)]
pub struct UsageAccumulator {
pub total_prompt_tokens: u64,
pub total_completion_tokens: u64,
pub total_tokens: u64,
pub total_cost_usd: f64,
pub call_count: u64,
}
impl UsageAccumulator {
pub fn new() -> Self {
Self::default()
}
pub fn add(&mut self, usage: &TokenUsage, cost: f64) {
self.total_prompt_tokens += usage.prompt_tokens as u64;
self.total_completion_tokens += usage.completion_tokens as u64;
self.total_tokens += usage.total_tokens as u64;
self.total_cost_usd += cost;
self.call_count += 1;
}
pub fn add_with_pricing(&mut self, usage: &TokenUsage, pricing: &ModelPricing) {
let cost = CostCalculator::calculate_cost(usage, pricing);
self.add(usage, cost);
}
pub fn add_for_model(&mut self, usage: &TokenUsage, model: &str) {
let pricing = CostCalculator::pricing_for_model(model);
self.add_with_pricing(usage, &pricing);
}
pub fn avg_tokens_per_call(&self) -> f64 {
if self.call_count == 0 {
0.0
} else {
self.total_tokens as f64 / self.call_count as f64
}
}
pub fn avg_cost_per_call(&self) -> f64 {
if self.call_count == 0 {
0.0
} else {
self.total_cost_usd / self.call_count as f64
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_usage_new() {
let usage = TokenUsage::new(100, 50);
assert_eq!(usage.prompt_tokens, 100);
assert_eq!(usage.completion_tokens, 50);
assert_eq!(usage.total_tokens, 150);
}
#[test]
fn test_token_usage_from_total() {
let usage = TokenUsage::from_total(200);
assert_eq!(usage.prompt_tokens, 0);
assert_eq!(usage.completion_tokens, 0);
assert_eq!(usage.total_tokens, 200);
}
#[test]
fn test_token_usage_is_empty() {
let empty = TokenUsage::default();
let non_empty = TokenUsage::new(10, 5);
assert!(empty.is_empty());
assert!(!non_empty.is_empty());
}
#[test]
fn test_cost_calculator() {
let usage = TokenUsage::new(1_000_000, 500_000);
let pricing = ModelPricing::new(3.0, 15.0);
let cost = CostCalculator::calculate_cost(&usage, &pricing);
assert!((cost - 10.50).abs() < 0.001);
}
#[test]
fn test_cost_calculator_small_usage() {
let usage = TokenUsage::new(1000, 500);
let pricing = ModelPricing::gpt4o_mini();
let cost = CostCalculator::calculate_cost(&usage, &pricing);
assert!(cost > 0.0 && cost < 0.001);
}
#[test]
fn test_pricing_for_model() {
let gpt4 = CostCalculator::pricing_for_model("gpt-4");
assert_eq!(gpt4.input_cost_per_1m, 30.0);
let gpt4o = CostCalculator::pricing_for_model("gpt-4o");
assert_eq!(gpt4o.input_cost_per_1m, 2.5);
let claude = CostCalculator::pricing_for_model("claude-3.5-sonnet");
assert_eq!(claude.input_cost_per_1m, 3.0);
}
#[test]
fn test_usage_accumulator() {
let mut acc = UsageAccumulator::new();
acc.add(&TokenUsage::new(100, 50), 0.01);
acc.add(&TokenUsage::new(200, 100), 0.02);
assert_eq!(acc.total_prompt_tokens, 300);
assert_eq!(acc.total_completion_tokens, 150);
assert_eq!(acc.total_tokens, 450);
assert!((acc.total_cost_usd - 0.03).abs() < 0.0001);
assert_eq!(acc.call_count, 2);
}
#[test]
fn test_usage_accumulator_averages() {
let mut acc = UsageAccumulator::new();
acc.add(&TokenUsage::new(100, 100), 0.02);
acc.add(&TokenUsage::new(200, 200), 0.04);
assert_eq!(acc.avg_tokens_per_call(), 300.0);
assert!((acc.avg_cost_per_call() - 0.03).abs() < 0.0001);
}
#[test]
fn test_usage_accumulator_empty() {
let acc = UsageAccumulator::new();
assert_eq!(acc.avg_tokens_per_call(), 0.0);
assert_eq!(acc.avg_cost_per_call(), 0.0);
}
}