use crate::types::AgentRole;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use tracing::{info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenBudget {
pub max_input_tokens: u64,
pub max_output_tokens: u64,
pub max_total_tokens: u64,
pub max_tool_calls: u32,
pub max_duration_secs: u64,
}
impl TokenBudget {
pub fn new(
max_input_tokens: u64,
max_output_tokens: u64,
max_total_tokens: u64,
max_tool_calls: u32,
max_duration_secs: u64,
) -> Self {
Self {
max_input_tokens,
max_output_tokens,
max_total_tokens,
max_tool_calls,
max_duration_secs,
}
}
}
pub fn default_budget(role: &AgentRole) -> TokenBudget {
match role {
AgentRole::Orchestrator => TokenBudget::new(50_000, 10_000, 60_000, 100, 300),
AgentRole::Coder => TokenBudget::new(100_000, 50_000, 150_000, 50, 600),
AgentRole::Tester => TokenBudget::new(50_000, 20_000, 70_000, 30, 300),
AgentRole::Reviewer => TokenBudget::new(30_000, 10_000, 40_000, 10, 120),
_ => TokenBudget::new(50_000, 20_000, 70_000, 30, 300),
}
}
#[derive(Debug, Clone)]
pub struct AgentUsage {
pub input_tokens: u64,
pub output_tokens: u64,
pub tool_calls: u32,
pub start_time: Instant,
pub role: AgentRole,
}
impl AgentUsage {
fn new(role: AgentRole) -> Self {
Self {
input_tokens: 0,
output_tokens: 0,
tool_calls: 0,
start_time: Instant::now(),
role,
}
}
pub fn elapsed_secs(&self) -> u64 {
self.start_time.elapsed().as_secs()
}
pub fn total_tokens(&self) -> u64 {
self.input_tokens + self.output_tokens
}
}
impl Serialize for AgentUsage {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("AgentUsage", 5)?;
state.serialize_field("input_tokens", &self.input_tokens)?;
state.serialize_field("output_tokens", &self.output_tokens)?;
state.serialize_field("tool_calls", &self.tool_calls)?;
state.serialize_field("elapsed_secs", &self.elapsed_secs())?;
state.serialize_field("role", &self.role)?;
state.end()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum BudgetStatus {
WithinBudget,
Warning {
resource: String,
usage_pct: f64,
},
Exceeded {
resource: String,
limit: u64,
used: u64,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BudgetSummary {
pub total_input_tokens: u64,
pub total_output_tokens: u64,
pub total_tool_calls: u32,
pub per_agent: Vec<AgentUsageEntry>,
pub estimated_cost_usd: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentUsageEntry {
pub role: AgentRole,
pub input_tokens: u64,
pub output_tokens: u64,
pub tool_calls: u32,
pub elapsed_secs: u64,
}
#[derive(Debug)]
struct Inner {
budgets: HashMap<AgentRole, TokenBudget>,
usage: HashMap<AgentRole, AgentUsage>,
}
impl Inner {
fn new() -> Self {
Self {
budgets: HashMap::new(),
usage: HashMap::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct BudgetTracker {
inner: Arc<RwLock<Inner>>,
}
impl BudgetTracker {
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(Inner::new())),
}
}
pub async fn set_budget(&self, role: AgentRole, budget: TokenBudget) {
info!(role = %role, max_input = budget.max_input_tokens, max_output = budget.max_output_tokens, "budget set");
self.inner.write().await.budgets.insert(role, budget);
}
pub async fn start_tracking(&self, role: AgentRole) {
info!(role = %role, "start tracking budget");
let usage = AgentUsage::new(role.clone());
self.inner.write().await.usage.insert(role, usage);
}
pub async fn record_tokens(&self, role: &AgentRole, input: u64, output: u64) {
let mut guard = self.inner.write().await;
if let Some(usage) = guard.usage.get_mut(role) {
usage.input_tokens += input;
usage.output_tokens += output;
}
}
pub async fn record_tool_call(&self, role: &AgentRole) {
let mut guard = self.inner.write().await;
if let Some(usage) = guard.usage.get_mut(role) {
usage.tool_calls += 1;
}
}
pub async fn check_budget(&self, role: &AgentRole) -> BudgetStatus {
let guard = self.inner.read().await;
let (budget, usage) = match (guard.budgets.get(role), guard.usage.get(role)) {
(Some(b), Some(u)) => (b, u),
_ => return BudgetStatus::WithinBudget,
};
let checks: Vec<(&str, u64, u64)> = vec![
("input_tokens", usage.input_tokens, budget.max_input_tokens),
(
"output_tokens",
usage.output_tokens,
budget.max_output_tokens,
),
(
"total_tokens",
usage.total_tokens(),
budget.max_total_tokens,
),
(
"tool_calls",
u64::from(usage.tool_calls),
u64::from(budget.max_tool_calls),
),
(
"duration_secs",
usage.elapsed_secs(),
budget.max_duration_secs,
),
];
for &(resource, used, limit) in &checks {
if used > limit {
warn!(role = %role, resource, used, limit, "budget exceeded");
return BudgetStatus::Exceeded {
resource: resource.to_string(),
limit,
used,
};
}
}
for &(resource, used, limit) in &checks {
if limit > 0 {
let pct = used as f64 / limit as f64;
if pct > 0.8 {
warn!(role = %role, resource, usage_pct = pct, "budget warning");
return BudgetStatus::Warning {
resource: resource.to_string(),
usage_pct: pct,
};
}
}
}
BudgetStatus::WithinBudget
}
pub async fn usage(&self, role: &AgentRole) -> Option<AgentUsage> {
self.inner.read().await.usage.get(role).cloned()
}
pub async fn total_cost_estimate(
&self,
price_per_1k_input: f64,
price_per_1k_output: f64,
) -> f64 {
let guard = self.inner.read().await;
guard.usage.values().fold(0.0, |acc, u| {
acc + (u.input_tokens as f64 / 1_000.0) * price_per_1k_input
+ (u.output_tokens as f64 / 1_000.0) * price_per_1k_output
})
}
pub async fn summary(&self) -> BudgetSummary {
let guard = self.inner.read().await;
let mut total_input: u64 = 0;
let mut total_output: u64 = 0;
let mut total_tools: u32 = 0;
let mut per_agent = Vec::new();
for usage in guard.usage.values() {
total_input += usage.input_tokens;
total_output += usage.output_tokens;
total_tools += usage.tool_calls;
per_agent.push(AgentUsageEntry {
role: usage.role.clone(),
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
tool_calls: usage.tool_calls,
elapsed_secs: usage.elapsed_secs(),
});
}
per_agent.sort_by_key(|entry| entry.role.to_string());
BudgetSummary {
total_input_tokens: total_input,
total_output_tokens: total_output,
total_tool_calls: total_tools,
per_agent,
estimated_cost_usd: None,
}
}
pub async fn reset(&self, role: &AgentRole) {
let mut guard = self.inner.write().await;
if guard.usage.contains_key(role) {
info!(role = %role, "budget usage reset");
guard
.usage
.insert(role.clone(), AgentUsage::new(role.clone()));
}
}
}
impl Default for BudgetTracker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
fn coder() -> AgentRole {
AgentRole::Coder
}
fn reviewer() -> AgentRole {
AgentRole::Reviewer
}
#[tokio::test]
async fn test_record_tokens() {
let tracker = BudgetTracker::new();
let role = coder();
tracker
.set_budget(role.clone(), default_budget(&role))
.await;
tracker.start_tracking(role.clone()).await;
tracker.record_tokens(&role, 1_000, 500).await;
tracker.record_tokens(&role, 2_000, 300).await;
let usage = tracker.usage(&role).await.unwrap();
assert_eq!(usage.input_tokens, 3_000);
assert_eq!(usage.output_tokens, 800);
assert_eq!(usage.total_tokens(), 3_800);
}
#[tokio::test]
async fn test_record_tool_calls() {
let tracker = BudgetTracker::new();
let role = coder();
tracker
.set_budget(role.clone(), default_budget(&role))
.await;
tracker.start_tracking(role.clone()).await;
for _ in 0..5 {
tracker.record_tool_call(&role).await;
}
let usage = tracker.usage(&role).await.unwrap();
assert_eq!(usage.tool_calls, 5);
}
#[tokio::test]
async fn test_within_budget() {
let tracker = BudgetTracker::new();
let role = coder();
tracker
.set_budget(role.clone(), default_budget(&role))
.await;
tracker.start_tracking(role.clone()).await;
tracker.record_tokens(&role, 100, 50).await;
let status = tracker.check_budget(&role).await;
assert_eq!(status, BudgetStatus::WithinBudget);
}
#[tokio::test]
async fn test_warning_threshold() {
let tracker = BudgetTracker::new();
let role = coder();
tracker
.set_budget(role.clone(), default_budget(&role))
.await;
tracker.start_tracking(role.clone()).await;
tracker.record_tokens(&role, 85_000, 0).await;
let status = tracker.check_budget(&role).await;
match status {
BudgetStatus::Warning {
resource,
usage_pct,
} => {
assert_eq!(resource, "input_tokens");
assert!(usage_pct > 0.8);
}
other => panic!("expected Warning, got {other:?}"),
}
}
#[tokio::test]
async fn test_exceeded_budget() {
let tracker = BudgetTracker::new();
let role = reviewer();
tracker
.set_budget(role.clone(), default_budget(&role))
.await;
tracker.start_tracking(role.clone()).await;
tracker.record_tokens(&role, 35_000, 0).await;
let status = tracker.check_budget(&role).await;
match status {
BudgetStatus::Exceeded {
resource,
limit,
used,
} => {
assert_eq!(resource, "input_tokens");
assert_eq!(limit, 30_000);
assert_eq!(used, 35_000);
}
other => panic!("expected Exceeded, got {other:?}"),
}
}
#[tokio::test]
async fn test_cost_estimation() {
let tracker = BudgetTracker::new();
let c = coder();
tracker.set_budget(c.clone(), default_budget(&c)).await;
tracker.start_tracking(c.clone()).await;
tracker.record_tokens(&c, 10_000, 5_000).await;
let r = reviewer();
tracker.set_budget(r.clone(), default_budget(&r)).await;
tracker.start_tracking(r.clone()).await;
tracker.record_tokens(&r, 2_000, 1_000).await;
let cost = tracker.total_cost_estimate(0.01, 0.03).await;
let expected = (12_000.0 / 1_000.0) * 0.01 + (6_000.0 / 1_000.0) * 0.03;
assert!((cost - expected).abs() < f64::EPSILON);
}
#[tokio::test]
async fn test_summary() {
let tracker = BudgetTracker::new();
let c = coder();
tracker.start_tracking(c.clone()).await;
tracker.record_tokens(&c, 1_000, 500).await;
tracker.record_tool_call(&c).await;
let r = reviewer();
tracker.start_tracking(r.clone()).await;
tracker.record_tokens(&r, 2_000, 1_000).await;
let summary = tracker.summary().await;
assert_eq!(summary.total_input_tokens, 3_000);
assert_eq!(summary.total_output_tokens, 1_500);
assert_eq!(summary.total_tool_calls, 1);
assert_eq!(summary.per_agent.len(), 2);
}
#[tokio::test]
async fn test_reset() {
let tracker = BudgetTracker::new();
let role = coder();
tracker
.set_budget(role.clone(), default_budget(&role))
.await;
tracker.start_tracking(role.clone()).await;
tracker.record_tokens(&role, 50_000, 10_000).await;
tracker.reset(&role).await;
let usage = tracker.usage(&role).await.unwrap();
assert_eq!(usage.input_tokens, 0);
assert_eq!(usage.output_tokens, 0);
assert_eq!(usage.tool_calls, 0);
let status = tracker.check_budget(&role).await;
assert_eq!(status, BudgetStatus::WithinBudget);
}
#[tokio::test]
async fn test_unknown_role() {
let tracker = BudgetTracker::new();
let role = AgentRole::Custom("unknown".to_string());
assert!(tracker.usage(&role).await.is_none());
assert_eq!(
tracker.check_budget(&role).await,
BudgetStatus::WithinBudget
);
}
#[test]
fn test_default_budgets() {
let orchestrator = default_budget(&AgentRole::Orchestrator);
assert_eq!(orchestrator.max_input_tokens, 50_000);
assert_eq!(orchestrator.max_output_tokens, 10_000);
assert_eq!(orchestrator.max_tool_calls, 100);
assert_eq!(orchestrator.max_duration_secs, 300);
let coder = default_budget(&AgentRole::Coder);
assert_eq!(coder.max_input_tokens, 100_000);
assert_eq!(coder.max_output_tokens, 50_000);
assert_eq!(coder.max_tool_calls, 50);
assert_eq!(coder.max_duration_secs, 600);
let tester = default_budget(&AgentRole::Tester);
assert_eq!(tester.max_input_tokens, 50_000);
let reviewer = default_budget(&AgentRole::Reviewer);
assert_eq!(reviewer.max_input_tokens, 30_000);
assert_eq!(reviewer.max_duration_secs, 120);
let custom = default_budget(&AgentRole::Custom("x".into()));
assert_eq!(custom.max_input_tokens, 50_000);
assert_eq!(custom.max_output_tokens, 20_000);
}
#[tokio::test]
async fn test_tool_call_exceeded() {
let tracker = BudgetTracker::new();
let role = reviewer();
tracker
.set_budget(role.clone(), default_budget(&role))
.await;
tracker.start_tracking(role.clone()).await;
for _ in 0..11 {
tracker.record_tool_call(&role).await;
}
let status = tracker.check_budget(&role).await;
match status {
BudgetStatus::Exceeded {
resource,
limit,
used,
} => {
assert_eq!(resource, "tool_calls");
assert_eq!(limit, 10);
assert_eq!(used, 11);
}
other => panic!("expected Exceeded for tool_calls, got {other:?}"),
}
}
#[tokio::test]
async fn test_usage_serialization() {
let tracker = BudgetTracker::new();
let role = coder();
tracker.start_tracking(role.clone()).await;
tracker.record_tokens(&role, 42, 7).await;
let usage = tracker.usage(&role).await.unwrap();
let json = serde_json::to_string(&usage).unwrap();
assert!(json.contains("\"input_tokens\":42"));
assert!(json.contains("\"output_tokens\":7"));
assert!(json.contains("\"elapsed_secs\":"));
assert!(!json.contains("start_time"));
}
}