use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use crate::error::{Error, Result};
use crate::ir::Usage;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "axis", rename_all = "snake_case")]
#[non_exhaustive]
pub enum UsageLimitBreach {
Requests {
limit: u64,
observed: u64,
},
InputTokens {
limit: u64,
observed: u64,
},
OutputTokens {
limit: u64,
observed: u64,
},
TotalTokens {
limit: u64,
observed: u64,
},
ToolCalls {
limit: u64,
observed: u64,
},
CostUsd {
limit: Decimal,
observed: Decimal,
},
}
impl UsageLimitBreach {
#[must_use]
pub const fn axis_name(&self) -> &'static str {
match self {
Self::Requests { .. } => "requests",
Self::InputTokens { .. } => "input_tokens",
Self::OutputTokens { .. } => "output_tokens",
Self::TotalTokens { .. } => "total_tokens",
Self::ToolCalls { .. } => "tool_calls",
Self::CostUsd { .. } => "cost_usd",
}
}
}
impl std::fmt::Display for UsageLimitBreach {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let axis = self.axis_name();
match self {
Self::Requests { limit, observed }
| Self::InputTokens { limit, observed }
| Self::OutputTokens { limit, observed }
| Self::TotalTokens { limit, observed }
| Self::ToolCalls { limit, observed } => {
write!(
f,
"run budget exceeded on {axis} axis: observed {observed}, limit {limit}"
)
}
Self::CostUsd { limit, observed } => {
write!(
f,
"run budget exceeded on {axis} axis: observed {observed}, limit {limit}"
)
}
}
}
}
#[derive(Clone, Debug, Default)]
pub struct RunBudget {
request_limit: Option<u32>,
input_tokens_limit: Option<u64>,
output_tokens_limit: Option<u64>,
total_tokens_limit: Option<u64>,
tool_calls_limit: Option<u32>,
cost_usd_limit: Option<Decimal>,
state: Arc<RunBudgetState>,
}
#[derive(Debug, Default)]
struct RunBudgetState {
requests: AtomicU32,
input_tokens: AtomicU64,
output_tokens: AtomicU64,
tool_calls: AtomicU32,
cost_usd: Mutex<Decimal>,
}
impl RunBudget {
#[must_use]
pub fn unlimited() -> Self {
Self::default()
}
#[must_use]
pub const fn with_request_limit(mut self, n: u32) -> Self {
self.request_limit = Some(n);
self
}
#[must_use]
pub const fn with_input_tokens_limit(mut self, n: u64) -> Self {
self.input_tokens_limit = Some(n);
self
}
#[must_use]
pub const fn with_output_tokens_limit(mut self, n: u64) -> Self {
self.output_tokens_limit = Some(n);
self
}
#[must_use]
pub const fn with_total_tokens_limit(mut self, n: u64) -> Self {
self.total_tokens_limit = Some(n);
self
}
#[must_use]
pub const fn with_tool_calls_limit(mut self, n: u32) -> Self {
self.tool_calls_limit = Some(n);
self
}
#[must_use]
pub const fn with_cost_limit_usd(mut self, limit: Decimal) -> Self {
self.cost_usd_limit = Some(limit);
self
}
pub fn check_pre_request(&self) -> Result<()> {
if let Some(limit) = self.request_limit {
loop {
let current = self.state.requests.load(Ordering::Acquire);
if u64::from(current) >= u64::from(limit) {
return Err(Error::UsageLimitExceeded(UsageLimitBreach::Requests {
limit: u64::from(limit),
observed: u64::from(current),
}));
}
if self
.state
.requests
.compare_exchange_weak(
current,
current.saturating_add(1),
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
{
return Ok(());
}
}
}
Ok(())
}
#[must_use]
pub const fn request_limit(&self) -> Option<u32> {
self.request_limit
}
#[must_use]
pub const fn input_tokens_limit(&self) -> Option<u64> {
self.input_tokens_limit
}
#[must_use]
pub const fn output_tokens_limit(&self) -> Option<u64> {
self.output_tokens_limit
}
#[must_use]
pub const fn total_tokens_limit(&self) -> Option<u64> {
self.total_tokens_limit
}
#[must_use]
pub const fn tool_calls_limit(&self) -> Option<u32> {
self.tool_calls_limit
}
#[must_use]
pub const fn cost_usd_limit(&self) -> Option<Decimal> {
self.cost_usd_limit
}
pub fn check_pre_request_tokens(
&self,
estimated_input: u64,
estimated_output: u64,
) -> Result<()> {
let observed_in = self.state.input_tokens.load(Ordering::Acquire);
let observed_out = self.state.output_tokens.load(Ordering::Acquire);
let projected_in = observed_in.saturating_add(estimated_input);
let projected_out = observed_out.saturating_add(estimated_output);
if let Some(limit) = self.input_tokens_limit
&& projected_in > limit
{
return Err(Error::UsageLimitExceeded(UsageLimitBreach::InputTokens {
limit,
observed: projected_in,
}));
}
if let Some(limit) = self.output_tokens_limit
&& projected_out > limit
{
return Err(Error::UsageLimitExceeded(UsageLimitBreach::OutputTokens {
limit,
observed: projected_out,
}));
}
if let Some(limit) = self.total_tokens_limit {
let projected_total = projected_in.saturating_add(projected_out);
if projected_total > limit {
return Err(Error::UsageLimitExceeded(UsageLimitBreach::TotalTokens {
limit,
observed: projected_total,
}));
}
}
Ok(())
}
pub fn check_pre_request_cost(&self, estimated_charge: Decimal) -> Result<()> {
if let Some(limit) = self.cost_usd_limit {
let observed = *self
.state
.cost_usd
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let projected = observed.saturating_add(estimated_charge);
if projected > limit {
return Err(Error::UsageLimitExceeded(UsageLimitBreach::CostUsd {
limit,
observed: projected,
}));
}
}
Ok(())
}
pub fn check_pre_tool_call(&self) -> Result<()> {
if let Some(limit) = self.tool_calls_limit {
loop {
let current = self.state.tool_calls.load(Ordering::Acquire);
if u64::from(current) >= u64::from(limit) {
return Err(Error::UsageLimitExceeded(UsageLimitBreach::ToolCalls {
limit: u64::from(limit),
observed: u64::from(current),
}));
}
if self
.state
.tool_calls
.compare_exchange_weak(
current,
current.saturating_add(1),
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
{
return Ok(());
}
}
}
Ok(())
}
pub fn observe_usage(&self, usage: &Usage) -> Result<()> {
let new_in = self
.state
.input_tokens
.fetch_add(u64::from(usage.input_tokens), Ordering::AcqRel)
.saturating_add(u64::from(usage.input_tokens));
let new_out = self
.state
.output_tokens
.fetch_add(u64::from(usage.output_tokens), Ordering::AcqRel)
.saturating_add(u64::from(usage.output_tokens));
if let Some(limit) = self.input_tokens_limit
&& new_in > limit
{
return Err(Error::UsageLimitExceeded(UsageLimitBreach::InputTokens {
limit,
observed: new_in,
}));
}
if let Some(limit) = self.output_tokens_limit
&& new_out > limit
{
return Err(Error::UsageLimitExceeded(UsageLimitBreach::OutputTokens {
limit,
observed: new_out,
}));
}
if let Some(limit) = self.total_tokens_limit {
let total = new_in.saturating_add(new_out);
if total > limit {
return Err(Error::UsageLimitExceeded(UsageLimitBreach::TotalTokens {
limit,
observed: total,
}));
}
}
Ok(())
}
pub fn observe_cost(&self, charge_usd: Decimal) -> Result<()> {
let observed = {
let mut accumulated = self
.state
.cost_usd
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
*accumulated = accumulated.saturating_add(charge_usd);
*accumulated
};
if let Some(limit) = self.cost_usd_limit
&& observed > limit
{
return Err(Error::UsageLimitExceeded(UsageLimitBreach::CostUsd {
limit,
observed,
}));
}
Ok(())
}
#[must_use]
pub fn snapshot(&self) -> UsageSnapshot {
let cost_usd = *self
.state
.cost_usd
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
UsageSnapshot {
requests: self.state.requests.load(Ordering::Acquire),
input_tokens: self.state.input_tokens.load(Ordering::Acquire),
output_tokens: self.state.output_tokens.load(Ordering::Acquire),
tool_calls: self.state.tool_calls.load(Ordering::Acquire),
cost_usd,
}
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct UsageSnapshot {
pub requests: u32,
pub input_tokens: u64,
pub output_tokens: u64,
pub tool_calls: u32,
pub cost_usd: Decimal,
}
impl UsageSnapshot {
#[must_use]
pub const fn total_tokens(&self) -> u64 {
self.input_tokens.saturating_add(self.output_tokens)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::ir::Usage;
#[test]
fn unlimited_budget_passes_every_check() {
let budget = RunBudget::unlimited();
budget.check_pre_request().unwrap();
budget.check_pre_tool_call().unwrap();
budget
.observe_usage(&Usage::new(1_000_000, 1_000_000))
.unwrap();
}
#[test]
fn request_limit_pre_check_increments_then_breaks() {
let budget = RunBudget::unlimited().with_request_limit(2);
budget.check_pre_request().unwrap();
budget.check_pre_request().unwrap();
let err = budget.check_pre_request().unwrap_err();
match err {
Error::UsageLimitExceeded(UsageLimitBreach::Requests {
limit: 2,
observed: 2,
}) => {}
other => panic!("unexpected: {other:?}"),
}
assert_eq!(budget.snapshot().requests, 2);
}
#[test]
fn tool_calls_limit_pre_check_breaks() {
let budget = RunBudget::unlimited().with_tool_calls_limit(1);
budget.check_pre_tool_call().unwrap();
let err = budget.check_pre_tool_call().unwrap_err();
assert!(matches!(
err,
Error::UsageLimitExceeded(UsageLimitBreach::ToolCalls { .. })
));
}
#[test]
fn input_tokens_limit_post_observe_breaks() {
let budget = RunBudget::unlimited().with_input_tokens_limit(100);
budget.observe_usage(&Usage::new(50, 0)).unwrap();
let err = budget.observe_usage(&Usage::new(60, 0)).unwrap_err();
match err {
Error::UsageLimitExceeded(UsageLimitBreach::InputTokens {
limit: 100,
observed: 110,
}) => {}
other => panic!("unexpected: {other:?}"),
}
}
#[test]
fn output_tokens_limit_post_observe_breaks() {
let budget = RunBudget::unlimited().with_output_tokens_limit(100);
budget.observe_usage(&Usage::new(0, 99)).unwrap();
let err = budget.observe_usage(&Usage::new(0, 2)).unwrap_err();
assert!(matches!(
err,
Error::UsageLimitExceeded(UsageLimitBreach::OutputTokens { .. })
));
}
#[test]
fn total_tokens_limit_combines_input_and_output() {
let budget = RunBudget::unlimited().with_total_tokens_limit(100);
budget.observe_usage(&Usage::new(40, 40)).unwrap();
let err = budget.observe_usage(&Usage::new(20, 20)).unwrap_err();
match err {
Error::UsageLimitExceeded(UsageLimitBreach::TotalTokens {
limit: 100,
observed: 120,
}) => {}
other => panic!("unexpected: {other:?}"),
}
}
#[test]
fn cost_usd_limit_post_observe_breaks() {
use rust_decimal::Decimal;
let cap = Decimal::new(50, 2); let budget = RunBudget::unlimited().with_cost_limit_usd(cap);
budget.observe_cost(Decimal::new(30, 2)).unwrap(); let err = budget.observe_cost(Decimal::new(25, 2)).unwrap_err();
match err {
Error::UsageLimitExceeded(UsageLimitBreach::CostUsd { limit, observed }) => {
assert_eq!(limit, cap);
assert_eq!(observed, Decimal::new(55, 2)); }
other => panic!("unexpected: {other:?}"),
}
assert_eq!(budget.snapshot().cost_usd, Decimal::new(55, 2));
}
#[test]
fn cost_unlimited_accumulates_without_breaching() {
use rust_decimal::Decimal;
let budget = RunBudget::unlimited();
budget.observe_cost(Decimal::new(100, 2)).unwrap(); budget.observe_cost(Decimal::new(200, 2)).unwrap(); assert_eq!(budget.snapshot().cost_usd, Decimal::new(300, 2));
}
#[test]
fn clone_shares_atomic_state() {
let parent = RunBudget::unlimited().with_request_limit(2);
let child = parent.clone();
parent.check_pre_request().unwrap();
child.check_pre_request().unwrap();
let err = parent.check_pre_request().unwrap_err();
assert!(matches!(
err,
Error::UsageLimitExceeded(UsageLimitBreach::Requests { .. })
));
}
#[test]
fn cost_clone_shares_arc_state() {
use rust_decimal::Decimal;
let cap = Decimal::new(100, 2); let parent = RunBudget::unlimited().with_cost_limit_usd(cap);
let child = parent.clone();
parent.observe_cost(Decimal::new(60, 2)).unwrap(); child.observe_cost(Decimal::new(30, 2)).unwrap(); let err = child.observe_cost(Decimal::new(20, 2)).unwrap_err();
match err {
Error::UsageLimitExceeded(UsageLimitBreach::CostUsd { limit, observed }) => {
assert_eq!(limit, cap);
assert_eq!(observed, Decimal::new(110, 2));
}
other => panic!("unexpected: {other:?}"),
}
assert_eq!(parent.snapshot().cost_usd, Decimal::new(110, 2));
assert_eq!(child.snapshot().cost_usd, Decimal::new(110, 2));
}
#[test]
fn limit_accessors_reflect_configuration() {
let budget = RunBudget::unlimited()
.with_request_limit(10)
.with_input_tokens_limit(100)
.with_output_tokens_limit(50)
.with_total_tokens_limit(140)
.with_tool_calls_limit(5)
.with_cost_limit_usd(Decimal::new(150, 2));
assert_eq!(budget.request_limit(), Some(10));
assert_eq!(budget.input_tokens_limit(), Some(100));
assert_eq!(budget.output_tokens_limit(), Some(50));
assert_eq!(budget.total_tokens_limit(), Some(140));
assert_eq!(budget.tool_calls_limit(), Some(5));
assert_eq!(budget.cost_usd_limit(), Some(Decimal::new(150, 2)));
let unbounded = RunBudget::unlimited();
assert_eq!(unbounded.request_limit(), None);
assert_eq!(unbounded.cost_usd_limit(), None);
}
#[test]
fn check_pre_request_cost_blocks_when_estimate_overshoots() {
let budget = RunBudget::unlimited().with_cost_limit_usd(Decimal::new(100, 2)); budget.observe_cost(Decimal::new(98, 2)).unwrap(); let err = budget
.check_pre_request_cost(Decimal::new(5, 2)) .unwrap_err();
match err {
Error::UsageLimitExceeded(UsageLimitBreach::CostUsd { limit, observed }) => {
assert_eq!(limit, Decimal::new(100, 2));
assert_eq!(observed, Decimal::new(103, 2));
}
other => panic!("unexpected: {other:?}"),
}
assert_eq!(budget.snapshot().cost_usd, Decimal::new(98, 2));
}
#[test]
fn check_pre_request_cost_passes_when_estimate_fits() {
let budget = RunBudget::unlimited().with_cost_limit_usd(Decimal::new(100, 2));
budget.observe_cost(Decimal::new(50, 2)).unwrap();
budget.check_pre_request_cost(Decimal::new(30, 2)).unwrap();
assert_eq!(budget.snapshot().cost_usd, Decimal::new(50, 2));
}
#[test]
fn check_pre_request_cost_no_op_when_axis_unbounded() {
let budget = RunBudget::unlimited();
budget
.check_pre_request_cost(Decimal::new(10_000_000, 0))
.unwrap();
}
#[test]
fn check_pre_request_tokens_blocks_on_input_axis() {
let budget = RunBudget::unlimited().with_input_tokens_limit(100);
budget.observe_usage(&Usage::new(80, 0)).unwrap();
let err = budget.check_pre_request_tokens(30, 0).unwrap_err();
match err {
Error::UsageLimitExceeded(UsageLimitBreach::InputTokens { limit, observed }) => {
assert_eq!(limit, 100);
assert_eq!(observed, 110);
}
other => panic!("unexpected: {other:?}"),
}
assert_eq!(budget.snapshot().input_tokens, 80);
}
#[test]
fn check_pre_request_tokens_blocks_on_output_axis() {
let budget = RunBudget::unlimited().with_output_tokens_limit(100);
budget.observe_usage(&Usage::new(0, 80)).unwrap();
let err = budget.check_pre_request_tokens(0, 30).unwrap_err();
assert!(
matches!(
err,
Error::UsageLimitExceeded(UsageLimitBreach::OutputTokens { .. })
),
"got: {err:?}"
);
}
#[test]
fn check_pre_request_tokens_blocks_on_total_axis() {
let budget = RunBudget::unlimited().with_total_tokens_limit(150);
budget.observe_usage(&Usage::new(50, 50)).unwrap();
let err = budget.check_pre_request_tokens(40, 40).unwrap_err();
match err {
Error::UsageLimitExceeded(UsageLimitBreach::TotalTokens { limit, observed }) => {
assert_eq!(limit, 150);
assert_eq!(observed, 180);
}
other => panic!("unexpected: {other:?}"),
}
}
#[test]
fn check_pre_request_tokens_input_fires_before_total() {
let budget = RunBudget::unlimited()
.with_input_tokens_limit(50)
.with_total_tokens_limit(60);
budget.observe_usage(&Usage::new(40, 0)).unwrap();
let err = budget.check_pre_request_tokens(20, 20).unwrap_err();
assert!(matches!(
err,
Error::UsageLimitExceeded(UsageLimitBreach::InputTokens { .. })
));
}
#[test]
fn check_pre_request_tokens_no_op_when_all_axes_unbounded() {
let budget = RunBudget::unlimited();
budget
.check_pre_request_tokens(u64::MAX / 2, u64::MAX / 2)
.unwrap();
}
#[test]
fn snapshot_returns_owned_values() {
let budget = RunBudget::unlimited().with_request_limit(100);
budget.check_pre_request().unwrap();
budget.observe_usage(&Usage::new(10, 5)).unwrap();
let snap = budget.snapshot();
assert_eq!(snap.requests, 1);
assert_eq!(snap.input_tokens, 10);
assert_eq!(snap.output_tokens, 5);
assert_eq!(snap.total_tokens(), 15);
budget.check_pre_request().unwrap();
assert_eq!(snap.requests, 1);
}
}