1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9
10use crate::model::ModelRole;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ModelPricing {
15 pub input_per_million: f64,
17 pub output_per_million: f64,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct CostEvent {
24 pub model_role: ModelRole,
25 pub model_name: String,
26 pub input_tokens: u64,
27 pub output_tokens: u64,
28 pub cost: f64,
29 pub timestamp: chrono::DateTime<chrono::Utc>,
30}
31
32#[derive(Debug, Clone)]
34pub struct CostTracker {
35 pricing: HashMap<String, ModelPricing>,
37 events: Arc<Mutex<Vec<CostEvent>>>,
39}
40
41impl CostTracker {
42 pub fn new() -> Self {
44 let mut pricing = HashMap::new();
45
46 pricing.insert(
48 "claude-opus-4-5".into(),
49 ModelPricing {
50 input_per_million: 15.0,
51 output_per_million: 75.0,
52 },
53 );
54 pricing.insert(
55 "claude-sonnet-4".into(),
56 ModelPricing {
57 input_per_million: 3.0,
58 output_per_million: 15.0,
59 },
60 );
61
62 pricing.insert(
64 "gemini-2.5-flash".into(),
65 ModelPricing {
66 input_per_million: 0.15,
67 output_per_million: 0.60,
68 },
69 );
70
71 pricing.insert(
73 "ollama-local".into(),
74 ModelPricing {
75 input_per_million: 0.0,
76 output_per_million: 0.0,
77 },
78 );
79
80 Self {
81 pricing,
82 events: Arc::new(Mutex::new(Vec::new())),
83 }
84 }
85
86 pub fn estimate_cost(
88 &self,
89 model_name: &str,
90 input_tokens: u64,
91 output_tokens: u64,
92 ) -> f64 {
93 if let Some(pricing) = self.pricing.get(model_name) {
94 let input_cost = (input_tokens as f64 / 1_000_000.0) * pricing.input_per_million;
95 let output_cost = (output_tokens as f64 / 1_000_000.0) * pricing.output_per_million;
96 input_cost + output_cost
97 } else {
98 let input_cost = (input_tokens as f64 / 1_000_000.0) * 3.0;
100 let output_cost = (output_tokens as f64 / 1_000_000.0) * 15.0;
101 input_cost + output_cost
102 }
103 }
104
105 pub fn record(
107 &self,
108 model_role: ModelRole,
109 model_name: &str,
110 input_tokens: u64,
111 output_tokens: u64,
112 ) {
113 let cost = self.estimate_cost(model_name, input_tokens, output_tokens);
114 let event = CostEvent {
115 model_role,
116 model_name: model_name.to_string(),
117 input_tokens,
118 output_tokens,
119 cost,
120 timestamp: chrono::Utc::now(),
121 };
122 if let Ok(mut events) = self.events.lock() {
123 events.push(event);
124 }
125 }
126
127 pub fn total_cost(&self) -> f64 {
129 self.events
130 .lock()
131 .map(|events| events.iter().map(|e| e.cost).sum())
132 .unwrap_or(0.0)
133 }
134
135 pub fn cost_by_role(&self) -> HashMap<ModelRole, f64> {
137 let mut breakdown = HashMap::new();
138 if let Ok(events) = self.events.lock() {
139 for event in events.iter() {
140 *breakdown.entry(event.model_role.clone()).or_insert(0.0) += event.cost;
141 }
142 }
143 breakdown
144 }
145
146 pub fn events(&self) -> Vec<CostEvent> {
148 self.events
149 .lock()
150 .map(|events| events.clone())
151 .unwrap_or_default()
152 }
153
154 pub fn reset(&self) {
156 if let Ok(mut events) = self.events.lock() {
157 events.clear();
158 }
159 }
160}
161
162impl Default for CostTracker {
163 fn default() -> Self {
164 Self::new()
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn test_estimate_cost_known_model() {
174 let tracker = CostTracker::new();
175 let cost = tracker.estimate_cost("claude-opus-4-5", 1_000_000, 1_000_000);
177 assert!((cost - 90.0).abs() < 0.01); }
179
180 #[test]
181 fn test_estimate_cost_cheap_model() {
182 let tracker = CostTracker::new();
183 let cost = tracker.estimate_cost("gemini-2.5-flash", 1_000_000, 1_000_000);
185 assert!((cost - 0.75).abs() < 0.01); }
187
188 #[test]
189 fn test_estimate_cost_local_model() {
190 let tracker = CostTracker::new();
191 let cost = tracker.estimate_cost("ollama-local", 1_000_000, 1_000_000);
192 assert!((cost - 0.0).abs() < 0.01);
193 }
194
195 #[test]
196 fn test_estimate_cost_unknown_model() {
197 let tracker = CostTracker::new();
198 let cost = tracker.estimate_cost("unknown-model", 1_000_000, 1_000_000);
199 assert!((cost - 18.0).abs() < 0.01);
201 }
202
203 #[test]
204 fn test_record_and_total() {
205 let tracker = CostTracker::new();
206 tracker.record(ModelRole::Thinking, "claude-opus-4-5", 500_000, 100_000);
207 tracker.record(ModelRole::Task, "gemini-2.5-flash", 200_000, 50_000);
208
209 let total = tracker.total_cost();
210 assert!(total > 0.0);
211
212 let events = tracker.events();
213 assert_eq!(events.len(), 2);
214 }
215
216 #[test]
217 fn test_cost_by_role() {
218 let tracker = CostTracker::new();
219 tracker.record(ModelRole::Thinking, "claude-opus-4-5", 1_000_000, 0);
220 tracker.record(ModelRole::Task, "gemini-2.5-flash", 1_000_000, 0);
221
222 let breakdown = tracker.cost_by_role();
223 assert!(breakdown.contains_key(&ModelRole::Thinking));
224 assert!(breakdown.contains_key(&ModelRole::Task));
225 assert!(breakdown[&ModelRole::Thinking] > breakdown[&ModelRole::Task]);
226 }
227
228 #[test]
229 fn test_reset() {
230 let tracker = CostTracker::new();
231 tracker.record(ModelRole::Thinking, "claude-opus-4-5", 1_000_000, 0);
232 assert!(tracker.total_cost() > 0.0);
233
234 tracker.reset();
235 assert!((tracker.total_cost() - 0.0).abs() < 0.001);
236 assert!(tracker.events().is_empty());
237 }
238}