1use chrono::{DateTime, Duration, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fs;
7use std::path::PathBuf;
8
9use crate::tokens::mappings::normalize_model_name;
10
11pub const PRICING_URL: &str =
13 "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json";
14
15pub const CACHE_DURATION: Duration = Duration::hours(24);
17
18pub const CACHE_FILE_NAME: &str = "agent-io-pricing-cache.json";
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ModelPricing {
24 pub model: String,
25 pub input_cost_per_token: Option<f64>,
26 pub output_cost_per_token: Option<f64>,
27 pub cache_read_input_token_cost: Option<f64>,
28 pub cache_creation_input_token_cost: Option<f64>,
29 pub max_tokens: Option<u64>,
30 pub max_input_tokens: Option<u64>,
31 pub max_output_tokens: Option<u64>,
32}
33
34impl ModelPricing {
35 pub fn calculate_cost(
37 &self,
38 input_tokens: u64,
39 output_tokens: u64,
40 cached_tokens: u64,
41 cache_creation_tokens: u64,
42 ) -> TokenCostCalculated {
43 let mut prompt_cost = 0.0;
44 let mut completion_cost = 0.0;
45
46 if let Some(cost) = self.input_cost_per_token {
48 prompt_cost += (input_tokens as f64) * cost;
49 }
50
51 if let Some(cost) = self.cache_read_input_token_cost {
53 prompt_cost -= (input_tokens as f64) * (self.input_cost_per_token.unwrap_or(0.0));
54 prompt_cost += (cached_tokens as f64) * cost;
55 }
56
57 if let Some(cost) = self.cache_creation_input_token_cost {
59 prompt_cost += (cache_creation_tokens as f64) * cost;
60 }
61
62 if let Some(cost) = self.output_cost_per_token {
64 completion_cost = (output_tokens as f64) * cost;
65 }
66
67 TokenCostCalculated {
68 prompt_cost,
69 completion_cost,
70 total_cost: prompt_cost + completion_cost,
71 }
72 }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct TokenCostCalculated {
78 pub prompt_cost: f64,
79 pub completion_cost: f64,
80 pub total_cost: f64,
81}
82
83#[derive(Debug, Clone, Default, Serialize, Deserialize)]
85pub struct ModelUsageStats {
86 pub prompt_tokens: u64,
87 pub completion_tokens: u64,
88 pub total_tokens: u64,
89 pub prompt_cost: f64,
90 pub completion_cost: f64,
91 pub total_cost: f64,
92 pub calls: u64,
93}
94
95#[derive(Debug, Clone, Default, Serialize, Deserialize)]
97pub struct UsageSummary {
98 pub total_prompt_tokens: u64,
99 pub total_prompt_cost: f64,
100 pub total_completion_tokens: u64,
101 pub total_completion_cost: f64,
102 pub total_tokens: u64,
103 pub total_cost: f64,
104 pub by_model: HashMap<String, ModelUsageStats>,
105}
106
107impl UsageSummary {
108 pub fn new() -> Self {
109 Self::default()
110 }
111
112 pub fn add(&mut self, model: &str, usage: &crate::llm::Usage, pricing: Option<&ModelPricing>) {
114 self.total_prompt_tokens += usage.prompt_tokens;
115 self.total_completion_tokens += usage.completion_tokens;
116 self.total_tokens += usage.total_tokens;
117
118 let model_stats = self.by_model.entry(model.to_string()).or_default();
119 model_stats.prompt_tokens += usage.prompt_tokens;
120 model_stats.completion_tokens += usage.completion_tokens;
121 model_stats.total_tokens += usage.total_tokens;
122 model_stats.calls += 1;
123
124 if let Some(pricing) = pricing {
125 let cost = pricing.calculate_cost(
126 usage.prompt_tokens,
127 usage.completion_tokens,
128 usage.prompt_cached_tokens.unwrap_or(0),
129 usage.prompt_cache_creation_tokens.unwrap_or(0),
130 );
131
132 self.total_prompt_cost += cost.prompt_cost;
133 self.total_completion_cost += cost.completion_cost;
134 self.total_cost += cost.total_cost;
135
136 model_stats.prompt_cost += cost.prompt_cost;
137 model_stats.completion_cost += cost.completion_cost;
138 model_stats.total_cost += cost.total_cost;
139 }
140 }
141
142 pub fn merge(&mut self, other: &UsageSummary) {
144 self.total_prompt_tokens += other.total_prompt_tokens;
145 self.total_prompt_cost += other.total_prompt_cost;
146 self.total_completion_tokens += other.total_completion_tokens;
147 self.total_completion_cost += other.total_completion_cost;
148 self.total_tokens += other.total_tokens;
149 self.total_cost += other.total_cost;
150
151 for (model, stats) in &other.by_model {
152 let entry = self.by_model.entry(model.clone()).or_default();
153 entry.prompt_tokens += stats.prompt_tokens;
154 entry.completion_tokens += stats.completion_tokens;
155 entry.total_tokens += stats.total_tokens;
156 entry.prompt_cost += stats.prompt_cost;
157 entry.completion_cost += stats.completion_cost;
158 entry.total_cost += stats.total_cost;
159 entry.calls += stats.calls;
160 }
161 }
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166struct CachedPricing {
167 pricing: HashMap<String, ModelPricing>,
168 last_update: DateTime<Utc>,
169}
170
171pub struct TokenCost {
173 pricing: HashMap<String, ModelPricing>,
174 last_update: Option<DateTime<Utc>>,
175 cache_path: Option<PathBuf>,
176}
177
178impl TokenCost {
179 pub fn new() -> Self {
180 Self {
181 pricing: HashMap::new(),
182 last_update: None,
183 cache_path: Self::get_cache_path(),
184 }
185 }
186
187 fn get_cache_path() -> Option<PathBuf> {
189 if let Ok(xdg_cache) = std::env::var("XDG_CACHE_HOME") {
191 let cache_dir = PathBuf::from(xdg_cache);
192 let _ = fs::create_dir_all(&cache_dir);
193 return Some(cache_dir.join(CACHE_FILE_NAME));
194 }
195
196 if let Some(home) = dirs::home_dir() {
198 let cache_dir = home.join(".cache");
199 let _ = fs::create_dir_all(&cache_dir);
200 return Some(cache_dir.join(CACHE_FILE_NAME));
201 }
202
203 None
204 }
205
206 pub fn load_cache(&mut self) -> Result<(), String> {
208 let cache_path = match &self.cache_path {
209 Some(p) => p,
210 None => return Err("No cache path available".into()),
211 };
212
213 if !cache_path.exists() {
214 return Err("Cache file does not exist".into());
215 }
216
217 let content =
218 fs::read_to_string(cache_path).map_err(|e| format!("Failed to read cache: {}", e))?;
219
220 let cached: CachedPricing =
221 serde_json::from_str(&content).map_err(|e| format!("Failed to parse cache: {}", e))?;
222
223 self.pricing = cached.pricing;
224 self.last_update = Some(cached.last_update);
225
226 Ok(())
227 }
228
229 fn save_cache(&self) -> Result<(), String> {
231 let cache_path = match &self.cache_path {
232 Some(p) => p,
233 None => return Ok(()),
234 };
235
236 let cached = CachedPricing {
237 pricing: self.pricing.clone(),
238 last_update: self.last_update.unwrap_or_else(Utc::now),
239 };
240
241 let content = serde_json::to_string_pretty(&cached)
242 .map_err(|e| format!("Failed to serialize cache: {}", e))?;
243
244 fs::write(cache_path, content).map_err(|e| format!("Failed to write cache: {}", e))?;
245
246 Ok(())
247 }
248
249 pub async fn fetch_pricing(&mut self) -> Result<(), String> {
251 if self.load_cache().is_ok() && !self.needs_refresh() {
253 return Ok(());
254 }
255
256 let response = reqwest::get(PRICING_URL)
258 .await
259 .map_err(|e| format!("Failed to fetch pricing: {}", e))?;
260
261 if !response.status().is_success() {
262 if self.last_update.is_some() {
264 return Ok(());
265 }
266 return Err(format!(
267 "Failed to fetch pricing: HTTP {}",
268 response.status()
269 ));
270 }
271
272 let pricing_data: HashMap<String, ModelPricing> = response
273 .json()
274 .await
275 .map_err(|e| format!("Failed to parse pricing: {}", e))?;
276
277 self.pricing = pricing_data;
278 self.last_update = Some(Utc::now());
279
280 let _ = self.save_cache();
282
283 Ok(())
284 }
285
286 pub fn needs_refresh(&self) -> bool {
288 match self.last_update {
289 None => true,
290 Some(last) => {
291 let elapsed = Utc::now() - last;
292 elapsed > CACHE_DURATION
293 }
294 }
295 }
296
297 pub fn get_model_pricing(&self, model_name: &str) -> Option<&ModelPricing> {
299 if let Some(pricing) = self.pricing.get(model_name) {
301 return Some(pricing);
302 }
303
304 let normalized = normalize_model_name(model_name);
306
307 if let Some(pricing) = self.pricing.get(&normalized) {
309 return Some(pricing);
310 }
311
312 self.pricing.get(&normalized.replace('/', "-"))
314 }
315
316 pub fn calculate_cost(
318 &self,
319 model: &str,
320 usage: &crate::llm::Usage,
321 ) -> Option<TokenCostCalculated> {
322 let pricing = self.get_model_pricing(model)?;
323 Some(pricing.calculate_cost(
324 usage.prompt_tokens,
325 usage.completion_tokens,
326 usage.prompt_cached_tokens.unwrap_or(0),
327 usage.prompt_cache_creation_tokens.unwrap_or(0),
328 ))
329 }
330}
331
332impl Default for TokenCost {
333 fn default() -> Self {
334 Self::new()
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn test_model_pricing() {
344 let pricing = ModelPricing {
345 model: "gpt-4o".to_string(),
346 input_cost_per_token: Some(0.0000025),
347 output_cost_per_token: Some(0.00001),
348 cache_read_input_token_cost: Some(0.00000125),
349 cache_creation_input_token_cost: Some(0.000003125),
350 max_tokens: Some(128000),
351 max_input_tokens: Some(128000),
352 max_output_tokens: Some(4096),
353 };
354
355 let cost = pricing.calculate_cost(1000, 500, 200, 100);
356
357 assert!(cost.prompt_cost > 0.0);
358 assert!(cost.completion_cost > 0.0);
359 assert!(cost.total_cost > 0.0);
360 }
361
362 #[test]
363 fn test_usage_summary() {
364 let mut summary = UsageSummary::new();
365 let usage = crate::llm::Usage::new(100, 50);
366
367 summary.add("gpt-4o", &usage, None);
368
369 assert_eq!(summary.total_prompt_tokens, 100);
370 assert_eq!(summary.total_completion_tokens, 50);
371 assert_eq!(summary.total_tokens, 150);
372 }
373}