use chrono::{DateTime, Utc};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use std::time::Duration;
use crate::types::AgentId;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BudgetLimit {
pub agent_id: AgentId,
pub token_budget: u64,
pub calls_budget: u64,
pub window_secs: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub tokens_used: u64,
pub calls_used: u64,
pub window_start: DateTime<Utc>,
}
#[derive(Debug, Clone)]
pub struct BudgetInfo {
pub tokens_remaining: u64,
pub calls_remaining: u64,
pub window_remaining_secs: u64,
pub is_exhausted: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BudgetKind {
Token,
Call,
}
#[derive(Debug, Clone)]
pub struct BudgetExceeded {
pub agent_id: AgentId,
pub kind: BudgetKind,
pub message: String,
}
impl std::fmt::Display for BudgetExceeded {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for BudgetExceeded {}
pub struct BudgetManager {
budgets: RwLock<HashMap<AgentId, BudgetLimit>>,
usage: RwLock<HashMap<AgentId, Usage>>,
}
impl BudgetManager {
pub fn new() -> Self {
Self {
budgets: RwLock::new(HashMap::new()),
usage: RwLock::new(HashMap::new()),
}
}
pub fn set_budget(&self, limit: BudgetLimit) {
let agent_id = limit.agent_id;
let now = Utc::now();
{
let mut budgets = self.budgets.write();
budgets.insert(agent_id, limit);
}
let mut usage = self.usage.write();
usage.entry(agent_id).or_insert(Usage {
tokens_used: 0,
calls_used: 0,
window_start: now,
});
}
pub fn remove_budget(&self, agent_id: &AgentId) {
let mut budgets = self.budgets.write();
let mut usage = self.usage.write();
budgets.remove(agent_id);
usage.remove(agent_id);
}
pub fn reserve(&self, agent_id: &AgentId, tokens: u64) -> Result<(), BudgetExceeded> {
let limit = {
let budgets = self.budgets.read();
budgets.get(agent_id).cloned()
};
let limit = match limit {
Some(l) => l,
None => {
return Err(BudgetExceeded {
agent_id: *agent_id,
kind: BudgetKind::Token,
message: format!("No budget configured for agent {}", agent_id),
});
}
};
{
let mut usage = self.usage.write();
let usage_entry = usage.entry(*agent_id).or_insert_with(|| Usage {
tokens_used: 0,
calls_used: 0,
window_start: Utc::now(),
});
reset_if_expired(usage_entry, limit.window_secs);
if usage_entry.tokens_used + tokens > limit.token_budget {
return Err(BudgetExceeded {
agent_id: *agent_id,
kind: BudgetKind::Token,
message: format!(
"Token budget exceeded: requested {} but only {} remaining",
tokens,
limit.token_budget.saturating_sub(usage_entry.tokens_used)
),
});
}
usage_entry.tokens_used += tokens;
}
Ok(())
}
pub fn release(&self, agent_id: &AgentId, tokens_used: u64) {
let mut usage = self.usage.write();
if let Some(entry) = usage.get_mut(agent_id) {
entry.tokens_used = entry.tokens_used.saturating_sub(tokens_used);
}
}
pub fn track_call(&self, agent_id: &AgentId) -> Result<(), BudgetExceeded> {
let limit = {
let budgets = self.budgets.read();
budgets.get(agent_id).cloned()
};
let limit = match limit {
Some(l) => l,
None => {
return Err(BudgetExceeded {
agent_id: *agent_id,
kind: BudgetKind::Call,
message: format!("No budget configured for agent {}", agent_id),
});
}
};
{
let mut usage = self.usage.write();
let usage_entry = usage.entry(*agent_id).or_insert_with(|| Usage {
tokens_used: 0,
calls_used: 0,
window_start: Utc::now(),
});
reset_if_expired(usage_entry, limit.window_secs);
if usage_entry.calls_used >= limit.calls_budget {
return Err(BudgetExceeded {
agent_id: *agent_id,
kind: BudgetKind::Call,
message: format!(
"Call budget exceeded: {} calls used, limit is {}",
usage_entry.calls_used, limit.calls_budget
),
});
}
usage_entry.calls_used += 1;
}
Ok(())
}
pub fn remaining(&self, agent_id: &AgentId) -> BudgetInfo {
let limit = {
let budgets = self.budgets.read();
budgets.get(agent_id).cloned()
};
match limit {
Some(limit) => {
let usage = self.usage.read();
let usage_entry = usage.get(agent_id);
if let Some(entry) = usage_entry {
let elapsed = Utc::now()
.signed_duration_since(entry.window_start)
.to_std()
.unwrap_or(Duration::ZERO);
let window_remaining = Duration::from_secs(limit.window_secs)
.saturating_sub(elapsed)
.as_secs();
let tokens_remaining = limit.token_budget.saturating_sub(entry.tokens_used);
let calls_remaining = limit.calls_budget.saturating_sub(entry.calls_used);
let is_exhausted = tokens_remaining == 0 || calls_remaining == 0;
BudgetInfo {
tokens_remaining,
calls_remaining,
window_remaining_secs: window_remaining,
is_exhausted,
}
} else {
BudgetInfo {
tokens_remaining: limit.token_budget,
calls_remaining: limit.calls_budget,
window_remaining_secs: limit.window_secs,
is_exhausted: false,
}
}
}
None => BudgetInfo {
tokens_remaining: 0,
calls_remaining: 0,
window_remaining_secs: 0,
is_exhausted: true,
},
}
}
pub fn can_schedule(&self, agent_id: &AgentId) -> bool {
!self.remaining(agent_id).is_exhausted
}
pub fn reset_window(&self, agent_id: &AgentId) {
let mut usage = self.usage.write();
if let Some(entry) = usage.get_mut(agent_id) {
entry.tokens_used = 0;
entry.calls_used = 0;
entry.window_start = Utc::now();
}
}
pub fn persist(&self, path: &Path) -> anyhow::Result<()> {
let budgets = self.budgets.read();
let usage = self.usage.read();
let data = PersistedBudgets {
budgets: budgets.clone(),
usage: usage.clone(),
};
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let json = serde_json::to_string_pretty(&data)?;
std::fs::write(path, json)?;
Ok(())
}
pub fn restore(&self, path: &Path) -> anyhow::Result<()> {
if !path.exists() {
return Ok(());
}
let json = std::fs::read_to_string(path)?;
let data: PersistedBudgets = serde_json::from_str(&json)?;
{
let mut budgets = self.budgets.write();
*budgets = data.budgets;
}
{
let mut usage = self.usage.write();
*usage = data.usage;
}
Ok(())
}
}
#[derive(Serialize, Deserialize)]
struct PersistedBudgets {
budgets: HashMap<AgentId, BudgetLimit>,
usage: HashMap<AgentId, Usage>,
}
fn reset_if_expired(usage: &mut Usage, window_secs: u64) {
let window_duration = chrono::Duration::seconds(window_secs as i64);
let elapsed = Utc::now().signed_duration_since(usage.window_start);
if elapsed >= window_duration {
usage.tokens_used = 0;
usage.calls_used = 0;
usage.window_start = Utc::now();
}
}
impl Default for BudgetManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
fn test_agent_id() -> AgentId {
uuid::Uuid::new_v4()
}
#[test]
fn test_budget_creation() {
let manager = BudgetManager::new();
let agent_id = test_agent_id();
let limit = BudgetLimit {
agent_id,
token_budget: 1000,
calls_budget: 10,
window_secs: 60,
};
manager.set_budget(limit.clone());
let info = manager.remaining(&agent_id);
assert_eq!(info.tokens_remaining, 1000);
assert_eq!(info.calls_remaining, 10);
assert!(!info.is_exhausted);
}
#[test]
fn test_reserve_success() {
let manager = BudgetManager::new();
let agent_id = test_agent_id();
let limit = BudgetLimit {
agent_id,
token_budget: 1000,
calls_budget: 10,
window_secs: 60,
};
manager.set_budget(limit);
let result = manager.reserve(&agent_id, 500);
assert!(result.is_ok());
let info = manager.remaining(&agent_id);
assert_eq!(info.tokens_remaining, 500);
}
#[test]
fn test_exhaust_tokens() {
let manager = BudgetManager::new();
let agent_id = test_agent_id();
let limit = BudgetLimit {
agent_id,
token_budget: 1000,
calls_budget: 10,
window_secs: 60,
};
manager.set_budget(limit);
let result = manager.reserve(&agent_id, 1000);
assert!(result.is_ok());
let result = manager.reserve(&agent_id, 1);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.agent_id, agent_id);
assert_eq!(err.kind, BudgetKind::Token);
}
#[test]
fn test_exhaust_calls() {
let manager = BudgetManager::new();
let agent_id = test_agent_id();
let limit = BudgetLimit {
agent_id,
token_budget: 1000,
calls_budget: 3,
window_secs: 60,
};
manager.set_budget(limit);
assert!(manager.track_call(&agent_id).is_ok());
assert!(manager.track_call(&agent_id).is_ok());
assert!(manager.track_call(&agent_id).is_ok());
let result = manager.track_call(&agent_id);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.agent_id, agent_id);
assert_eq!(err.kind, BudgetKind::Call);
}
#[test]
fn test_window_reset() {
let manager = BudgetManager::new();
let agent_id = test_agent_id();
let limit = BudgetLimit {
agent_id,
token_budget: 100,
calls_budget: 5,
window_secs: 1,
};
manager.set_budget(limit);
manager.reserve(&agent_id, 100).unwrap();
assert!(manager.reserve(&agent_id, 1).is_err());
thread::sleep(Duration::from_secs(2));
let result = manager.reserve(&agent_id, 50);
assert!(result.is_ok());
let info = manager.remaining(&agent_id);
assert_eq!(info.tokens_remaining, 50);
}
#[test]
fn test_can_schedule() {
let manager = BudgetManager::new();
let agent_id = test_agent_id();
let limit = BudgetLimit {
agent_id,
token_budget: 1000,
calls_budget: 10,
window_secs: 60,
};
manager.set_budget(limit);
assert!(manager.can_schedule(&agent_id));
for _ in 0..10 {
manager.track_call(&agent_id).unwrap();
}
assert!(!manager.can_schedule(&agent_id));
}
#[test]
fn test_no_budget_configured() {
let manager = BudgetManager::new();
let agent_id = test_agent_id();
let result = manager.reserve(&agent_id, 100);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message.contains("No budget configured"));
let result = manager.track_call(&agent_id);
assert!(result.is_err());
}
#[test]
fn test_remove_budget() {
let manager = BudgetManager::new();
let agent_id = test_agent_id();
let limit = BudgetLimit {
agent_id,
token_budget: 1000,
calls_budget: 10,
window_secs: 60,
};
manager.set_budget(limit);
manager.reserve(&agent_id, 100).unwrap();
manager.remove_budget(&agent_id);
let result = manager.reserve(&agent_id, 100);
assert!(result.is_err());
}
#[test]
fn test_release_tokens() {
let manager = BudgetManager::new();
let agent_id = test_agent_id();
let limit = BudgetLimit {
agent_id,
token_budget: 1000,
calls_budget: 10,
window_secs: 60,
};
manager.set_budget(limit);
manager.reserve(&agent_id, 500).unwrap();
let info_before = manager.remaining(&agent_id);
assert_eq!(info_before.tokens_remaining, 500);
manager.release(&agent_id, 200);
let info_after = manager.remaining(&agent_id);
assert_eq!(info_after.tokens_remaining, 700);
}
#[test]
fn test_reset_window() {
let manager = BudgetManager::new();
let agent_id = test_agent_id();
let limit = BudgetLimit {
agent_id,
token_budget: 1000,
calls_budget: 10,
window_secs: 60,
};
manager.set_budget(limit);
manager.reserve(&agent_id, 500).unwrap();
let info_before = manager.remaining(&agent_id);
assert_eq!(info_before.tokens_remaining, 500);
manager.reset_window(&agent_id);
let info_after = manager.remaining(&agent_id);
assert_eq!(info_after.tokens_remaining, 1000);
assert_eq!(info_after.calls_remaining, 10);
}
#[test]
fn test_multiple_agents() {
let manager = BudgetManager::new();
let agent1 = test_agent_id();
let agent2 = test_agent_id();
manager.set_budget(BudgetLimit {
agent_id: agent1,
token_budget: 1000,
calls_budget: 10,
window_secs: 60,
});
manager.set_budget(BudgetLimit {
agent_id: agent2,
token_budget: 500,
calls_budget: 5,
window_secs: 60,
});
manager.reserve(&agent1, 300).unwrap();
manager.reserve(&agent2, 200).unwrap();
let info1 = manager.remaining(&agent1);
let info2 = manager.remaining(&agent2);
assert_eq!(info1.tokens_remaining, 700);
assert_eq!(info2.tokens_remaining, 300);
}
}