use chrono::{DateTime, Utc};
use dashmap::DashMap;
use super::error::CodingAgentError;
use super::models::TokenUsage;
#[derive(Debug, Clone)]
pub struct AgentCostRecord {
pub agent_id: String,
pub total_input_tokens: u64,
pub total_output_tokens: u64,
pub estimated_total_cost_usd: f64,
pub task_count: u64,
pub period_start: DateTime<Utc>,
}
pub struct CostTracker {
agent_costs: DashMap<String, AgentCostRecord>,
task_caps: DashMap<String, f64>,
}
impl CostTracker {
pub fn new() -> Self {
Self {
agent_costs: DashMap::new(),
task_caps: DashMap::new(),
}
}
pub fn record_usage(&self, agent_id: &str, usage: &TokenUsage) {
self.agent_costs
.entry(agent_id.to_string())
.and_modify(|record| {
record.total_input_tokens += usage.input_tokens;
record.total_output_tokens += usage.output_tokens;
record.estimated_total_cost_usd += usage.estimated_cost_usd;
record.task_count += 1;
})
.or_insert_with(|| AgentCostRecord {
agent_id: agent_id.to_string(),
total_input_tokens: usage.input_tokens,
total_output_tokens: usage.output_tokens,
estimated_total_cost_usd: usage.estimated_cost_usd,
task_count: 1,
period_start: Utc::now(),
});
}
pub fn check_cost_cap(
&self,
_agent_id: &str,
current_cost: f64,
cap: f64,
) -> Result<(), CodingAgentError> {
if current_cost > cap {
Err(CodingAgentError::CostCapExceeded {
spent_usd: current_cost,
cap_usd: cap,
})
} else {
Ok(())
}
}
pub fn get_agent_stats(&self, agent_id: &str) -> Option<AgentCostRecord> {
self.agent_costs.get(agent_id).map(|r| r.clone())
}
pub fn reset_period(&self, agent_id: &str) {
if let Some(mut record) = self.agent_costs.get_mut(agent_id) {
record.total_input_tokens = 0;
record.total_output_tokens = 0;
record.estimated_total_cost_usd = 0.0;
record.task_count = 0;
record.period_start = Utc::now();
}
}
pub fn set_task_cap(&self, agent_id: &str, cap_usd: f64) {
self.task_caps.insert(agent_id.to_string(), cap_usd);
}
pub fn get_task_cap(&self, agent_id: &str) -> Option<f64> {
self.task_caps.get(agent_id).map(|v| *v)
}
}
impl Default for CostTracker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_usage(input: u64, output: u64, cost: f64) -> TokenUsage {
TokenUsage {
input_tokens: input,
output_tokens: output,
estimated_cost_usd: cost,
}
}
#[test]
fn test_new_tracker_has_no_records() {
let tracker = CostTracker::new();
assert!(tracker.get_agent_stats("agent-1").is_none());
}
#[test]
fn test_record_usage_creates_new_record() {
let tracker = CostTracker::new();
let usage = make_usage(100, 50, 0.01);
tracker.record_usage("agent-1", &usage);
let stats = tracker.get_agent_stats("agent-1").unwrap();
assert_eq!(stats.agent_id, "agent-1");
assert_eq!(stats.total_input_tokens, 100);
assert_eq!(stats.total_output_tokens, 50);
assert!((stats.estimated_total_cost_usd - 0.01).abs() < f64::EPSILON);
assert_eq!(stats.task_count, 1);
}
#[test]
fn test_record_usage_accumulates_totals() {
let tracker = CostTracker::new();
tracker.record_usage("agent-1", &make_usage(100, 50, 0.01));
tracker.record_usage("agent-1", &make_usage(200, 100, 0.02));
tracker.record_usage("agent-1", &make_usage(300, 150, 0.03));
let stats = tracker.get_agent_stats("agent-1").unwrap();
assert_eq!(stats.total_input_tokens, 600);
assert_eq!(stats.total_output_tokens, 300);
assert!((stats.estimated_total_cost_usd - 0.06).abs() < 1e-10);
assert_eq!(stats.task_count, 3);
}
#[test]
fn test_record_usage_separate_agents() {
let tracker = CostTracker::new();
tracker.record_usage("agent-1", &make_usage(100, 50, 0.01));
tracker.record_usage("agent-2", &make_usage(200, 100, 0.02));
let stats1 = tracker.get_agent_stats("agent-1").unwrap();
let stats2 = tracker.get_agent_stats("agent-2").unwrap();
assert_eq!(stats1.total_input_tokens, 100);
assert_eq!(stats2.total_input_tokens, 200);
assert_eq!(stats1.task_count, 1);
assert_eq!(stats2.task_count, 1);
}
#[test]
fn test_check_cost_cap_within_limit() {
let tracker = CostTracker::new();
let result = tracker.check_cost_cap("agent-1", 4.99, 5.0);
assert!(result.is_ok());
}
#[test]
fn test_check_cost_cap_at_exact_limit() {
let tracker = CostTracker::new();
let result = tracker.check_cost_cap("agent-1", 5.0, 5.0);
assert!(result.is_ok());
}
#[test]
fn test_check_cost_cap_exceeded() {
let tracker = CostTracker::new();
let result = tracker.check_cost_cap("agent-1", 5.01, 5.0);
assert!(result.is_err());
match result.unwrap_err() {
CodingAgentError::CostCapExceeded { spent_usd, cap_usd } => {
assert!((spent_usd - 5.01).abs() < f64::EPSILON);
assert!((cap_usd - 5.0).abs() < f64::EPSILON);
}
other => panic!("Expected CostCapExceeded, got: {:?}", other),
}
}
#[test]
fn test_get_agent_stats_returns_none_for_unknown() {
let tracker = CostTracker::new();
assert!(tracker.get_agent_stats("nonexistent").is_none());
}
#[test]
fn test_reset_period_clears_totals() {
let tracker = CostTracker::new();
tracker.record_usage("agent-1", &make_usage(1000, 500, 1.50));
tracker.record_usage("agent-1", &make_usage(2000, 1000, 3.00));
let before = tracker.get_agent_stats("agent-1").unwrap();
assert_eq!(before.total_input_tokens, 3000);
assert_eq!(before.task_count, 2);
tracker.reset_period("agent-1");
let after = tracker.get_agent_stats("agent-1").unwrap();
assert_eq!(after.total_input_tokens, 0);
assert_eq!(after.total_output_tokens, 0);
assert!((after.estimated_total_cost_usd).abs() < f64::EPSILON);
assert_eq!(after.task_count, 0);
assert_eq!(after.agent_id, "agent-1");
assert!(after.period_start >= before.period_start);
}
#[test]
fn test_reset_period_no_op_for_unknown_agent() {
let tracker = CostTracker::new();
tracker.reset_period("nonexistent");
assert!(tracker.get_agent_stats("nonexistent").is_none());
}
#[test]
fn test_set_and_get_task_cap() {
let tracker = CostTracker::new();
tracker.set_task_cap("agent-1", 10.0);
let cap = tracker.get_task_cap("agent-1");
assert_eq!(cap, Some(10.0));
}
#[test]
fn test_get_task_cap_returns_none_for_unknown() {
let tracker = CostTracker::new();
assert!(tracker.get_task_cap("nonexistent").is_none());
}
#[test]
fn test_default_creates_empty_tracker() {
let tracker = CostTracker::default();
assert!(tracker.get_agent_stats("any").is_none());
}
#[test]
fn test_zero_usage_record() {
let tracker = CostTracker::new();
tracker.record_usage("agent-1", &make_usage(0, 0, 0.0));
let stats = tracker.get_agent_stats("agent-1").unwrap();
assert_eq!(stats.total_input_tokens, 0);
assert_eq!(stats.total_output_tokens, 0);
assert!((stats.estimated_total_cost_usd).abs() < f64::EPSILON);
assert_eq!(stats.task_count, 1);
}
#[test]
fn test_check_cost_cap_zero_cap() {
let tracker = CostTracker::new();
let result = tracker.check_cost_cap("agent-1", 0.001, 0.0);
assert!(result.is_err());
}
#[test]
fn test_check_cost_cap_zero_cost_zero_cap() {
let tracker = CostTracker::new();
let result = tracker.check_cost_cap("agent-1", 0.0, 0.0);
assert!(result.is_ok());
}
}