1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
7pub struct TokenUsage {
8 pub prompt: u32,
10 pub completion: u32,
12}
13
14impl TokenUsage {
15 pub const fn new(prompt: u32, completion: u32) -> Self {
17 Self { prompt, completion }
18 }
19
20 pub const fn total(&self) -> u32 {
22 self.prompt + self.completion
23 }
24}
25
26impl std::ops::Add for TokenUsage {
27 type Output = TokenUsage;
28 fn add(self, rhs: TokenUsage) -> TokenUsage {
29 TokenUsage { prompt: self.prompt + rhs.prompt, completion: self.completion + rhs.completion }
30 }
31}
32
33impl std::ops::AddAssign for TokenUsage {
34 fn add_assign(&mut self, rhs: TokenUsage) {
35 self.prompt += rhs.prompt;
36 self.completion += rhs.completion;
37 }
38}
39
40#[derive(Debug, Clone, Copy, PartialEq)]
42pub struct ModelPricing {
43 pub input_per_mtok: f64,
45 pub output_per_mtok: f64,
47}
48
49impl ModelPricing {
50 pub fn cost_for(&self, usage: TokenUsage) -> CostEstimate {
52 CostEstimate {
53 input_usd: (usage.prompt as f64 / 1_000_000.0) * self.input_per_mtok,
54 output_usd: (usage.completion as f64 / 1_000_000.0) * self.output_per_mtok,
55 }
56 }
57}
58
59#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
61pub struct CostEstimate {
62 pub input_usd: f64,
64 pub output_usd: f64,
66}
67
68impl CostEstimate {
69 pub fn total_usd(&self) -> f64 {
71 self.input_usd + self.output_usd
72 }
73}
74
75impl std::ops::Add for CostEstimate {
76 type Output = CostEstimate;
77 fn add(self, rhs: CostEstimate) -> CostEstimate {
78 CostEstimate {
79 input_usd: self.input_usd + rhs.input_usd,
80 output_usd: self.output_usd + rhs.output_usd,
81 }
82 }
83}
84
85pub mod pricing {
87 use super::ModelPricing;
88
89 const fn p(input: f64, output: f64) -> ModelPricing {
90 ModelPricing { input_per_mtok: input, output_per_mtok: output }
91 }
92
93 pub fn pricing_for(model: &str) -> Option<ModelPricing> {
95 let m = model.trim().to_ascii_lowercase();
96 let m = m.strip_prefix("anthropic.").unwrap_or(&m);
97
98 let pricing = match () {
99 _ if m.starts_with("gpt-4o-mini") => p(0.15, 0.60),
100 _ if m.starts_with("gpt-4o") => p(2.50, 10.0),
101 _ if m.starts_with("o1-mini") => p(3.0, 12.0),
102 _ if m.starts_with("o1") => p(15.0, 60.0),
103 _ if m.starts_with("text-embedding-3-small") => p(0.02, 0.0),
104 _ if m.starts_with("text-embedding-3-large") => p(0.13, 0.0),
105 _ if m.starts_with("claude-opus-4") => p(5.0, 25.0),
106 _ if m.starts_with("claude-sonnet-4") || m.starts_with("claude-3-5-sonnet") => p(3.0, 15.0),
107 _ if m.starts_with("claude-haiku-4") || m.starts_with("claude-3-5-haiku") => p(1.0, 5.0),
108 _ if m.starts_with("claude-3-opus") => p(15.0, 75.0),
109 _ => return None,
110 };
111 Some(pricing)
112 }
113
114 pub fn pricing_for_or(model: &str, default: ModelPricing) -> ModelPricing {
116 pricing_for(model).unwrap_or(default)
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 #[test]
125 fn usage_arithmetic() {
126 let mut c = TokenUsage::new(10, 5);
127 c += TokenUsage::new(3, 7);
128 assert_eq!(c, TokenUsage::new(13, 12));
129 assert_eq!(c.total(), 25);
130 }
131
132 #[test]
133 fn cost_computation() {
134 let cost = ModelPricing { input_per_mtok: 5.0, output_per_mtok: 25.0 }
135 .cost_for(TokenUsage::new(1_000_000, 1_000_000));
136 assert!((cost.total_usd() - 30.0).abs() < 1e-9);
137 }
138
139 #[test]
140 fn pricing_lookup() {
141 assert!(pricing::pricing_for("gpt-4o-mini").is_some());
142 assert!(pricing::pricing_for("claude-opus-4-8").is_some());
143 assert!(pricing::pricing_for("llama3.1").is_none());
144 }
145}