use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use crate::tokens::mappings::normalize_model_name;
pub const PRICING_URL: &str =
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json";
pub const CACHE_DURATION: Duration = Duration::hours(24);
pub const CACHE_FILE_NAME: &str = "agent-io-pricing-cache.json";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelPricing {
pub model: String,
pub input_cost_per_token: Option<f64>,
pub output_cost_per_token: Option<f64>,
pub cache_read_input_token_cost: Option<f64>,
pub cache_creation_input_token_cost: Option<f64>,
pub max_tokens: Option<u64>,
pub max_input_tokens: Option<u64>,
pub max_output_tokens: Option<u64>,
}
impl ModelPricing {
pub fn calculate_cost(
&self,
input_tokens: u64,
output_tokens: u64,
cached_tokens: u64,
cache_creation_tokens: u64,
) -> TokenCostCalculated {
let mut prompt_cost = 0.0;
let mut completion_cost = 0.0;
if let Some(cost) = self.input_cost_per_token {
prompt_cost += (input_tokens as f64) * cost;
}
if let Some(cost) = self.cache_read_input_token_cost {
prompt_cost -= (input_tokens as f64) * (self.input_cost_per_token.unwrap_or(0.0));
prompt_cost += (cached_tokens as f64) * cost;
}
if let Some(cost) = self.cache_creation_input_token_cost {
prompt_cost += (cache_creation_tokens as f64) * cost;
}
if let Some(cost) = self.output_cost_per_token {
completion_cost = (output_tokens as f64) * cost;
}
TokenCostCalculated {
prompt_cost,
completion_cost,
total_cost: prompt_cost + completion_cost,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenCostCalculated {
pub prompt_cost: f64,
pub completion_cost: f64,
pub total_cost: f64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ModelUsageStats {
pub prompt_tokens: u64,
pub completion_tokens: u64,
pub total_tokens: u64,
pub prompt_cost: f64,
pub completion_cost: f64,
pub total_cost: f64,
pub calls: u64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct UsageSummary {
pub total_prompt_tokens: u64,
pub total_prompt_cost: f64,
pub total_completion_tokens: u64,
pub total_completion_cost: f64,
pub total_tokens: u64,
pub total_cost: f64,
pub by_model: HashMap<String, ModelUsageStats>,
}
impl UsageSummary {
pub fn new() -> Self {
Self::default()
}
pub fn add(&mut self, model: &str, usage: &crate::llm::Usage, pricing: Option<&ModelPricing>) {
self.total_prompt_tokens += usage.prompt_tokens;
self.total_completion_tokens += usage.completion_tokens;
self.total_tokens += usage.total_tokens;
let model_stats = self.by_model.entry(model.to_string()).or_default();
model_stats.prompt_tokens += usage.prompt_tokens;
model_stats.completion_tokens += usage.completion_tokens;
model_stats.total_tokens += usage.total_tokens;
model_stats.calls += 1;
if let Some(pricing) = pricing {
let cost = pricing.calculate_cost(
usage.prompt_tokens,
usage.completion_tokens,
usage.prompt_cached_tokens.unwrap_or(0),
usage.prompt_cache_creation_tokens.unwrap_or(0),
);
self.total_prompt_cost += cost.prompt_cost;
self.total_completion_cost += cost.completion_cost;
self.total_cost += cost.total_cost;
model_stats.prompt_cost += cost.prompt_cost;
model_stats.completion_cost += cost.completion_cost;
model_stats.total_cost += cost.total_cost;
}
}
pub fn merge(&mut self, other: &UsageSummary) {
self.total_prompt_tokens += other.total_prompt_tokens;
self.total_prompt_cost += other.total_prompt_cost;
self.total_completion_tokens += other.total_completion_tokens;
self.total_completion_cost += other.total_completion_cost;
self.total_tokens += other.total_tokens;
self.total_cost += other.total_cost;
for (model, stats) in &other.by_model {
let entry = self.by_model.entry(model.clone()).or_default();
entry.prompt_tokens += stats.prompt_tokens;
entry.completion_tokens += stats.completion_tokens;
entry.total_tokens += stats.total_tokens;
entry.prompt_cost += stats.prompt_cost;
entry.completion_cost += stats.completion_cost;
entry.total_cost += stats.total_cost;
entry.calls += stats.calls;
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CachedPricing {
pricing: HashMap<String, ModelPricing>,
last_update: DateTime<Utc>,
}
pub struct TokenCost {
pricing: HashMap<String, ModelPricing>,
last_update: Option<DateTime<Utc>>,
cache_path: Option<PathBuf>,
}
impl TokenCost {
pub fn new() -> Self {
Self {
pricing: HashMap::new(),
last_update: None,
cache_path: Self::get_cache_path(),
}
}
fn get_cache_path() -> Option<PathBuf> {
if let Ok(xdg_cache) = std::env::var("XDG_CACHE_HOME") {
let cache_dir = PathBuf::from(xdg_cache);
let _ = fs::create_dir_all(&cache_dir);
return Some(cache_dir.join(CACHE_FILE_NAME));
}
if let Some(home) = dirs::home_dir() {
let cache_dir = home.join(".cache");
let _ = fs::create_dir_all(&cache_dir);
return Some(cache_dir.join(CACHE_FILE_NAME));
}
None
}
pub fn load_cache(&mut self) -> Result<(), String> {
let cache_path = match &self.cache_path {
Some(p) => p,
None => return Err("No cache path available".into()),
};
if !cache_path.exists() {
return Err("Cache file does not exist".into());
}
let content =
fs::read_to_string(cache_path).map_err(|e| format!("Failed to read cache: {}", e))?;
let cached: CachedPricing =
serde_json::from_str(&content).map_err(|e| format!("Failed to parse cache: {}", e))?;
self.pricing = cached.pricing;
self.last_update = Some(cached.last_update);
Ok(())
}
fn save_cache(&self) -> Result<(), String> {
let cache_path = match &self.cache_path {
Some(p) => p,
None => return Ok(()),
};
let cached = CachedPricing {
pricing: self.pricing.clone(),
last_update: self.last_update.unwrap_or_else(Utc::now),
};
let content = serde_json::to_string_pretty(&cached)
.map_err(|e| format!("Failed to serialize cache: {}", e))?;
fs::write(cache_path, content).map_err(|e| format!("Failed to write cache: {}", e))?;
Ok(())
}
pub async fn fetch_pricing(&mut self) -> Result<(), String> {
if self.load_cache().is_ok() && !self.needs_refresh() {
return Ok(());
}
let response = reqwest::get(PRICING_URL)
.await
.map_err(|e| format!("Failed to fetch pricing: {}", e))?;
if !response.status().is_success() {
if self.last_update.is_some() {
return Ok(());
}
return Err(format!(
"Failed to fetch pricing: HTTP {}",
response.status()
));
}
let pricing_data: HashMap<String, ModelPricing> = response
.json()
.await
.map_err(|e| format!("Failed to parse pricing: {}", e))?;
self.pricing = pricing_data;
self.last_update = Some(Utc::now());
let _ = self.save_cache();
Ok(())
}
pub fn needs_refresh(&self) -> bool {
match self.last_update {
None => true,
Some(last) => {
let elapsed = Utc::now() - last;
elapsed > CACHE_DURATION
}
}
}
pub fn get_model_pricing(&self, model_name: &str) -> Option<&ModelPricing> {
if let Some(pricing) = self.pricing.get(model_name) {
return Some(pricing);
}
let normalized = normalize_model_name(model_name);
if let Some(pricing) = self.pricing.get(&normalized) {
return Some(pricing);
}
self.pricing.get(&normalized.replace('/', "-"))
}
pub fn calculate_cost(
&self,
model: &str,
usage: &crate::llm::Usage,
) -> Option<TokenCostCalculated> {
let pricing = self.get_model_pricing(model)?;
Some(pricing.calculate_cost(
usage.prompt_tokens,
usage.completion_tokens,
usage.prompt_cached_tokens.unwrap_or(0),
usage.prompt_cache_creation_tokens.unwrap_or(0),
))
}
}
impl Default for TokenCost {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_pricing() {
let pricing = ModelPricing {
model: "gpt-4o".to_string(),
input_cost_per_token: Some(0.0000025),
output_cost_per_token: Some(0.00001),
cache_read_input_token_cost: Some(0.00000125),
cache_creation_input_token_cost: Some(0.000003125),
max_tokens: Some(128000),
max_input_tokens: Some(128000),
max_output_tokens: Some(4096),
};
let cost = pricing.calculate_cost(1000, 500, 200, 100);
assert!(cost.prompt_cost > 0.0);
assert!(cost.completion_cost > 0.0);
assert!(cost.total_cost > 0.0);
}
#[test]
fn test_usage_summary() {
let mut summary = UsageSummary::new();
let usage = crate::llm::Usage::new(100, 50);
summary.add("gpt-4o", &usage, None);
assert_eq!(summary.total_prompt_tokens, 100);
assert_eq!(summary.total_completion_tokens, 50);
assert_eq!(summary.total_tokens, 150);
}
}