use anyhow::Result;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tokio::fs;
use tracing::info;
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostTracker {
pub daily_usage: HashMap<String, DailyUsage>,
pub provider_totals: HashMap<String, ProviderTotal>,
pub model_totals: HashMap<String, ModelTotal>,
pub budgets: Vec<Budget>,
pub alerts: Vec<Alert>,
pub settings: CostSettings,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DailyUsage {
pub date: String,
pub total_cost: f64,
pub total_tokens: u32,
pub requests: u32,
pub by_provider: HashMap<String, ProviderDailyUsage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderDailyUsage {
pub provider: String,
pub cost: f64,
pub input_tokens: u32,
pub output_tokens: u32,
pub requests: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderTotal {
pub provider: String,
pub total_cost: f64,
pub total_tokens: u32,
pub total_requests: u32,
pub first_used: DateTime<Utc>,
pub last_used: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelTotal {
pub model: String,
pub provider: String,
pub total_cost: f64,
pub total_tokens: u32,
pub total_requests: u32,
pub avg_cost_per_request: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Budget {
pub id: String,
pub name: String,
pub budget_type: BudgetType,
pub limit: f64,
pub period: BudgetPeriod,
pub current_spend: f64,
pub alert_threshold: f64,
pub created_at: DateTime<Utc>,
pub reset_at: DateTime<Utc>,
pub is_active: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum BudgetType {
Total,
PerProvider(String),
PerModel(String),
Daily,
Monthly,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum BudgetPeriod {
Daily,
Weekly,
Monthly,
Yearly,
OneTime,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Alert {
pub id: String,
pub timestamp: DateTime<Utc>,
pub alert_type: AlertType,
pub message: String,
pub budget_id: Option<String>,
pub severity: AlertSeverity,
pub acknowledged: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AlertType {
BudgetThreshold,
BudgetExceeded,
UnusualSpending,
CostSpike,
DailyLimit,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AlertSeverity {
Info,
Warning,
Critical,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostSettings {
pub track_costs: bool,
pub default_budget_limit: f64,
pub alert_on_unusual_spending: bool,
pub unusual_spending_threshold: f64,
pub currency: String,
pub cost_optimization_enabled: bool,
}
impl Default for CostSettings {
fn default() -> Self {
Self {
track_costs: true,
default_budget_limit: 100.0,
alert_on_unusual_spending: true,
unusual_spending_threshold: 2.0,
currency: "USD".to_string(),
cost_optimization_enabled: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageRecord {
pub timestamp: DateTime<Utc>,
pub provider: String,
pub model: String,
pub input_tokens: u32,
pub output_tokens: u32,
pub cost: f64,
pub request_type: String,
}
impl CostTracker {
pub fn new() -> Self {
Self {
daily_usage: HashMap::new(),
provider_totals: HashMap::new(),
model_totals: HashMap::new(),
budgets: Vec::new(),
alerts: Vec::new(),
settings: CostSettings::default(),
}
}
pub async fn load(storage_path: &Path) -> Result<Self> {
let path = storage_path.join("costs.json");
if !path.exists() {
return Ok(Self::new());
}
let content = fs::read_to_string(&path).await?;
let tracker: Self = serde_json::from_str(&content)?;
Ok(tracker)
}
pub async fn save(&self, storage_path: &Path) -> Result<()> {
let path = storage_path.join("costs.json");
let content = serde_json::to_string_pretty(self)?;
fs::write(&path, content).await?;
Ok(())
}
pub fn record_usage(&mut self, record: UsageRecord) {
if !self.settings.track_costs {
return;
}
let date_key = record.timestamp.format("%Y-%m-%d").to_string();
let daily = self.daily_usage.entry(date_key.clone()).or_insert_with(|| DailyUsage {
date: date_key,
total_cost: 0.0,
total_tokens: 0,
requests: 0,
by_provider: HashMap::new(),
});
daily.total_cost += record.cost;
daily.total_tokens += record.input_tokens + record.output_tokens;
daily.requests += 1;
let provider_daily = daily.by_provider.entry(record.provider.clone()).or_insert_with(|| ProviderDailyUsage {
provider: record.provider.clone(),
cost: 0.0,
input_tokens: 0,
output_tokens: 0,
requests: 0,
});
provider_daily.cost += record.cost;
provider_daily.input_tokens += record.input_tokens;
provider_daily.output_tokens += record.output_tokens;
provider_daily.requests += 1;
let provider_total = self.provider_totals.entry(record.provider.clone()).or_insert_with(|| ProviderTotal {
provider: record.provider.clone(),
total_cost: 0.0,
total_tokens: 0,
total_requests: 0,
first_used: record.timestamp,
last_used: record.timestamp,
});
provider_total.total_cost += record.cost;
provider_total.total_tokens += record.input_tokens + record.output_tokens;
provider_total.total_requests += 1;
provider_total.last_used = record.timestamp;
let model_key = format!("{}/{}", record.provider, record.model);
let model_total = self.model_totals.entry(model_key.clone()).or_insert_with(|| ModelTotal {
model: record.model.clone(),
provider: record.provider.clone(),
total_cost: 0.0,
total_tokens: 0,
total_requests: 0,
avg_cost_per_request: 0.0,
});
model_total.total_cost += record.cost;
model_total.total_tokens += record.input_tokens + record.output_tokens;
model_total.total_requests += 1;
model_total.avg_cost_per_request = model_total.total_cost / model_total.total_requests as f64;
self.check_budgets(&record);
}
fn check_budgets(&mut self, record: &UsageRecord) {
let monthly_spend = self.get_monthly_spend();
let today = Utc::now().format("%Y-%m-%d").to_string();
let daily_spend = self.daily_usage.get(&today).map(|d| d.total_cost).unwrap_or(0.0);
for budget in &mut self.budgets {
if !budget.is_active {
continue;
}
let should_alert = match &budget.budget_type {
BudgetType::Total => {
match budget.period {
BudgetPeriod::Daily => daily_spend > budget.limit,
BudgetPeriod::Monthly => monthly_spend > budget.limit,
BudgetPeriod::OneTime => budget.current_spend + record.cost > budget.limit,
_ => false,
}
}
BudgetType::PerProvider(provider) if provider == &record.provider => {
budget.current_spend + record.cost > budget.limit
}
BudgetType::PerModel(model) if model == &record.model => {
budget.current_spend + record.cost > budget.limit
}
_ => false,
};
if should_alert {
let alert = Alert {
id: uuid::Uuid::new_v4().to_string(),
timestamp: Utc::now(),
alert_type: if budget.current_spend >= budget.limit {
AlertType::BudgetExceeded
} else {
AlertType::BudgetThreshold
},
message: format!(
"Budget '{}' is at {:.1}% (${:.2} / ${:.2})",
budget.name,
(budget.current_spend / budget.limit) * 100.0,
budget.current_spend,
budget.limit
),
budget_id: Some(budget.id.clone()),
severity: if budget.current_spend >= budget.limit {
AlertSeverity::Critical
} else {
AlertSeverity::Warning
},
acknowledged: false,
};
self.alerts.push(alert);
}
match &budget.budget_type {
BudgetType::Total | BudgetType::PerProvider(_) | BudgetType::PerModel(_) => {
budget.current_spend += record.cost;
}
_ => {}
}
}
}
pub fn get_monthly_spend(&self) -> f64 {
let now = Utc::now();
let year_month = now.format("%Y-%m").to_string();
self.daily_usage
.iter()
.filter(|(date, _)| date.starts_with(&year_month))
.map(|(_, usage)| usage.total_cost)
.sum()
}
pub fn get_stats(&self) -> CostStats {
let total_cost: f64 = self.provider_totals.values().map(|p| p.total_cost).sum();
let total_tokens: u32 = self.provider_totals.values().map(|p| p.total_tokens).sum();
let total_requests: u32 = self.provider_totals.values().map(|p| p.total_requests).sum();
CostStats {
total_cost,
total_tokens,
total_requests,
avg_cost_per_request: if total_requests > 0 {
total_cost / total_requests as f64
} else {
0.0
},
avg_cost_per_1k_tokens: if total_tokens > 0 {
(total_cost / total_tokens as f64) * 1000.0
} else {
0.0
},
monthly_spend: self.get_monthly_spend(),
daily_average: if !self.daily_usage.is_empty() {
total_cost / self.daily_usage.len() as f64
} else {
0.0
},
active_budgets: self.budgets.iter().filter(|b| b.is_active).count(),
pending_alerts: self.alerts.iter().filter(|a| !a.acknowledged).count(),
}
}
pub fn create_budget(
&mut self,
name: &str,
budget_type: BudgetType,
limit: f64,
period: BudgetPeriod,
alert_threshold: f64,
) -> String {
let id = uuid::Uuid::new_v4().to_string();
let budget = Budget {
id: id.clone(),
name: name.to_string(),
budget_type,
limit,
period,
current_spend: 0.0,
alert_threshold,
created_at: Utc::now(),
reset_at: Utc::now(),
is_active: true,
};
self.budgets.push(budget);
info!("Created budget: {} (${:.2})", name, limit);
id
}
}
#[derive(Debug, Clone)]
pub struct CostStats {
pub total_cost: f64,
pub total_tokens: u32,
pub total_requests: u32,
pub avg_cost_per_request: f64,
pub avg_cost_per_1k_tokens: f64,
pub monthly_spend: f64,
pub daily_average: f64,
pub active_budgets: usize,
pub pending_alerts: usize,
}