use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt: u32,
pub completion: u32,
}
impl TokenUsage {
pub const fn new(prompt: u32, completion: u32) -> Self {
Self { prompt, completion }
}
pub const fn total(&self) -> u32 {
self.prompt + self.completion
}
}
impl std::ops::Add for TokenUsage {
type Output = TokenUsage;
fn add(self, rhs: TokenUsage) -> TokenUsage {
TokenUsage { prompt: self.prompt + rhs.prompt, completion: self.completion + rhs.completion }
}
}
impl std::ops::AddAssign for TokenUsage {
fn add_assign(&mut self, rhs: TokenUsage) {
self.prompt += rhs.prompt;
self.completion += rhs.completion;
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ModelPricing {
pub input_per_mtok: f64,
pub output_per_mtok: f64,
}
impl ModelPricing {
pub fn cost_for(&self, usage: TokenUsage) -> CostEstimate {
CostEstimate {
input_usd: (usage.prompt as f64 / 1_000_000.0) * self.input_per_mtok,
output_usd: (usage.completion as f64 / 1_000_000.0) * self.output_per_mtok,
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
pub struct CostEstimate {
pub input_usd: f64,
pub output_usd: f64,
}
impl CostEstimate {
pub fn total_usd(&self) -> f64 {
self.input_usd + self.output_usd
}
}
impl std::ops::Add for CostEstimate {
type Output = CostEstimate;
fn add(self, rhs: CostEstimate) -> CostEstimate {
CostEstimate {
input_usd: self.input_usd + rhs.input_usd,
output_usd: self.output_usd + rhs.output_usd,
}
}
}
pub mod pricing {
use super::ModelPricing;
const fn p(input: f64, output: f64) -> ModelPricing {
ModelPricing { input_per_mtok: input, output_per_mtok: output }
}
pub fn pricing_for(model: &str) -> Option<ModelPricing> {
let m = model.trim().to_ascii_lowercase();
let m = m.strip_prefix("anthropic.").unwrap_or(&m);
let pricing = match () {
_ if m.starts_with("gpt-4o-mini") => p(0.15, 0.60),
_ if m.starts_with("gpt-4o") => p(2.50, 10.0),
_ if m.starts_with("o1-mini") => p(3.0, 12.0),
_ if m.starts_with("o1") => p(15.0, 60.0),
_ if m.starts_with("text-embedding-3-small") => p(0.02, 0.0),
_ if m.starts_with("text-embedding-3-large") => p(0.13, 0.0),
_ if m.starts_with("claude-opus-4") => p(5.0, 25.0),
_ if m.starts_with("claude-sonnet-4") || m.starts_with("claude-3-5-sonnet") => p(3.0, 15.0),
_ if m.starts_with("claude-haiku-4") || m.starts_with("claude-3-5-haiku") => p(1.0, 5.0),
_ if m.starts_with("claude-3-opus") => p(15.0, 75.0),
_ => return None,
};
Some(pricing)
}
pub fn pricing_for_or(model: &str, default: ModelPricing) -> ModelPricing {
pricing_for(model).unwrap_or(default)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn usage_arithmetic() {
let mut c = TokenUsage::new(10, 5);
c += TokenUsage::new(3, 7);
assert_eq!(c, TokenUsage::new(13, 12));
assert_eq!(c.total(), 25);
}
#[test]
fn cost_computation() {
let cost = ModelPricing { input_per_mtok: 5.0, output_per_mtok: 25.0 }
.cost_for(TokenUsage::new(1_000_000, 1_000_000));
assert!((cost.total_usd() - 30.0).abs() < 1e-9);
}
#[test]
fn pricing_lookup() {
assert!(pricing::pricing_for("gpt-4o-mini").is_some());
assert!(pricing::pricing_for("claude-opus-4-8").is_some());
assert!(pricing::pricing_for("llama3.1").is_none());
}
}