use std::sync::atomic::{AtomicBool, Ordering};
use crate::config::guardrails::CostGuardrails;
use crate::provider::pricing::session_cost_usd;
static WARNED: AtomicBool = AtomicBool::new(false);
#[derive(Debug, Clone, Copy)]
pub enum CostGuardStatus {
Ok,
Warned {
spent_usd: f64,
threshold_usd: f64,
},
Block {
spent_usd: f64,
limit_usd: f64,
},
}
fn check() -> CostGuardStatus {
let g = CostGuardrails::from_env();
let spent = session_cost_usd();
if let Some(limit) = g.hard_limit_usd
&& spent >= limit
{
return CostGuardStatus::Block {
spent_usd: spent,
limit_usd: limit,
};
}
if let Some(warn) = g.warn_usd
&& spent >= warn
&& !WARNED.swap(true, Ordering::Relaxed)
{
return CostGuardStatus::Warned {
spent_usd: spent,
threshold_usd: warn,
};
}
CostGuardStatus::Ok
}
pub fn enforce_cost_budget() -> anyhow::Result<()> {
match check() {
CostGuardStatus::Block {
spent_usd,
limit_usd,
} => {
anyhow::bail!(
"Cost guardrail tripped: session has spent ~${:.2} which meets/exceeds the \
hard limit of ${:.2}. Raise CODETETHER_COST_LIMIT_USD (or \
`[guardrails] hard_limit_usd` in config) to continue.",
spent_usd,
limit_usd
)
}
CostGuardStatus::Warned {
spent_usd,
threshold_usd,
} => {
tracing::warn!(
spent_usd,
threshold_usd,
"Cost guardrail warn threshold reached; set CODETETHER_COST_LIMIT_USD to cap spend"
);
Ok(())
}
CostGuardStatus::Ok => Ok(()),
}
}
pub fn cost_guard_level() -> CostGuardLevel {
let g = CostGuardrails::from_env();
let spent = session_cost_usd();
if let Some(limit) = g.hard_limit_usd
&& spent >= limit
{
return CostGuardLevel::OverLimit;
}
if let Some(warn) = g.warn_usd
&& spent >= warn
{
return CostGuardLevel::OverWarn;
}
CostGuardLevel::Ok
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CostGuardLevel {
Ok,
OverWarn,
OverLimit,
}