#![deny(missing_docs)]
use std::sync::Mutex;
#[derive(Debug, Clone, Copy, Default, PartialEq)]
pub struct Caps {
pub max_input_tokens: Option<u64>,
pub max_output_tokens: Option<u64>,
pub max_total_tokens: Option<u64>,
pub max_cost_usd: Option<f64>,
}
#[derive(Debug, Clone, Copy, Default, PartialEq)]
pub struct Totals {
pub input_tokens: u64,
pub output_tokens: u64,
pub cost_usd: f64,
pub calls: u64,
}
impl Totals {
pub fn total_tokens(&self) -> u64 {
self.input_tokens + self.output_tokens
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct BudgetExceeded {
pub cap: &'static str,
pub limit: f64,
pub attempted: f64,
}
impl std::fmt::Display for BudgetExceeded {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"budget cap `{}` exceeded: limit={}, attempted={}",
self.cap, self.limit, self.attempted
)
}
}
impl std::error::Error for BudgetExceeded {}
#[derive(Debug)]
pub struct BudgetPool {
caps: Caps,
state: Mutex<Totals>,
}
impl BudgetPool {
pub fn with_caps(caps: Caps) -> Self {
Self {
caps,
state: Mutex::new(Totals::default()),
}
}
pub fn unconstrained() -> Self {
Self::with_caps(Caps::default())
}
pub fn record(
&self,
input_tokens: u64,
output_tokens: u64,
cost_usd: f64,
) -> Result<Totals, BudgetExceeded> {
let mut s = self.state.lock().unwrap();
let next_in = s.input_tokens + input_tokens;
let next_out = s.output_tokens + output_tokens;
let next_total = next_in + next_out;
let next_cost = s.cost_usd + cost_usd;
if let Some(cap) = self.caps.max_input_tokens {
if next_in > cap {
return Err(BudgetExceeded {
cap: "input_tokens",
limit: cap as f64,
attempted: next_in as f64,
});
}
}
if let Some(cap) = self.caps.max_output_tokens {
if next_out > cap {
return Err(BudgetExceeded {
cap: "output_tokens",
limit: cap as f64,
attempted: next_out as f64,
});
}
}
if let Some(cap) = self.caps.max_total_tokens {
if next_total > cap {
return Err(BudgetExceeded {
cap: "total_tokens",
limit: cap as f64,
attempted: next_total as f64,
});
}
}
if let Some(cap) = self.caps.max_cost_usd {
if next_cost > cap {
return Err(BudgetExceeded {
cap: "cost_usd",
limit: cap,
attempted: next_cost,
});
}
}
s.input_tokens = next_in;
s.output_tokens = next_out;
s.cost_usd = next_cost;
s.calls += 1;
Ok(*s)
}
pub fn totals(&self) -> Totals {
*self.state.lock().unwrap()
}
pub fn caps(&self) -> Caps {
self.caps
}
pub fn reset(&self) {
*self.state.lock().unwrap() = Totals::default();
}
}