use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Instant;
use rust_decimal::Decimal;
use rust_decimal_macros::dec;
use tokio::sync::Mutex;
use crate::llm::costs;
#[derive(Debug, Clone, Default)]
pub struct CostGuardConfig {
pub max_cost_per_day_cents: Option<u64>,
pub max_actions_per_hour: Option<u64>,
}
#[derive(Debug, Clone)]
pub enum CostLimitExceeded {
DailyBudget { spent_cents: u64, limit_cents: u64 },
HourlyRate { actions: u64, limit: u64 },
}
impl std::fmt::Display for CostLimitExceeded {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::DailyBudget {
spent_cents,
limit_cents,
} => write!(
f,
"Daily cost limit exceeded: spent ${:.2} of ${:.2} allowed",
*spent_cents as f64 / 100.0,
*limit_cents as f64 / 100.0
),
Self::HourlyRate { actions, limit } => write!(
f,
"Hourly action limit exceeded: {} actions of {} allowed per hour",
actions, limit
),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ModelTokens {
pub input_tokens: u64,
pub output_tokens: u64,
pub cost: Decimal,
}
pub struct CostGuard {
config: CostGuardConfig,
daily_cost: Mutex<DailyCost>,
action_window: Mutex<VecDeque<Instant>>,
budget_exceeded: AtomicBool,
model_tokens: Mutex<HashMap<String, ModelTokens>>,
}
struct DailyCost {
total: Decimal,
reset_date: chrono::NaiveDate,
}
impl CostGuard {
pub fn new(config: CostGuardConfig) -> Self {
Self {
config,
daily_cost: Mutex::new(DailyCost {
total: Decimal::ZERO,
reset_date: chrono::Utc::now().date_naive(),
}),
action_window: Mutex::new(VecDeque::new()),
budget_exceeded: AtomicBool::new(false),
model_tokens: Mutex::new(HashMap::new()),
}
}
pub async fn check_allowed(&self) -> Result<(), CostLimitExceeded> {
if self.budget_exceeded.load(Ordering::Relaxed) {
let daily = self.daily_cost.lock().await;
let spent_cents = to_cents(daily.total);
return Err(CostLimitExceeded::DailyBudget {
spent_cents,
limit_cents: self.config.max_cost_per_day_cents.unwrap_or(0),
});
}
if let Some(limit_cents) = self.config.max_cost_per_day_cents {
let daily = self.daily_cost.lock().await;
let spent_cents = to_cents(daily.total);
if spent_cents >= limit_cents {
self.budget_exceeded.store(true, Ordering::Relaxed);
return Err(CostLimitExceeded::DailyBudget {
spent_cents,
limit_cents,
});
}
}
if let Some(limit) = self.config.max_actions_per_hour {
let mut window = self.action_window.lock().await;
if let Some(cutoff) = Instant::now().checked_sub(std::time::Duration::from_secs(3600)) {
while window.front().is_some_and(|t| *t < cutoff) {
window.pop_front();
}
}
let count = window.len() as u64;
if count >= limit {
return Err(CostLimitExceeded::HourlyRate {
actions: count,
limit,
});
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub async fn record_llm_call(
&self,
model: &str,
input_tokens: u32,
output_tokens: u32,
cache_read_input_tokens: u32,
cache_creation_input_tokens: u32,
cache_read_discount: Decimal,
cache_write_multiplier: Decimal,
cost_per_token: Option<(Decimal, Decimal)>,
) -> Decimal {
let (input_rate, output_rate) = cost_per_token
.unwrap_or_else(|| costs::model_cost(model).unwrap_or_else(costs::default_cost));
let cached_total = cache_read_input_tokens.saturating_add(cache_creation_input_tokens);
let uncached_input = input_tokens.saturating_sub(cached_total);
let effective_discount = if cache_read_discount.is_zero() {
Decimal::ONE
} else {
cache_read_discount
};
let cache_read_cost =
input_rate * Decimal::from(cache_read_input_tokens) / effective_discount;
let cache_write_cost =
input_rate * Decimal::from(cache_creation_input_tokens) * cache_write_multiplier;
let cost = input_rate * Decimal::from(uncached_input)
+ cache_read_cost
+ cache_write_cost
+ output_rate * Decimal::from(output_tokens);
{
let mut daily = self.daily_cost.lock().await;
let today = chrono::Utc::now().date_naive();
if today != daily.reset_date {
daily.total = Decimal::ZERO;
daily.reset_date = today;
self.budget_exceeded.store(false, Ordering::Relaxed);
tracing::info!("Cost guard: daily counter reset for {}", today);
}
daily.total += cost;
if let Some(limit_cents) = self.config.max_cost_per_day_cents {
let spent_cents = to_cents(daily.total);
if spent_cents >= limit_cents {
self.budget_exceeded.store(true, Ordering::Relaxed);
tracing::warn!(
"Daily cost limit reached: ${:.2} of ${:.2}",
daily.total,
Decimal::from(limit_cents) / dec!(100)
);
}
let warn_threshold = limit_cents * 80 / 100;
if spent_cents >= warn_threshold && spent_cents < limit_cents {
tracing::warn!(
"Approaching daily cost limit: ${:.2} of ${:.2} ({}%)",
daily.total,
Decimal::from(limit_cents) / dec!(100),
spent_cents * 100 / limit_cents
);
}
}
}
{
let mut window = self.action_window.lock().await;
window.push_back(Instant::now());
}
{
let mut tokens = self.model_tokens.lock().await;
let entry = tokens.entry(model.to_string()).or_default();
entry.input_tokens += u64::from(input_tokens);
entry.output_tokens += u64::from(output_tokens);
entry.cost += cost;
}
cost
}
pub async fn daily_spend(&self) -> Decimal {
let daily = self.daily_cost.lock().await;
let today = chrono::Utc::now().date_naive();
if today != daily.reset_date {
Decimal::ZERO
} else {
daily.total
}
}
pub async fn actions_this_hour(&self) -> u64 {
let mut window = self.action_window.lock().await;
if let Some(cutoff) = Instant::now().checked_sub(std::time::Duration::from_secs(3600)) {
while window.front().is_some_and(|t| *t < cutoff) {
window.pop_front();
}
}
window.len() as u64
}
pub async fn model_usage(&self) -> HashMap<String, ModelTokens> {
self.model_tokens.lock().await.clone()
}
}
fn to_cents(usd: Decimal) -> u64 {
let cents = (usd * dec!(100)).trunc();
cents.to_string().parse::<u64>().unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_unlimited_allows_everything() {
let guard = CostGuard::new(CostGuardConfig::default());
assert!(guard.check_allowed().await.is_ok());
guard
.record_llm_call(
"gpt-4o",
100_000,
100_000,
0,
0,
Decimal::ONE,
Decimal::ONE,
None,
)
.await;
assert!(guard.check_allowed().await.is_ok());
}
#[tokio::test]
async fn test_daily_budget_enforcement() {
let guard = CostGuard::new(CostGuardConfig {
max_cost_per_day_cents: Some(1), max_actions_per_hour: None,
});
assert!(guard.check_allowed().await.is_ok());
guard
.record_llm_call(
"gpt-4o",
10_000,
10_000,
0,
0,
Decimal::ONE,
Decimal::ONE,
None,
)
.await;
let result = guard.check_allowed().await;
assert!(result.is_err());
match result.unwrap_err() {
CostLimitExceeded::DailyBudget { limit_cents, .. } => {
assert_eq!(limit_cents, 1);
}
other => panic!("Expected DailyBudget, got {:?}", other),
}
}
#[tokio::test]
async fn test_hourly_rate_enforcement() {
let guard = CostGuard::new(CostGuardConfig {
max_cost_per_day_cents: None,
max_actions_per_hour: Some(3),
});
for _ in 0..3 {
assert!(guard.check_allowed().await.is_ok());
guard
.record_llm_call("gpt-4o", 10, 10, 0, 0, Decimal::ONE, Decimal::ONE, None)
.await;
}
let result = guard.check_allowed().await;
assert!(result.is_err());
match result.unwrap_err() {
CostLimitExceeded::HourlyRate { actions, limit } => {
assert_eq!(actions, 3);
assert_eq!(limit, 3);
}
other => panic!("Expected HourlyRate, got {:?}", other),
}
}
#[tokio::test]
async fn test_daily_spend_tracking() {
let guard = CostGuard::new(CostGuardConfig::default());
assert_eq!(guard.daily_spend().await, Decimal::ZERO);
let cost = guard
.record_llm_call("gpt-4o", 1000, 500, 0, 0, Decimal::ONE, Decimal::ONE, None)
.await;
assert!(cost > Decimal::ZERO);
assert_eq!(guard.daily_spend().await, cost);
}
#[tokio::test]
async fn test_actions_this_hour() {
let guard = CostGuard::new(CostGuardConfig::default());
assert_eq!(guard.actions_this_hour().await, 0);
guard
.record_llm_call("gpt-4o", 10, 10, 0, 0, Decimal::ONE, Decimal::ONE, None)
.await;
guard
.record_llm_call("gpt-4o", 10, 10, 0, 0, Decimal::ONE, Decimal::ONE, None)
.await;
assert_eq!(guard.actions_this_hour().await, 2);
}
#[test]
fn test_to_cents() {
assert_eq!(to_cents(dec!(1.50)), 150);
assert_eq!(to_cents(dec!(0.01)), 1);
assert_eq!(to_cents(Decimal::ZERO), 0);
}
#[test]
fn test_cost_limit_display() {
let budget = CostLimitExceeded::DailyBudget {
spent_cents: 1050,
limit_cents: 1000,
};
assert!(budget.to_string().contains("$10.50"));
assert!(budget.to_string().contains("$10.00"));
let rate = CostLimitExceeded::HourlyRate {
actions: 101,
limit: 100,
};
assert!(rate.to_string().contains("101 actions"));
assert!(rate.to_string().contains("100 allowed"));
}
#[tokio::test]
async fn test_model_usage_per_model_tracking() {
let guard = CostGuard::new(CostGuardConfig::default());
assert!(guard.model_usage().await.is_empty());
guard
.record_llm_call("gpt-4o", 1000, 500, 0, 0, Decimal::ONE, Decimal::ONE, None)
.await;
guard
.record_llm_call("gpt-4o", 2000, 1000, 0, 0, Decimal::ONE, Decimal::ONE, None)
.await;
guard
.record_llm_call(
"claude-3-5-sonnet-20241022",
500,
200,
0,
0,
Decimal::ONE,
Decimal::ONE,
None,
)
.await;
let usage = guard.model_usage().await;
assert_eq!(usage.len(), 2);
let gpt = usage.get("gpt-4o").expect("gpt-4o should be tracked");
assert_eq!(gpt.input_tokens, 3000);
assert_eq!(gpt.output_tokens, 1500);
assert!(gpt.cost > Decimal::ZERO);
let claude = usage
.get("claude-3-5-sonnet-20241022")
.expect("claude should be tracked");
assert_eq!(claude.input_tokens, 500);
assert_eq!(claude.output_tokens, 200);
assert!(claude.cost > Decimal::ZERO);
assert_ne!(gpt.cost, claude.cost);
}
#[tokio::test]
async fn test_cache_discount_reduces_cost() {
let guard = CostGuard::new(CostGuardConfig::default());
let full_cost = guard
.record_llm_call(
"claude-opus-4-6",
1000,
500,
0,
0,
Decimal::ONE,
Decimal::ONE,
None,
)
.await;
let guard2 = CostGuard::new(CostGuardConfig::default());
let cached_cost = guard2
.record_llm_call(
"claude-opus-4-6",
1000,
500,
1000,
0,
dec!(10),
Decimal::ONE,
None,
)
.await;
assert!(
cached_cost < full_cost,
"cached_cost ({}) should be less than full_cost ({})",
cached_cost,
full_cost
);
let (input_rate, _) = costs::model_cost("claude-opus-4-6").unwrap();
let expected_savings = input_rate * Decimal::from(1000u32) * dec!(9) / dec!(10);
let actual_savings = full_cost - cached_cost;
assert_eq!(
actual_savings, expected_savings,
"savings should be 90% of input cost for fully-cached request"
);
}
#[tokio::test]
async fn test_cache_write_surcharge_increases_cost() {
let guard = CostGuard::new(CostGuardConfig::default());
let full_cost = guard
.record_llm_call(
"claude-opus-4-6",
1000,
500,
0,
0,
Decimal::ONE,
Decimal::ONE,
None,
)
.await;
let guard2 = CostGuard::new(CostGuardConfig::default());
let short_multiplier = Decimal::new(125, 2); let write_cost = guard2
.record_llm_call(
"claude-opus-4-6",
1000,
500,
0,
1000,
Decimal::ONE,
short_multiplier,
None,
)
.await;
assert!(
write_cost > full_cost,
"write_cost ({}) should be greater than full_cost ({})",
write_cost,
full_cost
);
let (input_rate, _) = costs::model_cost("claude-opus-4-6").unwrap();
let expected_surcharge = input_rate * Decimal::from(1000u32) * dec!(0.25);
let actual_surcharge = write_cost - full_cost;
assert_eq!(
actual_surcharge, expected_surcharge,
"surcharge should be 25% of input cost for 5m cache writes"
);
}
#[tokio::test]
async fn test_cache_write_surcharge_long_ttl() {
let guard = CostGuard::new(CostGuardConfig::default());
let full_cost = guard
.record_llm_call(
"claude-opus-4-6",
1000,
500,
0,
0,
Decimal::ONE,
Decimal::ONE,
None,
)
.await;
let guard2 = CostGuard::new(CostGuardConfig::default());
let long_multiplier = Decimal::TWO;
let write_cost = guard2
.record_llm_call(
"claude-opus-4-6",
1000,
500,
0,
1000,
Decimal::ONE,
long_multiplier,
None,
)
.await;
assert!(write_cost > full_cost);
let (input_rate, _) = costs::model_cost("claude-opus-4-6").unwrap();
let expected_surcharge = input_rate * Decimal::from(1000u32);
let actual_surcharge = write_cost - full_cost;
assert_eq!(
actual_surcharge, expected_surcharge,
"surcharge should be 100% of input cost for 1h cache writes"
);
}
#[tokio::test]
async fn test_checked_sub_no_panic_on_fresh_guard() {
let guard = CostGuard::new(CostGuardConfig {
max_cost_per_day_cents: None,
max_actions_per_hour: Some(100),
});
assert!(guard.check_allowed().await.is_ok());
assert_eq!(guard.actions_this_hour().await, 0);
guard
.record_llm_call("gpt-4o", 10, 10, 0, 0, Decimal::ONE, Decimal::ONE, None)
.await;
assert!(guard.check_allowed().await.is_ok());
assert_eq!(guard.actions_this_hour().await, 1);
}
#[test]
fn test_instant_checked_sub_returns_none_for_overflow() {
let result = Instant::now().checked_sub(std::time::Duration::MAX);
assert!(result.is_none());
}
}