token_budget_pool/lib.rs
1//! # token-budget-pool
2//!
3//! Shared token + dollar budget across N concurrent LLM tasks.
4//!
5//! Drop a `BudgetPool` at the top of an agent run; pass `&pool` to every
6//! task that issues LLM calls; call [`BudgetPool::record`] after each
7//! response. The pool serializes the updates and returns
8//! [`BudgetExceeded`] when a record would push past any cap.
9//!
10//! ## Example
11//!
12//! ```
13//! use token_budget_pool::{BudgetPool, Caps};
14//!
15//! let pool = BudgetPool::with_caps(Caps {
16//! max_input_tokens: Some(10_000),
17//! max_output_tokens: Some(5_000),
18//! max_total_tokens: None,
19//! max_cost_usd: Some(1.0),
20//! });
21//!
22//! pool.record(1_000, 500, 0.05).unwrap(); // fits
23//! let err = pool.record(20_000, 0, 0.0).unwrap_err(); // input cap blown
24//! assert!(format!("{err}").contains("input_tokens"));
25//! ```
26
27#![deny(missing_docs)]
28
29use std::sync::Mutex;
30
31/// Caps for a single pool. Any cap left as `None` is unenforced.
32#[derive(Debug, Clone, Copy, Default, PartialEq)]
33pub struct Caps {
34 /// Cap on cumulative input tokens across all recorded calls.
35 pub max_input_tokens: Option<u64>,
36 /// Cap on cumulative output tokens across all recorded calls.
37 pub max_output_tokens: Option<u64>,
38 /// Cap on cumulative input + output tokens.
39 pub max_total_tokens: Option<u64>,
40 /// Cap on cumulative USD spend.
41 pub max_cost_usd: Option<f64>,
42}
43
44/// Running totals.
45#[derive(Debug, Clone, Copy, Default, PartialEq)]
46pub struct Totals {
47 /// Cumulative input tokens recorded.
48 pub input_tokens: u64,
49 /// Cumulative output tokens recorded.
50 pub output_tokens: u64,
51 /// Cumulative dollars recorded.
52 pub cost_usd: f64,
53 /// Number of `record` calls counted.
54 pub calls: u64,
55}
56
57impl Totals {
58 /// Sum of input + output tokens.
59 pub fn total_tokens(&self) -> u64 {
60 self.input_tokens + self.output_tokens
61 }
62}
63
64/// Error returned when a `record` call would push past a cap.
65///
66/// The error names the first cap that would be exceeded; subsequent caps
67/// may also be exceeded. **The pool's totals are NOT updated** when this
68/// error fires — the call is rejected outright.
69#[derive(Debug, Clone, Copy, PartialEq)]
70pub struct BudgetExceeded {
71 /// Which cap blew. One of `"input_tokens"`, `"output_tokens"`,
72 /// `"total_tokens"`, `"cost_usd"`.
73 pub cap: &'static str,
74 /// The cap limit that was breached.
75 pub limit: f64,
76 /// What the running total would have become.
77 pub attempted: f64,
78}
79
80impl std::fmt::Display for BudgetExceeded {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 write!(
83 f,
84 "budget cap `{}` exceeded: limit={}, attempted={}",
85 self.cap, self.limit, self.attempted
86 )
87 }
88}
89
90impl std::error::Error for BudgetExceeded {}
91
92/// Shared budget. Cheap to construct; record() takes a mutex.
93#[derive(Debug)]
94pub struct BudgetPool {
95 caps: Caps,
96 state: Mutex<Totals>,
97}
98
99impl BudgetPool {
100 /// Build a pool with the given caps. All caps default to `None`.
101 pub fn with_caps(caps: Caps) -> Self {
102 Self {
103 caps,
104 state: Mutex::new(Totals::default()),
105 }
106 }
107
108 /// Build an unconstrained pool (no caps).
109 pub fn unconstrained() -> Self {
110 Self::with_caps(Caps::default())
111 }
112
113 /// Record one call's usage. Returns the updated totals on success, or
114 /// [`BudgetExceeded`] (totals unchanged) on cap breach.
115 pub fn record(
116 &self,
117 input_tokens: u64,
118 output_tokens: u64,
119 cost_usd: f64,
120 ) -> Result<Totals, BudgetExceeded> {
121 let mut s = self.state.lock().unwrap();
122
123 let next_in = s.input_tokens + input_tokens;
124 let next_out = s.output_tokens + output_tokens;
125 let next_total = next_in + next_out;
126 let next_cost = s.cost_usd + cost_usd;
127
128 if let Some(cap) = self.caps.max_input_tokens {
129 if next_in > cap {
130 return Err(BudgetExceeded {
131 cap: "input_tokens",
132 limit: cap as f64,
133 attempted: next_in as f64,
134 });
135 }
136 }
137 if let Some(cap) = self.caps.max_output_tokens {
138 if next_out > cap {
139 return Err(BudgetExceeded {
140 cap: "output_tokens",
141 limit: cap as f64,
142 attempted: next_out as f64,
143 });
144 }
145 }
146 if let Some(cap) = self.caps.max_total_tokens {
147 if next_total > cap {
148 return Err(BudgetExceeded {
149 cap: "total_tokens",
150 limit: cap as f64,
151 attempted: next_total as f64,
152 });
153 }
154 }
155 if let Some(cap) = self.caps.max_cost_usd {
156 if next_cost > cap {
157 return Err(BudgetExceeded {
158 cap: "cost_usd",
159 limit: cap,
160 attempted: next_cost,
161 });
162 }
163 }
164
165 s.input_tokens = next_in;
166 s.output_tokens = next_out;
167 s.cost_usd = next_cost;
168 s.calls += 1;
169 Ok(*s)
170 }
171
172 /// Read current totals.
173 pub fn totals(&self) -> Totals {
174 *self.state.lock().unwrap()
175 }
176
177 /// Read the caps.
178 pub fn caps(&self) -> Caps {
179 self.caps
180 }
181
182 /// Reset the pool to zero totals (caps unchanged).
183 pub fn reset(&self) {
184 *self.state.lock().unwrap() = Totals::default();
185 }
186}