Skip to main content

agent_io/tokens/
service.rs

1//! Token usage tracking and cost calculation
2
3use chrono::{DateTime, Duration, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fs;
7use std::path::PathBuf;
8
9use crate::tokens::mappings::normalize_model_name;
10
11/// Default pricing URL (LiteLLM model prices)
12pub const PRICING_URL: &str =
13    "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json";
14
15/// Cache duration for pricing data
16pub const CACHE_DURATION: Duration = Duration::hours(24);
17
18/// Cache file name
19pub const CACHE_FILE_NAME: &str = "agent-io-pricing-cache.json";
20
21/// Model pricing information
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ModelPricing {
24    pub model: String,
25    pub input_cost_per_token: Option<f64>,
26    pub output_cost_per_token: Option<f64>,
27    pub cache_read_input_token_cost: Option<f64>,
28    pub cache_creation_input_token_cost: Option<f64>,
29    pub max_tokens: Option<u64>,
30    pub max_input_tokens: Option<u64>,
31    pub max_output_tokens: Option<u64>,
32}
33
34impl ModelPricing {
35    /// Calculate cost for given token usage
36    pub fn calculate_cost(
37        &self,
38        input_tokens: u64,
39        output_tokens: u64,
40        cached_tokens: u64,
41        cache_creation_tokens: u64,
42    ) -> TokenCostCalculated {
43        let mut prompt_cost = 0.0;
44        let mut completion_cost = 0.0;
45
46        // Input tokens cost
47        if let Some(cost) = self.input_cost_per_token {
48            prompt_cost += (input_tokens as f64) * cost;
49        }
50
51        // Cached tokens cost (usually cheaper)
52        if let Some(cost) = self.cache_read_input_token_cost {
53            prompt_cost -= (input_tokens as f64) * (self.input_cost_per_token.unwrap_or(0.0));
54            prompt_cost += (cached_tokens as f64) * cost;
55        }
56
57        // Cache creation cost
58        if let Some(cost) = self.cache_creation_input_token_cost {
59            prompt_cost += (cache_creation_tokens as f64) * cost;
60        }
61
62        // Output tokens cost
63        if let Some(cost) = self.output_cost_per_token {
64            completion_cost = (output_tokens as f64) * cost;
65        }
66
67        TokenCostCalculated {
68            prompt_cost,
69            completion_cost,
70            total_cost: prompt_cost + completion_cost,
71        }
72    }
73}
74
75/// Calculated token cost
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct TokenCostCalculated {
78    pub prompt_cost: f64,
79    pub completion_cost: f64,
80    pub total_cost: f64,
81}
82
83/// Per-model usage statistics
84#[derive(Debug, Clone, Default, Serialize, Deserialize)]
85pub struct ModelUsageStats {
86    pub prompt_tokens: u64,
87    pub completion_tokens: u64,
88    pub total_tokens: u64,
89    pub prompt_cost: f64,
90    pub completion_cost: f64,
91    pub total_cost: f64,
92    pub calls: u64,
93}
94
95/// Usage summary for the session
96#[derive(Debug, Clone, Default, Serialize, Deserialize)]
97pub struct UsageSummary {
98    pub total_prompt_tokens: u64,
99    pub total_prompt_cost: f64,
100    pub total_completion_tokens: u64,
101    pub total_completion_cost: f64,
102    pub total_tokens: u64,
103    pub total_cost: f64,
104    pub by_model: HashMap<String, ModelUsageStats>,
105}
106
107impl UsageSummary {
108    pub fn new() -> Self {
109        Self::default()
110    }
111
112    /// Add usage from a completion
113    pub fn add(&mut self, model: &str, usage: &crate::llm::Usage, pricing: Option<&ModelPricing>) {
114        self.total_prompt_tokens += usage.prompt_tokens;
115        self.total_completion_tokens += usage.completion_tokens;
116        self.total_tokens += usage.total_tokens;
117
118        let model_stats = self.by_model.entry(model.to_string()).or_default();
119        model_stats.prompt_tokens += usage.prompt_tokens;
120        model_stats.completion_tokens += usage.completion_tokens;
121        model_stats.total_tokens += usage.total_tokens;
122        model_stats.calls += 1;
123
124        if let Some(pricing) = pricing {
125            let cost = pricing.calculate_cost(
126                usage.prompt_tokens,
127                usage.completion_tokens,
128                usage.prompt_cached_tokens.unwrap_or(0),
129                usage.prompt_cache_creation_tokens.unwrap_or(0),
130            );
131
132            self.total_prompt_cost += cost.prompt_cost;
133            self.total_completion_cost += cost.completion_cost;
134            self.total_cost += cost.total_cost;
135
136            model_stats.prompt_cost += cost.prompt_cost;
137            model_stats.completion_cost += cost.completion_cost;
138            model_stats.total_cost += cost.total_cost;
139        }
140    }
141
142    /// Merge another usage summary
143    pub fn merge(&mut self, other: &UsageSummary) {
144        self.total_prompt_tokens += other.total_prompt_tokens;
145        self.total_prompt_cost += other.total_prompt_cost;
146        self.total_completion_tokens += other.total_completion_tokens;
147        self.total_completion_cost += other.total_completion_cost;
148        self.total_tokens += other.total_tokens;
149        self.total_cost += other.total_cost;
150
151        for (model, stats) in &other.by_model {
152            let entry = self.by_model.entry(model.clone()).or_default();
153            entry.prompt_tokens += stats.prompt_tokens;
154            entry.completion_tokens += stats.completion_tokens;
155            entry.total_tokens += stats.total_tokens;
156            entry.prompt_cost += stats.prompt_cost;
157            entry.completion_cost += stats.completion_cost;
158            entry.total_cost += stats.total_cost;
159            entry.calls += stats.calls;
160        }
161    }
162}
163
164/// Cached pricing data
165#[derive(Debug, Clone, Serialize, Deserialize)]
166struct CachedPricing {
167    pricing: HashMap<String, ModelPricing>,
168    last_update: DateTime<Utc>,
169}
170
171/// Token cost service
172pub struct TokenCost {
173    pricing: HashMap<String, ModelPricing>,
174    last_update: Option<DateTime<Utc>>,
175    cache_path: Option<PathBuf>,
176}
177
178impl TokenCost {
179    pub fn new() -> Self {
180        Self {
181            pricing: HashMap::new(),
182            last_update: None,
183            cache_path: Self::get_cache_path(),
184        }
185    }
186
187    /// Get the cache file path
188    fn get_cache_path() -> Option<PathBuf> {
189        // Try XDG cache directory first
190        if let Ok(xdg_cache) = std::env::var("XDG_CACHE_HOME") {
191            let cache_dir = PathBuf::from(xdg_cache);
192            let _ = fs::create_dir_all(&cache_dir);
193            return Some(cache_dir.join(CACHE_FILE_NAME));
194        }
195
196        // Fallback to home directory
197        if let Some(home) = dirs::home_dir() {
198            let cache_dir = home.join(".cache");
199            let _ = fs::create_dir_all(&cache_dir);
200            return Some(cache_dir.join(CACHE_FILE_NAME));
201        }
202
203        None
204    }
205
206    /// Load pricing from local cache
207    pub fn load_cache(&mut self) -> Result<(), String> {
208        let cache_path = match &self.cache_path {
209            Some(p) => p,
210            None => return Err("No cache path available".into()),
211        };
212
213        if !cache_path.exists() {
214            return Err("Cache file does not exist".into());
215        }
216
217        let content =
218            fs::read_to_string(cache_path).map_err(|e| format!("Failed to read cache: {}", e))?;
219
220        let cached: CachedPricing =
221            serde_json::from_str(&content).map_err(|e| format!("Failed to parse cache: {}", e))?;
222
223        self.pricing = cached.pricing;
224        self.last_update = Some(cached.last_update);
225
226        Ok(())
227    }
228
229    /// Save pricing to local cache
230    fn save_cache(&self) -> Result<(), String> {
231        let cache_path = match &self.cache_path {
232            Some(p) => p,
233            None => return Ok(()),
234        };
235
236        let cached = CachedPricing {
237            pricing: self.pricing.clone(),
238            last_update: self.last_update.unwrap_or_else(Utc::now),
239        };
240
241        let content = serde_json::to_string_pretty(&cached)
242            .map_err(|e| format!("Failed to serialize cache: {}", e))?;
243
244        fs::write(cache_path, content).map_err(|e| format!("Failed to write cache: {}", e))?;
245
246        Ok(())
247    }
248
249    /// Fetch pricing data from remote or load from cache
250    pub async fn fetch_pricing(&mut self) -> Result<(), String> {
251        // Try to load from cache first
252        if self.load_cache().is_ok() && !self.needs_refresh() {
253            return Ok(());
254        }
255
256        // Fetch from remote
257        let response = reqwest::get(PRICING_URL)
258            .await
259            .map_err(|e| format!("Failed to fetch pricing: {}", e))?;
260
261        if !response.status().is_success() {
262            // If we have cached data, use it even if expired
263            if self.last_update.is_some() {
264                return Ok(());
265            }
266            return Err(format!(
267                "Failed to fetch pricing: HTTP {}",
268                response.status()
269            ));
270        }
271
272        let pricing_data: HashMap<String, ModelPricing> = response
273            .json()
274            .await
275            .map_err(|e| format!("Failed to parse pricing: {}", e))?;
276
277        self.pricing = pricing_data;
278        self.last_update = Some(Utc::now());
279
280        // Save to cache
281        let _ = self.save_cache();
282
283        Ok(())
284    }
285
286    /// Check if pricing needs refresh
287    pub fn needs_refresh(&self) -> bool {
288        match self.last_update {
289            None => true,
290            Some(last) => {
291                let elapsed = Utc::now() - last;
292                elapsed > CACHE_DURATION
293            }
294        }
295    }
296
297    /// Get pricing for a model
298    pub fn get_model_pricing(&self, model_name: &str) -> Option<&ModelPricing> {
299        // Try exact match first
300        if let Some(pricing) = self.pricing.get(model_name) {
301            return Some(pricing);
302        }
303
304        // Try normalized model name
305        let normalized = normalize_model_name(model_name);
306
307        // Try with the normalized name
308        if let Some(pricing) = self.pricing.get(&normalized) {
309            return Some(pricing);
310        }
311
312        // Try without provider prefix
313        self.pricing.get(&normalized.replace('/', "-"))
314    }
315
316    /// Calculate cost for a completion
317    pub fn calculate_cost(
318        &self,
319        model: &str,
320        usage: &crate::llm::Usage,
321    ) -> Option<TokenCostCalculated> {
322        let pricing = self.get_model_pricing(model)?;
323        Some(pricing.calculate_cost(
324            usage.prompt_tokens,
325            usage.completion_tokens,
326            usage.prompt_cached_tokens.unwrap_or(0),
327            usage.prompt_cache_creation_tokens.unwrap_or(0),
328        ))
329    }
330}
331
332impl Default for TokenCost {
333    fn default() -> Self {
334        Self::new()
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn test_model_pricing() {
344        let pricing = ModelPricing {
345            model: "gpt-4o".to_string(),
346            input_cost_per_token: Some(0.0000025),
347            output_cost_per_token: Some(0.00001),
348            cache_read_input_token_cost: Some(0.00000125),
349            cache_creation_input_token_cost: Some(0.000003125),
350            max_tokens: Some(128000),
351            max_input_tokens: Some(128000),
352            max_output_tokens: Some(4096),
353        };
354
355        let cost = pricing.calculate_cost(1000, 500, 200, 100);
356
357        assert!(cost.prompt_cost > 0.0);
358        assert!(cost.completion_cost > 0.0);
359        assert!(cost.total_cost > 0.0);
360    }
361
362    #[test]
363    fn test_usage_summary() {
364        let mut summary = UsageSummary::new();
365        let usage = crate::llm::Usage::new(100, 50);
366
367        summary.add("gpt-4o", &usage, None);
368
369        assert_eq!(summary.total_prompt_tokens, 100);
370        assert_eq!(summary.total_completion_tokens, 50);
371        assert_eq!(summary.total_tokens, 150);
372    }
373}