use std::sync::Mutex;
use dashmap::DashMap;
use rust_decimal::Decimal;
use tokio::sync::broadcast;
use aa_core::AgentId;
use rust_decimal::prelude::ToPrimitive;
use crate::budget::{
pricing::PricingTable,
types::{BudgetAlert, BudgetState, BudgetStatus, BudgetWindow},
};
#[derive(Debug, Clone)]
pub(crate) struct AgentLimit {
pub daily_usd: Option<Decimal>,
pub monthly_usd: Option<Decimal>,
}
const ALERT_CHANNEL_CAPACITY: usize = 64;
const ALERT_PCT_HIGH: u8 = 95;
const ALERT_PCT_LOW: u8 = 80;
fn compute_status(spent: Decimal, limit: Decimal) -> BudgetStatus {
if spent >= limit {
return BudgetStatus::LimitExceeded;
}
let pct = (spent / limit * Decimal::ONE_HUNDRED)
.round_dp(0)
.to_u8()
.unwrap_or(100);
let spent_f = spent.to_f64().unwrap_or(0.0);
let limit_f = limit.to_f64().unwrap_or(0.0);
if pct >= ALERT_PCT_HIGH {
BudgetStatus::ThresholdAlert { pct: ALERT_PCT_HIGH }
} else if pct >= ALERT_PCT_LOW {
BudgetStatus::ThresholdAlert { pct: ALERT_PCT_LOW }
} else {
BudgetStatus::WithinBudget {
spent_usd: spent_f,
remaining_usd: limit_f - spent_f,
}
}
}
fn today_in_tz(tz: chrono_tz::Tz) -> chrono::NaiveDate {
chrono::Utc::now().with_timezone(&tz).date_naive()
}
pub struct BudgetTracker {
pub(crate) per_agent: DashMap<AgentId, BudgetState>,
pub(crate) team_budgets: DashMap<String, BudgetState>,
pub(crate) org_budgets: DashMap<String, BudgetState>,
pub(crate) global: Mutex<BudgetState>,
pricing: PricingTable,
daily_limit_usd: Option<Decimal>,
monthly_limit_usd: Option<Decimal>,
team_daily_limit_usd: Option<Decimal>,
team_monthly_limit_usd: Option<Decimal>,
org_daily_limit_usd: Option<Decimal>,
org_monthly_limit_usd: Option<Decimal>,
pub(crate) agent_limits: DashMap<AgentId, AgentLimit>,
pub(crate) parent_locks: DashMap<AgentId, std::sync::Arc<parking_lot::Mutex<()>>>,
pub(crate) spend_history: DashMap<AgentId, std::collections::BTreeMap<chrono::NaiveDate, Decimal>>,
alert_tx: broadcast::Sender<BudgetAlert>,
timezone: chrono_tz::Tz,
window: BudgetWindow,
}
impl BudgetTracker {
pub fn new(
pricing: PricingTable,
daily_limit_usd: Option<Decimal>,
monthly_limit_usd: Option<Decimal>,
timezone: chrono_tz::Tz,
) -> Self {
let (alert_tx, _) = broadcast::channel(ALERT_CHANNEL_CAPACITY);
Self::new_with_alert_sender(pricing, daily_limit_usd, monthly_limit_usd, timezone, alert_tx)
}
pub fn new_with_alert_sender(
pricing: PricingTable,
daily_limit_usd: Option<Decimal>,
monthly_limit_usd: Option<Decimal>,
timezone: chrono_tz::Tz,
alert_tx: broadcast::Sender<BudgetAlert>,
) -> Self {
Self {
per_agent: DashMap::new(),
team_budgets: DashMap::new(),
org_budgets: DashMap::new(),
global: Mutex::new(BudgetState::new_for_date(today_in_tz(timezone))),
pricing,
daily_limit_usd,
monthly_limit_usd,
team_daily_limit_usd: None,
team_monthly_limit_usd: None,
org_daily_limit_usd: None,
org_monthly_limit_usd: None,
agent_limits: DashMap::new(),
parent_locks: DashMap::new(),
spend_history: DashMap::new(),
alert_tx,
timezone,
window: BudgetWindow::Daily,
}
}
pub fn with_team_daily_limit(mut self, limit: Decimal) -> Self {
self.team_daily_limit_usd = Some(limit);
self
}
pub fn with_team_monthly_limit(mut self, limit: Decimal) -> Self {
self.team_monthly_limit_usd = Some(limit);
self
}
pub fn with_org_daily_limit(mut self, limit: Decimal) -> Self {
self.org_daily_limit_usd = Some(limit);
self
}
pub fn with_org_monthly_limit(mut self, limit: Decimal) -> Self {
self.org_monthly_limit_usd = Some(limit);
self
}
pub fn with_agent_limit(self, agent_id: AgentId, daily_usd: Option<Decimal>, monthly_usd: Option<Decimal>) -> Self {
self.agent_limits
.insert(agent_id, AgentLimit { daily_usd, monthly_usd });
self
}
pub fn with_state(
pricing: PricingTable,
daily_limit_usd: Option<Decimal>,
monthly_limit_usd: Option<Decimal>,
initial: crate::budget::persistence::PersistedBudget,
) -> Self {
let (alert_tx, _) = broadcast::channel(ALERT_CHANNEL_CAPACITY);
Self::with_state_and_alert_sender(pricing, daily_limit_usd, monthly_limit_usd, initial, alert_tx)
}
pub fn with_state_and_alert_sender(
pricing: PricingTable,
daily_limit_usd: Option<Decimal>,
monthly_limit_usd: Option<Decimal>,
initial: crate::budget::persistence::PersistedBudget,
alert_tx: broadcast::Sender<BudgetAlert>,
) -> Self {
let timezone = initial.timezone;
let per_agent: DashMap<AgentId, BudgetState> = initial
.per_agent
.into_iter()
.filter_map(|e| {
crate::budget::persistence::hex_to_agent_id(&e.agent_id_hex)
.ok()
.map(|id| (id, e.state))
})
.collect();
let team_budgets: DashMap<String, BudgetState> = initial.team_budgets.into_iter().collect();
Self {
per_agent,
team_budgets,
org_budgets: DashMap::new(),
global: Mutex::new(initial.global),
pricing,
daily_limit_usd,
monthly_limit_usd,
team_daily_limit_usd: None,
team_monthly_limit_usd: None,
org_daily_limit_usd: None,
org_monthly_limit_usd: None,
agent_limits: DashMap::new(),
parent_locks: DashMap::new(),
spend_history: DashMap::new(),
alert_tx,
timezone,
window: BudgetWindow::Daily,
}
}
pub fn with_window(mut self, window: BudgetWindow) -> Self {
self.window = window;
self
}
#[allow(dead_code)]
pub fn window(&self) -> BudgetWindow {
self.window
}
fn get_or_create_parent_lock(&self, ancestor_id: AgentId) -> std::sync::Arc<parking_lot::Mutex<()>> {
use std::sync::Arc;
self.parent_locks
.entry(ancestor_id)
.or_insert_with(|| Arc::new(parking_lot::Mutex::new(())))
.value()
.clone()
}
fn resolve_limit(&self, agent_id: &AgentId, kind: crate::budget::types::BudgetKind) -> Option<Decimal> {
use crate::budget::types::BudgetKind;
match kind {
BudgetKind::Daily => self
.agent_limits
.get(agent_id)
.and_then(|l| l.daily_usd)
.or(self.daily_limit_usd),
BudgetKind::Monthly => self
.agent_limits
.get(agent_id)
.and_then(|l| l.monthly_usd)
.or(self.monthly_limit_usd),
BudgetKind::Global => None,
}
}
pub fn subscribe_alerts(&self) -> broadcast::Receiver<BudgetAlert> {
self.alert_tx.subscribe()
}
pub fn timezone(&self) -> chrono_tz::Tz {
self.timezone
}
pub fn daily_limit_usd(&self) -> Option<Decimal> {
self.daily_limit_usd
}
pub fn monthly_limit_usd(&self) -> Option<Decimal> {
self.monthly_limit_usd
}
pub fn check_daily(&self, agent_id: &AgentId, limit: Decimal) -> bool {
if let Some(mut entry) = self.per_agent.get_mut(agent_id) {
entry.maybe_reset_window(chrono::Utc::now(), self.window, self.timezone);
entry.spent_usd >= limit
} else {
false
}
}
pub fn check_monthly(&self, agent_id: &AgentId, limit: Decimal) -> bool {
if let Some(mut entry) = self.per_agent.get_mut(agent_id) {
entry.maybe_reset_window(chrono::Utc::now(), self.window, self.timezone);
entry.monthly_spent_usd.map(|m| m >= limit).unwrap_or(false)
} else {
false
}
}
pub fn record_raw_spend(
&self,
agent_id: AgentId,
team_id: Option<&str>,
org_id: Option<&str>,
amount_usd: Decimal,
) -> BudgetStatus {
self.record_cost(agent_id, team_id, org_id, amount_usd)
}
#[allow(clippy::too_many_arguments)] pub fn record_usage(
&self,
agent_id: AgentId,
team_id: Option<&str>,
org_id: Option<&str>,
provider: crate::budget::types::Provider,
model: crate::budget::types::Model,
input_tokens: u64,
output_tokens: u64,
) -> BudgetStatus {
let cost = self.pricing.cost_usd(provider, model, input_tokens, output_tokens);
self.record_cost(agent_id, team_id, org_id, cost)
}
fn check_limit_and_alert(
&self,
agent_id: AgentId,
team_id: Option<&str>,
spent: Decimal,
limit: Decimal,
) -> BudgetStatus {
let status = compute_status(spent, limit);
if let BudgetStatus::ThresholdAlert { pct } = &status {
let _ = self.alert_tx.send(BudgetAlert {
agent_id,
team_id: team_id.map(str::to_string),
threshold_pct: *pct,
spent_usd: spent.to_f64().unwrap_or(0.0),
limit_usd: limit.to_f64().unwrap_or(0.0),
});
}
status
}
fn record_cost(
&self,
agent_id: AgentId,
team_id: Option<&str>,
org_id: Option<&str>,
cost: Decimal,
) -> BudgetStatus {
let has_monthly = self.monthly_limit_usd.is_some()
|| self.team_monthly_limit_usd.is_some()
|| self.org_monthly_limit_usd.is_some();
let now = chrono::Utc::now();
let today = today_in_tz(self.timezone);
let window = self.window;
let tz = self.timezone;
self.per_agent
.entry(agent_id)
.and_modify(|s| {
s.maybe_reset_window(now, window, tz);
s.spent_usd += cost;
if let Some(m) = s.monthly_spent_usd.as_mut() {
*m += cost;
}
})
.or_insert_with(|| {
let mut s = BudgetState::new_for_date(today);
s.spent_usd += cost;
if has_monthly {
s.monthly_spent_usd = Some(cost);
}
if matches!(window, BudgetWindow::Duration(_)) {
s.last_reset_at = Some(now);
}
s
});
self.spend_history
.entry(agent_id)
.or_default()
.entry(today)
.and_modify(|v| *v += cost)
.or_insert(cost);
if let Some(tid) = team_id {
self.team_budgets
.entry(tid.to_string())
.and_modify(|s| {
s.maybe_reset_window(now, window, tz);
s.spent_usd += cost;
if let Some(m) = s.monthly_spent_usd.as_mut() {
*m += cost;
}
})
.or_insert_with(|| {
let mut s = BudgetState::new_for_date(today);
s.spent_usd += cost;
if has_monthly {
s.monthly_spent_usd = Some(cost);
}
if matches!(window, BudgetWindow::Duration(_)) {
s.last_reset_at = Some(now);
}
s
});
if let Some(team_monthly_limit) = self.team_monthly_limit_usd {
if let Some(team_state) = self.team_budgets.get(tid) {
if let Some(team_monthly) = team_state.monthly_spent_usd {
let status = self.check_limit_and_alert(agent_id, Some(tid), team_monthly, team_monthly_limit);
if status == BudgetStatus::LimitExceeded {
return BudgetStatus::LimitExceeded;
}
}
}
}
if let Some(team_daily_limit) = self.team_daily_limit_usd {
if let Some(team_state) = self.team_budgets.get(tid) {
let status =
self.check_limit_and_alert(agent_id, Some(tid), team_state.spent_usd, team_daily_limit);
if status == BudgetStatus::LimitExceeded {
return BudgetStatus::LimitExceeded;
}
}
}
}
if let Some(oid) = org_id {
self.org_budgets
.entry(oid.to_string())
.and_modify(|s| {
s.maybe_reset_window(now, window, tz);
s.spent_usd += cost;
if let Some(m) = s.monthly_spent_usd.as_mut() {
*m += cost;
}
})
.or_insert_with(|| {
let mut s = BudgetState::new_for_date(today);
s.spent_usd += cost;
if has_monthly {
s.monthly_spent_usd = Some(cost);
}
if matches!(window, BudgetWindow::Duration(_)) {
s.last_reset_at = Some(now);
}
s
});
if let Some(org_monthly_limit) = self.org_monthly_limit_usd {
if let Some(org_state) = self.org_budgets.get(oid) {
if let Some(org_monthly) = org_state.monthly_spent_usd {
let status = self.check_limit_and_alert(agent_id, None, org_monthly, org_monthly_limit);
if status == BudgetStatus::LimitExceeded {
return BudgetStatus::LimitExceeded;
}
}
}
}
if let Some(org_daily_limit) = self.org_daily_limit_usd {
if let Some(org_state) = self.org_budgets.get(oid) {
let status = self.check_limit_and_alert(agent_id, None, org_state.spent_usd, org_daily_limit);
if status == BudgetStatus::LimitExceeded {
return BudgetStatus::LimitExceeded;
}
}
}
}
let (spent, monthly_spent) = self
.per_agent
.get(&agent_id)
.map(|s| (s.spent_usd, s.monthly_spent_usd))
.unwrap_or((cost, None));
if let Ok(mut g) = self.global.lock() {
g.maybe_reset_window(now, window, tz);
g.spent_usd += cost;
}
if let (Some(limit), Some(m_spent)) = (self.monthly_limit_usd, monthly_spent) {
let status = self.check_limit_and_alert(agent_id, None, m_spent, limit);
if matches!(status, BudgetStatus::LimitExceeded) {
return BudgetStatus::LimitExceeded;
}
}
match self.daily_limit_usd {
None => BudgetStatus::WithinBudget {
spent_usd: spent.to_f64().unwrap_or(0.0),
remaining_usd: f64::INFINITY,
},
Some(limit) => self.check_limit_and_alert(agent_id, None, spent, limit),
}
}
fn preflight_ancestors(
&self,
ancestors: &[[u8; 16]],
amount: Decimal,
) -> Result<(), crate::budget::types::BudgetError> {
use crate::budget::types::{BudgetError, BudgetKind};
let now = chrono::Utc::now();
for &ancestor_bytes in ancestors {
let ancestor_id = AgentId::from_bytes(ancestor_bytes);
if let Some(limit) = self.resolve_limit(&ancestor_id, BudgetKind::Daily) {
let spent = self
.per_agent
.get(&ancestor_id)
.map(|s| {
let mut copy = s.clone();
copy.maybe_reset_window(now, self.window, self.timezone);
copy.spent_usd
})
.unwrap_or(Decimal::ZERO);
if spent + amount > limit {
return Err(BudgetError::AncestorBudgetExhausted {
ancestor_id: ancestor_bytes,
kind: BudgetKind::Daily,
});
}
}
}
Ok(())
}
pub fn check_and_decrement(
&self,
agent_id: AgentId,
ancestors: &[[u8; 16]],
amount: Decimal,
) -> Result<(), crate::budget::types::BudgetError> {
let ancestor_arcs: Vec<_> = ancestors
.iter()
.rev()
.map(|&bytes| self.get_or_create_parent_lock(AgentId::from_bytes(bytes)))
.collect();
let lock_wait_start = std::time::Instant::now();
let _guards: Vec<_> = ancestor_arcs.iter().map(|arc| arc.lock()).collect();
metrics::histogram!("budget_parent_lock_wait_seconds").record(lock_wait_start.elapsed().as_secs_f64());
self.preflight_ancestors(ancestors, amount)?;
let now = chrono::Utc::now();
let today = today_in_tz(self.timezone);
let window = self.window;
let tz = self.timezone;
self.per_agent
.entry(agent_id)
.and_modify(|s| {
s.maybe_reset_window(now, window, tz);
s.spent_usd += amount;
})
.or_insert_with(|| {
let mut s = BudgetState::new_for_date(today);
s.spent_usd = amount;
if matches!(window, BudgetWindow::Duration(_)) {
s.last_reset_at = Some(now);
}
s
});
for &ancestor_bytes in ancestors {
let ancestor_id = AgentId::from_bytes(ancestor_bytes);
self.per_agent
.entry(ancestor_id)
.and_modify(|s| {
s.maybe_reset_window(now, window, tz);
s.spent_usd += amount;
})
.or_insert_with(|| {
let mut s = BudgetState::new_for_date(today);
s.spent_usd = amount;
if matches!(window, BudgetWindow::Duration(_)) {
s.last_reset_at = Some(now);
}
s
});
}
Ok(())
}
pub fn subtree_spend(&self, agent_id: &AgentId, descendants: &[[u8; 16]]) -> crate::budget::types::SubtreeSpend {
let now = chrono::Utc::now();
let mut total_usd = Decimal::ZERO;
let mut agents_counted = 0usize;
let ids = std::iter::once(*agent_id.as_bytes()).chain(descendants.iter().copied());
for id_bytes in ids {
let aid = AgentId::from_bytes(id_bytes);
if let Some(state) = self.per_agent.get(&aid) {
let mut copy = state.clone();
copy.maybe_reset_window(now, self.window, self.timezone);
if copy.spent_usd > Decimal::ZERO {
total_usd += copy.spent_usd;
agents_counted += 1;
}
}
}
crate::budget::types::SubtreeSpend {
tokens: 0,
usd: total_usd,
agents_counted,
}
}
pub fn team_state(&self, team_id: &str) -> Option<BudgetState> {
self.team_budgets.get(team_id).map(|s| s.clone())
}
pub fn org_state(&self, org_id: &str) -> Option<BudgetState> {
self.org_budgets.get(org_id).map(|s| s.clone())
}
pub fn agent_state(&self, agent_id: &AgentId) -> Option<BudgetState> {
self.per_agent.get(agent_id).map(|s| s.clone())
}
pub fn agent_spend_history(&self, agent_id: &AgentId, days: u32) -> Vec<(chrono::NaiveDate, Decimal)> {
if days == 0 {
return Vec::new();
}
let today = today_in_tz(self.timezone);
let start = today - chrono::Duration::days(i64::from(days) - 1);
let history = self.spend_history.get(agent_id);
(0..days)
.map(|d| {
let date = start + chrono::Duration::days(i64::from(d));
let amount = history
.as_ref()
.and_then(|m| m.get(&date).copied())
.unwrap_or(Decimal::ZERO);
(date, amount)
})
.collect()
}
pub fn global_state(&self) -> BudgetState {
self.global
.lock()
.map(|g| g.clone())
.unwrap_or_else(|_| BudgetState::new_for_date(today_in_tz(self.timezone)))
}
pub fn flush_window(&self) {
let now = chrono::Utc::now();
for mut entry in self.per_agent.iter_mut() {
entry.maybe_reset_window(now, self.window, self.timezone);
}
for mut entry in self.team_budgets.iter_mut() {
entry.maybe_reset_window(now, self.window, self.timezone);
}
if let Ok(mut g) = self.global.lock() {
g.maybe_reset_window(now, self.window, self.timezone);
}
}
pub fn snapshot(&self) -> crate::budget::persistence::PersistedBudget {
let per_agent = self
.per_agent
.iter()
.map(|entry| crate::budget::persistence::PersistedAgentEntry {
agent_id_hex: crate::budget::persistence::agent_id_to_hex(entry.key()),
state: entry.value().clone(),
})
.collect();
let team_budgets = self
.team_budgets
.iter()
.map(|entry| (entry.key().clone(), entry.value().clone()))
.collect();
crate::budget::persistence::PersistedBudget {
per_agent,
team_budgets,
global: self.global_state(),
timezone: self.timezone,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::budget::pricing::PricingTable;
use rust_decimal::Decimal;
fn new_tracker() -> BudgetTracker {
BudgetTracker::new(PricingTable::default_table(), None, None, chrono_tz::UTC)
}
fn agent(b: u8) -> AgentId {
AgentId::from_bytes([b; 16])
}
fn tracker_with_limit(s: &str) -> BudgetTracker {
BudgetTracker::new(
PricingTable::default_table(),
Some(s.parse().unwrap()),
None,
chrono_tz::UTC,
)
}
#[test]
fn new_tracker_has_empty_per_agent_map() {
let t = new_tracker();
assert!(t.per_agent.is_empty());
}
#[test]
fn agent_spend_history_returns_dense_zero_filled_window_for_unknown_agent() {
let t = new_tracker();
let aid = AgentId::from_bytes([0xAA; 16]);
let history = t.agent_spend_history(&aid, 7);
assert_eq!(history.len(), 7, "should return one entry per requested day");
for (_, amount) in &history {
assert_eq!(*amount, Decimal::ZERO);
}
for win in history.windows(2) {
assert!(win[0].0 < win[1].0);
}
}
#[test]
fn agent_spend_history_records_todays_spend() {
let t = new_tracker();
let aid = AgentId::from_bytes([0xAA; 16]);
t.record_raw_spend(aid, None, None, Decimal::new(125, 2));
t.record_raw_spend(aid, None, None, Decimal::new(375, 2));
let history = t.agent_spend_history(&aid, 1);
assert_eq!(history.len(), 1);
assert_eq!(history[0].1, Decimal::new(500, 2), "should accumulate same-day spend");
}
#[test]
fn agent_spend_history_zero_days_returns_empty_vec() {
let t = new_tracker();
let aid = AgentId::from_bytes([0xAA; 16]);
t.record_raw_spend(aid, None, None, Decimal::ONE);
assert!(t.agent_spend_history(&aid, 0).is_empty());
}
#[test]
fn daily_limit_usd_returns_configured_limit() {
let t = tracker_with_limit("50.00");
assert_eq!(t.daily_limit_usd(), Some(Decimal::new(5000, 2)));
}
#[test]
fn daily_limit_usd_returns_none_when_unset() {
let t = new_tracker();
assert_eq!(t.daily_limit_usd(), None);
}
#[test]
fn monthly_limit_usd_returns_configured_limit() {
let t = BudgetTracker::new(
PricingTable::default_table(),
None,
Some("1000.00".parse().unwrap()),
chrono_tz::UTC,
);
assert_eq!(t.monthly_limit_usd(), Some(Decimal::new(100000, 2)));
}
#[test]
fn monthly_limit_usd_returns_none_when_unset() {
let t = new_tracker();
assert_eq!(t.monthly_limit_usd(), None);
}
#[test]
fn compute_status_returns_within_budget_below_80() {
use crate::budget::types::BudgetStatus;
fn d(s: &str) -> Decimal {
s.parse().unwrap()
}
let status = compute_status(d("7.00"), d("10.00")); assert!(matches!(status, BudgetStatus::WithinBudget { .. }));
}
#[test]
fn compute_status_returns_alert_at_80() {
use crate::budget::types::BudgetStatus;
fn d(s: &str) -> Decimal {
s.parse().unwrap()
}
let status = compute_status(d("8.00"), d("10.00")); assert_eq!(status, BudgetStatus::ThresholdAlert { pct: 80 });
}
#[test]
fn compute_status_returns_alert_at_95() {
use crate::budget::types::BudgetStatus;
fn d(s: &str) -> Decimal {
s.parse().unwrap()
}
let status = compute_status(d("9.50"), d("10.00")); assert_eq!(status, BudgetStatus::ThresholdAlert { pct: 95 });
}
#[test]
fn compute_status_returns_limit_exceeded_at_100() {
use crate::budget::types::BudgetStatus;
fn d(s: &str) -> Decimal {
s.parse().unwrap()
}
assert_eq!(compute_status(d("10.00"), d("10.00")), BudgetStatus::LimitExceeded);
assert_eq!(compute_status(d("11.00"), d("10.00")), BudgetStatus::LimitExceeded);
}
#[test]
fn record_usage_no_limit_returns_within_budget() {
use crate::budget::types::{BudgetStatus, Model, Provider};
let t = new_tracker();
let s = t.record_usage(agent(1), None, None, Provider::OpenAi, Model::Gpt4o, 100, 100);
assert!(matches!(s, BudgetStatus::WithinBudget { .. }));
}
#[test]
fn record_usage_over_limit_returns_limit_exceeded() {
use crate::budget::types::{BudgetStatus, Model, Provider};
let t = tracker_with_limit("1.00");
let s = t.record_usage(agent(2), None, None, Provider::OpenAi, Model::Gpt4o, 100_000, 40_000);
assert_eq!(s, BudgetStatus::LimitExceeded);
}
#[test]
fn record_usage_alert_at_80_pct() {
use crate::budget::types::{BudgetStatus, Model, Provider};
let t = tracker_with_limit("1.00");
let s = t.record_usage(agent(3), None, None, Provider::OpenAi, Model::Gpt4o, 100_000, 20_000);
assert_eq!(s, BudgetStatus::ThresholdAlert { pct: 80 });
}
#[test]
fn record_usage_resets_on_old_date() {
use crate::budget::types::{BudgetStatus, Model, Provider};
let t = tracker_with_limit("1.00");
let id = agent(4);
t.record_usage(id, None, None, Provider::OpenAi, Model::Gpt4o, 100_000, 30_000); t.per_agent.alter(&id, |_, mut s| {
s.date = chrono::Utc::now().date_naive() - chrono::Duration::days(1);
s
});
let s = t.record_usage(id, None, None, Provider::OpenAi, Model::Gpt4o, 100, 0);
assert!(matches!(s, BudgetStatus::WithinBudget { .. }));
}
#[test]
fn subscribe_alerts_returns_receiver() {
let t = new_tracker();
let _rx = t.subscribe_alerts(); }
#[test]
fn with_state_restores_per_agent_entries() {
use crate::budget::persistence::{agent_id_to_hex, PersistedAgentEntry, PersistedBudget};
use chrono::Datelike;
let id = AgentId::from_bytes([42u8; 16]);
let today = chrono::Utc::now().date_naive();
let state = BudgetState {
spent_usd: "5.00".parse::<Decimal>().unwrap(),
date: today,
month: today.year() as u32 * 100 + today.month(),
monthly_spent_usd: None,
last_reset_at: None,
};
let persisted = PersistedBudget {
per_agent: vec![PersistedAgentEntry {
agent_id_hex: agent_id_to_hex(&id),
state: state.clone(),
}],
team_budgets: Default::default(),
global: BudgetState::new_today(),
timezone: chrono_tz::UTC,
};
let t = BudgetTracker::with_state(PricingTable::default_table(), None, None, persisted);
let entry = t.per_agent.get(&id).unwrap();
assert_eq!(entry.spent_usd, state.spent_usd);
assert_eq!(t.timezone(), chrono_tz::UTC);
}
#[test]
fn snapshot_includes_per_agent_and_global() {
use crate::budget::types::{Model, Provider};
let t = new_tracker();
let id = agent(7);
t.record_usage(id, None, None, Provider::OpenAi, Model::Gpt4o, 1_000, 0);
let snap = t.snapshot();
assert_eq!(snap.per_agent.len(), 1);
assert_eq!(snap.global.spent_usd, snap.per_agent[0].state.spent_usd);
}
#[test]
fn global_state_accumulates_all_agents() {
use crate::budget::types::{Model, Provider};
let t = new_tracker();
t.record_usage(agent(5), None, None, Provider::OpenAi, Model::Gpt4o, 1_000, 0); t.record_usage(agent(6), None, None, Provider::OpenAi, Model::Gpt4o, 1_000, 0); let g = t.global_state();
let expected: Decimal = "0.010".parse().unwrap();
assert_eq!(g.spent_usd, expected);
}
#[test]
fn record_usage_timezone_offset_resets_at_local_midnight() {
use crate::budget::types::{BudgetStatus, Model, Provider};
let tz = chrono_tz::Asia::Tokyo;
let t = BudgetTracker::new(PricingTable::default_table(), Some("1.00".parse().unwrap()), None, tz);
let id = agent(10);
t.record_usage(id, None, None, Provider::OpenAi, Model::Gpt4o, 100_000, 30_000); let yesterday_tokyo = today_in_tz(tz) - chrono::Duration::days(1);
t.per_agent.alter(&id, |_, mut s| {
s.date = yesterday_tokyo;
s
});
let s = t.record_usage(id, None, None, Provider::OpenAi, Model::Gpt4o, 100, 0);
assert!(
matches!(s, BudgetStatus::WithinBudget { .. }),
"Expected reset after Tokyo midnight, got: {:?}",
s
);
}
fn tracker_with_monthly_limit(monthly: &str) -> BudgetTracker {
BudgetTracker::new(
PricingTable::default_table(),
None,
Some(monthly.parse().unwrap()),
chrono_tz::UTC,
)
}
#[test]
fn monthly_limit_exceeded_blocks_usage() {
use crate::budget::types::{BudgetStatus, Model, Provider};
let t = tracker_with_monthly_limit("1.00");
let s = t.record_usage(agent(20), None, None, Provider::OpenAi, Model::Gpt4o, 100_000, 40_000);
assert_eq!(s, BudgetStatus::LimitExceeded);
}
#[test]
fn monthly_within_budget_returns_within_budget() {
use crate::budget::types::{BudgetStatus, Model, Provider};
let t = tracker_with_monthly_limit("10.00");
let s = t.record_usage(agent(21), None, None, Provider::OpenAi, Model::Gpt4o, 1_000, 0);
assert!(matches!(s, BudgetStatus::WithinBudget { .. }));
}
#[test]
fn monthly_accumulates_across_daily_resets() {
use crate::budget::types::{BudgetStatus, Model, Provider};
let t = tracker_with_monthly_limit("1.00");
let id = agent(22);
t.record_usage(id, None, None, Provider::OpenAi, Model::Gpt4o, 100_000, 0);
t.per_agent.alter(&id, |_, mut s| {
s.date = chrono::Utc::now().date_naive() - chrono::Duration::days(1);
s
});
let s = t.record_usage(id, None, None, Provider::OpenAi, Model::Gpt4o, 0, 40_000);
assert_eq!(s, BudgetStatus::LimitExceeded);
}
#[test]
fn monthly_resets_on_month_change() {
use crate::budget::types::{BudgetStatus, Model, Provider};
use chrono::Datelike;
let t = tracker_with_monthly_limit("1.00");
let id = agent(23);
t.record_usage(id, None, None, Provider::OpenAi, Model::Gpt4o, 100_000, 30_000);
let last_month = chrono::Utc::now().date_naive() - chrono::Duration::days(32);
t.per_agent.alter(&id, |_, mut s| {
s.date = last_month;
s.month = last_month.year() as u32 * 100 + last_month.month();
s
});
let s = t.record_usage(id, None, None, Provider::OpenAi, Model::Gpt4o, 100, 0);
assert!(
matches!(s, BudgetStatus::WithinBudget { .. }),
"Expected within budget after monthly reset, got: {:?}",
s
);
}
#[test]
fn check_daily_returns_false_for_new_agent() {
let t = tracker_with_limit("10.00");
assert!(!t.check_daily(&agent(30), "10.00".parse().unwrap()));
}
#[test]
fn check_daily_returns_true_when_exceeded() {
let t = tracker_with_limit("1.00");
let id = agent(31);
t.record_raw_spend(id, None, None, "1.00".parse().unwrap());
assert!(t.check_daily(&id, "1.00".parse().unwrap()));
}
#[test]
fn check_monthly_returns_false_for_new_agent() {
let t = tracker_with_monthly_limit("100.00");
assert!(!t.check_monthly(&agent(32), "100.00".parse().unwrap()));
}
#[test]
fn check_monthly_returns_true_when_exceeded() {
let t = tracker_with_monthly_limit("5.00");
let id = agent(33);
t.record_raw_spend(id, None, None, "5.00".parse().unwrap());
assert!(t.check_monthly(&id, "5.00".parse().unwrap()));
}
#[test]
fn record_raw_spend_accumulates() {
let t = tracker_with_limit("10.00");
let id = agent(34);
t.record_raw_spend(id, None, None, "3.00".parse().unwrap());
t.record_raw_spend(id, None, None, "4.00".parse().unwrap());
assert!(t.check_daily(&id, "7.00".parse().unwrap()));
assert!(!t.check_daily(&id, "8.00".parse().unwrap()));
}
#[test]
fn record_raw_spend_fires_80_pct_alert() {
let t = tracker_with_limit("10.00");
let mut rx = t.subscribe_alerts();
let id = agent(35);
t.record_raw_spend(id, None, None, "8.00".parse().unwrap());
let alert = rx.try_recv().expect("expected 80% alert");
assert_eq!(alert.threshold_pct, 80);
assert_eq!(alert.agent_id, id);
}
#[test]
fn record_raw_spend_fires_95_pct_alert() {
let t = tracker_with_limit("10.00");
let mut rx = t.subscribe_alerts();
let id = agent(36);
t.record_raw_spend(id, None, None, "9.50".parse().unwrap());
let alert = rx.try_recv().expect("expected 95% alert");
assert_eq!(alert.threshold_pct, 95);
}
#[test]
fn new_with_alert_sender_uses_external_channel() {
let (tx, mut rx) = broadcast::channel::<BudgetAlert>(64);
let t = BudgetTracker::new_with_alert_sender(
PricingTable::default_table(),
Some("10.00".parse().unwrap()),
None,
chrono_tz::UTC,
tx,
);
let id = agent(37);
t.record_raw_spend(id, None, None, "8.00".parse().unwrap());
let alert = rx.try_recv().expect("alert should arrive on external channel");
assert_eq!(alert.threshold_pct, 80);
}
#[test]
fn check_daily_exact_limit_is_exceeded() {
let t = tracker_with_limit("1.00");
let id = agent(40);
t.record_raw_spend(id, None, None, "1.00".parse().unwrap());
assert!(t.check_daily(&id, "1.00".parse().unwrap()));
}
#[test]
fn check_daily_resets_on_new_date() {
let t = tracker_with_limit("1.00");
let id = agent(41);
t.record_raw_spend(id, None, None, "0.90".parse().unwrap());
t.per_agent.alter(&id, |_, mut s| {
s.date = chrono::Utc::now().date_naive() - chrono::Duration::days(1);
s
});
assert!(!t.check_daily(&id, "1.00".parse().unwrap()));
}
#[test]
fn check_monthly_accumulates_raw_spend() {
let t = tracker_with_monthly_limit("7.00");
let id = agent(42);
t.record_raw_spend(id, None, None, "3.00".parse().unwrap());
t.record_raw_spend(id, None, None, "4.00".parse().unwrap());
assert!(t.check_monthly(&id, "7.00".parse().unwrap()));
assert!(!t.check_monthly(&id, "8.00".parse().unwrap()));
}
#[test]
fn check_monthly_resets_on_month_change() {
use chrono::Datelike;
let t = tracker_with_monthly_limit("5.00");
let id = agent(43);
t.record_raw_spend(id, None, None, "5.00".parse().unwrap());
let last_month = chrono::Utc::now().date_naive() - chrono::Duration::days(32);
t.per_agent.alter(&id, |_, mut s| {
s.date = last_month;
s.month = last_month.year() as u32 * 100 + last_month.month();
s
});
assert!(!t.check_monthly(&id, "5.00".parse().unwrap()));
}
fn tracker_with_team_daily_limit(daily: &str) -> BudgetTracker {
BudgetTracker::new(PricingTable::default_table(), None, None, chrono_tz::UTC)
.with_team_daily_limit(daily.parse().unwrap())
}
fn tracker_with_team_monthly_limit(monthly: &str) -> BudgetTracker {
BudgetTracker::new(PricingTable::default_table(), None, None, chrono_tz::UTC)
.with_team_monthly_limit(monthly.parse().unwrap())
}
#[test]
fn team_daily_limit_exceeded_blocks_agent_in_same_team() {
let t = tracker_with_team_daily_limit("5.00");
let id = agent(50);
let status = t.record_raw_spend(id, Some("team-alpha"), None, "5.00".parse().unwrap());
assert_eq!(status, BudgetStatus::LimitExceeded);
}
#[test]
fn team_daily_limit_not_exceeded_before_threshold() {
let t = tracker_with_team_daily_limit("10.00");
let id = agent(51);
let status = t.record_raw_spend(id, Some("team-alpha"), None, "3.00".parse().unwrap());
assert!(matches!(status, BudgetStatus::WithinBudget { .. }));
}
#[test]
fn team_daily_limit_aggregates_across_multiple_agents() {
let t = tracker_with_team_daily_limit("5.00");
let id_a = agent(52);
let id_b = agent(53);
t.record_raw_spend(id_a, Some("team-beta"), None, "3.00".parse().unwrap());
let status = t.record_raw_spend(id_b, Some("team-beta"), None, "2.00".parse().unwrap());
assert_eq!(status, BudgetStatus::LimitExceeded);
}
#[test]
fn team_monthly_limit_exceeded_blocks() {
let t = tracker_with_team_monthly_limit("10.00");
let id = agent(54);
let status = t.record_raw_spend(id, Some("team-gamma"), None, "10.00".parse().unwrap());
assert_eq!(status, BudgetStatus::LimitExceeded);
}
#[test]
fn team_with_no_team_id_ignores_team_limits() {
let t = tracker_with_team_daily_limit("1.00");
let id = agent(55);
let status = t.record_raw_spend(id, None, None, "100.00".parse().unwrap());
assert!(matches!(status, BudgetStatus::WithinBudget { .. }));
}
#[test]
fn team_daily_80_pct_fires_alert_with_team_id() {
let t = tracker_with_team_daily_limit("10.00");
let mut rx = t.subscribe_alerts();
let id = agent(60);
t.record_raw_spend(id, Some("team-delta"), None, "8.00".parse().unwrap());
let alert = rx.try_recv().expect("expected 80% team alert");
assert_eq!(alert.threshold_pct, 80);
assert_eq!(alert.team_id.as_deref(), Some("team-delta"));
}
#[test]
fn team_daily_95_pct_fires_alert_with_team_id() {
let t = tracker_with_team_daily_limit("10.00");
let mut rx = t.subscribe_alerts();
let id = agent(61);
t.record_raw_spend(id, Some("team-epsilon"), None, "9.50".parse().unwrap());
let alert = rx.try_recv().expect("expected 95% team alert");
assert_eq!(alert.threshold_pct, 95);
assert_eq!(alert.team_id.as_deref(), Some("team-epsilon"));
}
#[test]
fn proptest_concurrent_decrement_invariants_hold() {
use proptest::prelude::*;
use std::sync::Arc;
let config = proptest::test_runner::Config {
cases: 256,
..Default::default()
};
let mut runner = proptest::test_runner::TestRunner::new(config);
runner
.run(&(1usize..=20usize), |n_threads| {
let root = AgentId::from_bytes([0xD0u8; 16]);
let child = AgentId::from_bytes([0xD1u8; 16]);
let t = Arc::new(
BudgetTracker::new(PricingTable::default_table(), None, None, chrono_tz::UTC)
.with_agent_limit(root, Some("10000.00".parse().unwrap()), None)
.with_agent_limit(child, Some("10000.00".parse().unwrap()), None),
);
let ancestors = vec![*child.as_bytes(), *root.as_bytes()];
let amount: Decimal = "0.01".parse().unwrap();
let mut handles = Vec::new();
for idx in 0..n_threads {
let t2 = Arc::clone(&t);
let anc = ancestors.clone();
handles.push(std::thread::spawn(move || {
let leaf = AgentId::from_bytes({
let mut b = [0xE0u8; 16];
b[1] = idx as u8;
b
});
t2.check_and_decrement(leaf, &anc, amount).unwrap();
}));
}
for h in handles {
h.join().unwrap();
}
let root_spent = t.per_agent.get(&root).unwrap().spent_usd;
let expected = amount * Decimal::from(n_threads);
prop_assert_eq!(root_spent, expected);
Ok(())
})
.unwrap();
}
#[test]
fn cross_ancestor_lock_ordering_completes_without_deadlock() {
use std::sync::Arc;
let root = AgentId::from_bytes([0xF0u8; 16]);
let child_a = AgentId::from_bytes([0xF1u8; 16]);
let child_b = AgentId::from_bytes([0xF2u8; 16]);
let t = Arc::new(
BudgetTracker::new(PricingTable::default_table(), None, None, chrono_tz::UTC)
.with_agent_limit(root, Some("1000.00".parse().unwrap()), None)
.with_agent_limit(child_a, Some("1000.00".parse().unwrap()), None)
.with_agent_limit(child_b, Some("1000.00".parse().unwrap()), None),
);
let amount: Decimal = "0.01".parse().unwrap();
let mut handles = Vec::new();
for idx in 0u8..50 {
let t2 = Arc::clone(&t);
let (child, leaf_b) = if idx % 2 == 0 {
(
*child_a.as_bytes(),
AgentId::from_bytes({
let mut b = [0xA0u8; 16];
b[1] = idx;
b
}),
)
} else {
(
*child_b.as_bytes(),
AgentId::from_bytes({
let mut b = [0xB0u8; 16];
b[1] = idx;
b
}),
)
};
handles.push(std::thread::spawn(move || {
let ancestors = [child, *root.as_bytes()];
t2.check_and_decrement(leaf_b, &ancestors, amount).unwrap();
}));
}
for h in handles {
h.join().expect("thread must not panic (deadlock would timeout)");
}
let root_spent = t.per_agent.get(&root).unwrap().spent_usd;
let expected: Decimal = "0.50".parse().unwrap(); assert_eq!(root_spent, expected);
}
#[test]
fn subtree_spend_leaf_agent_equals_own_spend() {
let t = new_tracker();
let id = agent(80);
t.record_raw_spend(id, None, None, "3.00".parse().unwrap());
let result = t.subtree_spend(&id, &[]);
assert_eq!(result.usd, "3.00".parse::<Decimal>().unwrap());
assert_eq!(result.agents_counted, 1);
}
#[test]
fn subtree_spend_three_deep_seven_node_tree_returns_correct_total() {
let root = agent(81);
let child_a = agent(82);
let child_b = agent(83);
let gc1 = agent(84);
let gc2 = agent(85);
let gc3 = agent(86);
let gc4 = agent(87);
let t = new_tracker();
let one: Decimal = "1.00".parse().unwrap();
for &a in &[root, child_a, child_b, gc1, gc2, gc3, gc4] {
t.record_raw_spend(a, None, None, one);
}
let descendants = [
*child_a.as_bytes(),
*child_b.as_bytes(),
*gc1.as_bytes(),
*gc2.as_bytes(),
*gc3.as_bytes(),
*gc4.as_bytes(),
];
let result = t.subtree_spend(&root, &descendants);
let expected: Decimal = "7.00".parse().unwrap();
assert_eq!(result.usd, expected);
assert_eq!(result.agents_counted, 7);
}
#[test]
fn subtree_spend_no_usage_returns_zero() {
let t = new_tracker();
let id = agent(88);
let result = t.subtree_spend(&id, &[]);
assert_eq!(result.usd, Decimal::ZERO);
assert_eq!(result.agents_counted, 0);
}
#[test]
fn check_and_decrement_linear_chain_all_sufficient_decrements_all_levels() {
let root = agent(70);
let child = agent(71);
let grandchild = agent(72);
let t = BudgetTracker::new(PricingTable::default_table(), None, None, chrono_tz::UTC)
.with_agent_limit(root, Some("10.00".parse().unwrap()), None)
.with_agent_limit(child, Some("10.00".parse().unwrap()), None)
.with_agent_limit(grandchild, Some("10.00".parse().unwrap()), None);
let ancestors = [*child.as_bytes(), *root.as_bytes()];
let amount: Decimal = "1.00".parse().unwrap();
t.check_and_decrement(grandchild, &ancestors, amount).unwrap();
assert_eq!(t.per_agent.get(&grandchild).unwrap().spent_usd, amount);
assert_eq!(t.per_agent.get(&child).unwrap().spent_usd, amount);
assert_eq!(t.per_agent.get(&root).unwrap().spent_usd, amount);
}
#[test]
fn check_and_decrement_ancestor_exhausted_blocks_all_decrements() {
let root = agent(73);
let child = agent(74);
let grandchild = agent(75);
let t = BudgetTracker::new(PricingTable::default_table(), None, None, chrono_tz::UTC)
.with_agent_limit(root, Some("5.00".parse().unwrap()), None)
.with_agent_limit(child, Some("50.00".parse().unwrap()), None);
t.per_agent
.entry(root)
.or_insert_with(|| BudgetState::new_for_date(today_in_tz(chrono_tz::UTC)))
.spent_usd = "5.00".parse().unwrap();
let ancestors = [*child.as_bytes(), *root.as_bytes()];
let result = t.check_and_decrement(grandchild, &ancestors, "1.00".parse().unwrap());
assert!(
matches!(
result,
Err(crate::budget::types::BudgetError::AncestorBudgetExhausted { .. })
),
"expected AncestorBudgetExhausted, got: {:?}",
result
);
assert!(t.per_agent.get(&child).is_none());
assert!(t.per_agent.get(&grandchild).is_none());
}
#[test]
fn check_and_decrement_concurrent_calls_preserve_sum_invariant() {
use std::sync::Arc;
let root = AgentId::from_bytes([0xA0u8; 16]);
let child = AgentId::from_bytes([0xB0u8; 16]);
let t = Arc::new(
BudgetTracker::new(PricingTable::default_table(), None, None, chrono_tz::UTC)
.with_agent_limit(root, Some("1000.00".parse().unwrap()), None)
.with_agent_limit(child, Some("1000.00".parse().unwrap()), None),
);
let ancestors = vec![*child.as_bytes(), *root.as_bytes()];
let amount: Decimal = "0.10".parse().unwrap();
let mut handles = Vec::new();
for n in 0u8..100 {
let t2 = Arc::clone(&t);
let anc = ancestors.clone();
handles.push(std::thread::spawn(move || {
let leaf = AgentId::from_bytes([0xC0u8 | (n % 64), n, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
t2.check_and_decrement(leaf, &anc, amount).unwrap();
}));
}
for h in handles {
h.join().unwrap();
}
let root_spent = t.per_agent.get(&root).unwrap().spent_usd;
let expected: Decimal = "10.00".parse().unwrap();
assert_eq!(root_spent, expected, "root must accumulate all 100 × $0.10");
}
#[test]
fn team_monthly_80_pct_fires_alert_with_team_id() {
let t = tracker_with_team_monthly_limit("10.00");
let mut rx = t.subscribe_alerts();
let id = agent(62);
t.record_raw_spend(id, Some("team-zeta"), None, "8.00".parse().unwrap());
let alert = rx.try_recv().expect("expected 80% monthly team alert");
assert_eq!(alert.threshold_pct, 80);
assert_eq!(alert.team_id.as_deref(), Some("team-zeta"));
}
}