use crate::types::Usage;
#[derive(Debug, Clone, Copy)]
pub struct ModelPricing {
pub input: f64,
pub cache_write_5m: f64,
pub cache_write_1h: f64,
pub cache_read: f64,
pub output: f64,
}
impl ModelPricing {
pub const OPUS_46: Self = Self {
input: 5.0,
cache_write_5m: 6.25,
cache_write_1h: 10.0,
cache_read: 0.50,
output: 25.0,
};
pub const OPUS_45: Self = Self {
input: 5.0,
cache_write_5m: 6.25,
cache_write_1h: 10.0,
cache_read: 0.50,
output: 25.0,
};
pub const OPUS_41: Self = Self {
input: 15.0,
cache_write_5m: 18.75,
cache_write_1h: 30.0,
cache_read: 1.50,
output: 75.0,
};
pub const OPUS_4: Self = Self {
input: 15.0,
cache_write_5m: 18.75,
cache_write_1h: 30.0,
cache_read: 1.50,
output: 75.0,
};
pub const SONNET_46: Self = Self {
input: 3.0,
cache_write_5m: 3.75,
cache_write_1h: 6.0,
cache_read: 0.30,
output: 15.0,
};
pub const SONNET_45: Self = Self {
input: 3.0,
cache_write_5m: 3.75,
cache_write_1h: 6.0,
cache_read: 0.30,
output: 15.0,
};
pub const SONNET_4: Self = Self {
input: 3.0,
cache_write_5m: 3.75,
cache_write_1h: 6.0,
cache_read: 0.30,
output: 15.0,
};
pub const HAIKU_45: Self = Self {
input: 1.0,
cache_write_5m: 1.25,
cache_write_1h: 2.0,
cache_read: 0.10,
output: 5.0,
};
}
#[derive(Debug, Clone, Copy, Default)]
pub struct CostBreakdown {
pub input_cost: f64,
pub cache_write_cost: f64,
pub cache_read_cost: f64,
pub output_cost: f64,
}
impl CostBreakdown {
pub fn total(&self) -> f64 {
self.input_cost + self.cache_write_cost + self.cache_read_cost + self.output_cost
}
}
impl std::fmt::Display for CostBreakdown {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"${:.6} (in=${:.6} cache_w=${:.6} cache_r=${:.6} out=${:.6})",
self.total(),
self.input_cost,
self.cache_write_cost,
self.cache_read_cost,
self.output_cost
)
}
}
pub fn estimate_cost(pricing: ModelPricing, usage: &Usage) -> CostBreakdown {
let mtok = 1_000_000.0;
CostBreakdown {
input_cost: usage.input_tokens as f64 / mtok * pricing.input,
cache_write_cost: usage.cache_creation_input_tokens.unwrap_or(0) as f64 / mtok
* pricing.cache_write_5m,
cache_read_cost: usage.cache_read_input_tokens.unwrap_or(0) as f64 / mtok
* pricing.cache_read,
output_cost: usage.output_tokens as f64 / mtok * pricing.output,
}
}
pub fn estimate_cost_1h(pricing: ModelPricing, usage: &Usage) -> CostBreakdown {
let mtok = 1_000_000.0;
CostBreakdown {
input_cost: usage.input_tokens as f64 / mtok * pricing.input,
cache_write_cost: usage.cache_creation_input_tokens.unwrap_or(0) as f64 / mtok
* pricing.cache_write_1h,
cache_read_cost: usage.cache_read_input_tokens.unwrap_or(0) as f64 / mtok
* pricing.cache_read,
output_cost: usage.output_tokens as f64 / mtok * pricing.output,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sonnet_46_basic_cost() {
let usage = Usage {
input_tokens: 1000,
output_tokens: 500,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
cache_creation_input_tokens_1h: None,
server_tool_use: None,
};
let cost = estimate_cost(ModelPricing::SONNET_46, &usage);
assert!((cost.input_cost - 0.003).abs() < 1e-9);
assert!((cost.output_cost - 0.0075).abs() < 1e-9);
assert!((cost.total() - 0.0105).abs() < 1e-9);
}
#[test]
fn sonnet_46_with_caching() {
let usage = Usage {
input_tokens: 3,
output_tokens: 256,
cache_creation_input_tokens: Some(274),
cache_read_input_tokens: Some(2048),
cache_creation_input_tokens_1h: None,
server_tool_use: None,
};
let cost = estimate_cost(ModelPricing::SONNET_46, &usage);
assert!(cost.cache_read_cost > 0.0);
assert!(cost.cache_write_cost > 0.0);
assert!(cost.total() > 0.0);
}
#[test]
fn display_format() {
let cost = CostBreakdown {
input_cost: 0.003,
cache_write_cost: 0.001,
cache_read_cost: 0.0005,
output_cost: 0.0075,
};
let s = cost.to_string();
assert!(s.starts_with('$'));
assert!(s.contains("in="));
assert!(s.contains("out="));
}
}