1use std::{
8 collections::BTreeMap,
9 sync::{
10 Arc, Mutex,
11 atomic::{AtomicU64, Ordering},
12 },
13};
14
15use thiserror::Error;
16
17use crate::{AgentError, RequestExtensions};
18
19#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
20pub struct UsageEstimate {
21 pub input_tokens: u64,
22 pub output_tokens: u64,
23 pub total_tokens: u64,
24 pub cost_micros_usd: u64,
25}
26
27impl UsageEstimate {
28 pub const fn zero() -> Self {
29 Self {
30 input_tokens: 0,
31 output_tokens: 0,
32 total_tokens: 0,
33 cost_micros_usd: 0,
34 }
35 }
36}
37
38#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
39pub struct Usage {
40 pub input_tokens: u64,
41 pub output_tokens: u64,
42 pub total_tokens: u64,
43 pub cost_micros_usd: u64,
44}
45
46impl Usage {
47 pub const fn zero() -> Self {
48 Self {
49 input_tokens: 0,
50 output_tokens: 0,
51 total_tokens: 0,
52 cost_micros_usd: 0,
53 }
54 }
55}
56
57#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
58pub struct Remaining {
59 pub tokens: u64,
60 pub cost_micros_usd: u64,
61 pub below_threshold: bool,
62}
63
64#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
65pub struct RequestBudget {
66 pub tokens: Option<u64>,
67 pub cost_micros_usd: Option<u64>,
68}
69
70impl RequestBudget {
71 pub const fn unlimited() -> Self {
72 Self {
73 tokens: None,
74 cost_micros_usd: None,
75 }
76 }
77
78 pub const fn from_tokens(tokens: u64) -> Self {
79 Self {
80 tokens: Some(tokens),
81 cost_micros_usd: None,
82 }
83 }
84
85 pub const fn from_cost_micros_usd(cost_micros_usd: u64) -> Self {
86 Self {
87 tokens: None,
88 cost_micros_usd: Some(cost_micros_usd),
89 }
90 }
91
92 pub const fn with_limits(tokens: Option<u64>, cost_micros_usd: Option<u64>) -> Self {
93 Self {
94 tokens,
95 cost_micros_usd,
96 }
97 }
98
99 fn allows(self, tokens: u64, cost_micros_usd: u64) -> bool {
100 self.tokens.is_none_or(|limit| tokens <= limit)
101 && self
102 .cost_micros_usd
103 .is_none_or(|limit| cost_micros_usd <= limit)
104 }
105}
106
107#[derive(Clone, Debug, Eq, PartialEq)]
108pub struct BudgetLease {
109 id: u64,
110 reserved: UsageEstimate,
111 request_budget: RequestBudget,
112}
113
114impl BudgetLease {
115 pub fn new(id: u64, reserved: UsageEstimate, request_budget: RequestBudget) -> Self {
116 Self {
117 id,
118 reserved,
119 request_budget,
120 }
121 }
122
123 pub fn id(&self) -> u64 {
124 self.id
125 }
126
127 pub fn reserved(&self) -> UsageEstimate {
128 self.reserved
129 }
130
131 pub fn request_budget(&self) -> RequestBudget {
132 self.request_budget
133 }
134}
135
136pub trait BudgetManager: Send + Sync + 'static {
137 fn remaining(&self, extensions: &RequestExtensions) -> Remaining;
138 fn reserve(
139 &self,
140 extensions: &RequestExtensions,
141 estimate: &UsageEstimate,
142 request_budget: RequestBudget,
143 ) -> Result<BudgetLease, AgentError>;
144 fn record_used(&self, lease: BudgetLease, usage: Usage) -> Result<(), AgentError>;
145}
146
147impl<T> BudgetManager for Arc<T>
148where
149 T: BudgetManager + ?Sized,
150{
151 fn remaining(&self, extensions: &RequestExtensions) -> Remaining {
152 (**self).remaining(extensions)
153 }
154
155 fn reserve(
156 &self,
157 extensions: &RequestExtensions,
158 estimate: &UsageEstimate,
159 request_budget: RequestBudget,
160 ) -> Result<BudgetLease, AgentError> {
161 (**self).reserve(extensions, estimate, request_budget)
162 }
163
164 fn record_used(&self, lease: BudgetLease, usage: Usage) -> Result<(), AgentError> {
165 (**self).record_used(lease, usage)
166 }
167}
168
169#[derive(Clone, Copy, Debug, Eq, PartialEq)]
170pub struct SharedPoolBudgetOptions {
171 pub capacity_tokens: u64,
172 pub capacity_cost_micros_usd: u64,
173 pub stop_threshold_tokens: u64,
174 pub stop_threshold_cost_micros_usd: u64,
175}
176
177impl Default for SharedPoolBudgetOptions {
178 fn default() -> Self {
179 Self {
180 capacity_tokens: u64::MAX,
181 capacity_cost_micros_usd: u64::MAX,
182 stop_threshold_tokens: 0,
183 stop_threshold_cost_micros_usd: 0,
184 }
185 }
186}
187
188#[derive(Debug, Error, Clone, Eq, PartialEq)]
189pub enum SharedPoolBudgetError {
190 #[error(
191 "request budget exceeded: requested {requested_tokens} tokens / {requested_cost_micros_usd} micros exceeds tokens={budget_tokens:?}, cost={budget_cost_micros_usd:?}"
192 )]
193 RequestBudgetExceeded {
194 requested_tokens: u64,
195 requested_cost_micros_usd: u64,
196 budget_tokens: Option<u64>,
197 budget_cost_micros_usd: Option<u64>,
198 },
199 #[error(
200 "reserving {requested_tokens} tokens / {requested_cost_micros_usd} micros would cross the stop threshold"
201 )]
202 ThresholdExceeded {
203 requested_tokens: u64,
204 requested_cost_micros_usd: u64,
205 remaining_tokens: u64,
206 remaining_cost_micros_usd: u64,
207 },
208 #[error("unknown budget lease {lease_id}")]
209 UnknownLease { lease_id: u64 },
210 #[error("shared budget state poisoned")]
211 Poisoned,
212}
213
214#[derive(Clone)]
215pub struct SharedPoolBudgetManager {
216 options: SharedPoolBudgetOptions,
217 next_lease_id: Arc<AtomicU64>,
218 state: Arc<Mutex<SharedPoolBudgetState>>,
219}
220
221#[derive(Debug, Default)]
222struct SharedPoolBudgetState {
223 committed_tokens: u64,
224 committed_cost_micros_usd: u64,
225 reserved_tokens: u64,
226 reserved_cost_micros_usd: u64,
227 leases: BTreeMap<u64, (UsageEstimate, RequestBudget)>,
228}
229
230impl SharedPoolBudgetManager {
231 pub fn new(options: SharedPoolBudgetOptions) -> Self {
232 Self {
233 options,
234 next_lease_id: Arc::new(AtomicU64::new(1)),
235 state: Arc::new(Mutex::new(SharedPoolBudgetState::default())),
236 }
237 }
238
239 fn remaining_with_state(&self, state: &SharedPoolBudgetState) -> Remaining {
240 let tokens = self
241 .options
242 .capacity_tokens
243 .saturating_sub(state.committed_tokens.saturating_add(state.reserved_tokens));
244 let cost_micros_usd = self.options.capacity_cost_micros_usd.saturating_sub(
245 state
246 .committed_cost_micros_usd
247 .saturating_add(state.reserved_cost_micros_usd),
248 );
249
250 Remaining {
251 tokens,
252 cost_micros_usd,
253 below_threshold: tokens <= self.options.stop_threshold_tokens
254 || cost_micros_usd <= self.options.stop_threshold_cost_micros_usd,
255 }
256 }
257}
258
259impl BudgetManager for SharedPoolBudgetManager {
260 fn remaining(&self, _extensions: &RequestExtensions) -> Remaining {
261 let state = self
262 .state
263 .lock()
264 .map_err(|_| AgentError::budget(SharedPoolBudgetError::Poisoned));
265 match state {
266 Ok(state) => self.remaining_with_state(&state),
267 Err(_) => Remaining {
268 tokens: 0,
269 cost_micros_usd: 0,
270 below_threshold: true,
271 },
272 }
273 }
274
275 fn reserve(
276 &self,
277 _extensions: &RequestExtensions,
278 estimate: &UsageEstimate,
279 request_budget: RequestBudget,
280 ) -> Result<BudgetLease, AgentError> {
281 if !request_budget.allows(estimate.total_tokens, estimate.cost_micros_usd) {
282 return Err(AgentError::budget(
283 SharedPoolBudgetError::RequestBudgetExceeded {
284 requested_tokens: estimate.total_tokens,
285 requested_cost_micros_usd: estimate.cost_micros_usd,
286 budget_tokens: request_budget.tokens,
287 budget_cost_micros_usd: request_budget.cost_micros_usd,
288 },
289 ));
290 }
291
292 let mut state = self
293 .state
294 .lock()
295 .map_err(|_| AgentError::budget(SharedPoolBudgetError::Poisoned))?;
296 let remaining = self.remaining_with_state(&state);
297
298 let remaining_after_tokens = remaining.tokens.saturating_sub(estimate.total_tokens);
299 let remaining_after_cost = remaining
300 .cost_micros_usd
301 .saturating_sub(estimate.cost_micros_usd);
302 let denied = estimate.total_tokens > remaining.tokens
303 || estimate.cost_micros_usd > remaining.cost_micros_usd
304 || remaining_after_tokens < self.options.stop_threshold_tokens
305 || remaining_after_cost < self.options.stop_threshold_cost_micros_usd;
306
307 if denied {
308 return Err(AgentError::budget(
309 SharedPoolBudgetError::ThresholdExceeded {
310 requested_tokens: estimate.total_tokens,
311 requested_cost_micros_usd: estimate.cost_micros_usd,
312 remaining_tokens: remaining.tokens,
313 remaining_cost_micros_usd: remaining.cost_micros_usd,
314 },
315 ));
316 }
317
318 let id = self.next_lease_id.fetch_add(1, Ordering::Relaxed);
319 state.reserved_tokens = state.reserved_tokens.saturating_add(estimate.total_tokens);
320 state.reserved_cost_micros_usd = state
321 .reserved_cost_micros_usd
322 .saturating_add(estimate.cost_micros_usd);
323 state.leases.insert(id, (*estimate, request_budget));
324
325 Ok(BudgetLease::new(id, *estimate, request_budget))
326 }
327
328 fn record_used(&self, lease: BudgetLease, usage: Usage) -> Result<(), AgentError> {
329 let mut state = self
330 .state
331 .lock()
332 .map_err(|_| AgentError::budget(SharedPoolBudgetError::Poisoned))?;
333 let Some((reserved, request_budget)) = state.leases.remove(&lease.id) else {
334 return Err(AgentError::budget(SharedPoolBudgetError::UnknownLease {
335 lease_id: lease.id,
336 }));
337 };
338
339 state.reserved_tokens = state.reserved_tokens.saturating_sub(reserved.total_tokens);
340 state.reserved_cost_micros_usd = state
341 .reserved_cost_micros_usd
342 .saturating_sub(reserved.cost_micros_usd);
343 state.committed_tokens = state.committed_tokens.saturating_add(usage.total_tokens);
344 state.committed_cost_micros_usd = state
345 .committed_cost_micros_usd
346 .saturating_add(usage.cost_micros_usd);
347
348 if !request_budget.allows(usage.total_tokens, usage.cost_micros_usd) {
349 return Err(AgentError::budget(
350 SharedPoolBudgetError::RequestBudgetExceeded {
351 requested_tokens: usage.total_tokens,
352 requested_cost_micros_usd: usage.cost_micros_usd,
353 budget_tokens: request_budget.tokens,
354 budget_cost_micros_usd: request_budget.cost_micros_usd,
355 },
356 ));
357 }
358
359 Ok(())
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 fn shared_pool_error(err: &AgentError) -> &SharedPoolBudgetError {
368 match err {
369 AgentError::Budget(source) => source
370 .downcast_ref::<SharedPoolBudgetError>()
371 .expect("shared pool budget error source"),
372 other => panic!("expected budget error, got {other}"),
373 }
374 }
375
376 #[test]
377 fn shared_pool_reserves_and_refunds_difference() {
378 let manager = SharedPoolBudgetManager::new(SharedPoolBudgetOptions {
379 capacity_tokens: 100,
380 capacity_cost_micros_usd: 1_000,
381 stop_threshold_tokens: 10,
382 stop_threshold_cost_micros_usd: 100,
383 });
384
385 let extensions = RequestExtensions::new();
386 let lease = manager
387 .reserve(
388 &extensions,
389 &UsageEstimate {
390 total_tokens: 20,
391 cost_micros_usd: 200,
392 ..UsageEstimate::zero()
393 },
394 RequestBudget::unlimited(),
395 )
396 .unwrap();
397 assert_eq!(manager.remaining(&extensions).tokens, 80);
398
399 manager
400 .record_used(
401 lease,
402 Usage {
403 total_tokens: 12,
404 cost_micros_usd: 120,
405 ..Usage::zero()
406 },
407 )
408 .unwrap();
409
410 let remaining = manager.remaining(&extensions);
411 assert_eq!(remaining.tokens, 88);
412 assert_eq!(remaining.cost_micros_usd, 880);
413 }
414
415 #[test]
416 fn shared_pool_blocks_when_threshold_would_be_crossed() {
417 let manager = SharedPoolBudgetManager::new(SharedPoolBudgetOptions {
418 capacity_tokens: 100,
419 capacity_cost_micros_usd: 1_000,
420 stop_threshold_tokens: 10,
421 stop_threshold_cost_micros_usd: 0,
422 });
423
424 let err = manager
425 .reserve(
426 &RequestExtensions::new(),
427 &UsageEstimate {
428 total_tokens: 91,
429 ..UsageEstimate::zero()
430 },
431 RequestBudget::unlimited(),
432 )
433 .unwrap_err();
434
435 assert!(matches!(
436 shared_pool_error(&err),
437 SharedPoolBudgetError::ThresholdExceeded { .. }
438 ));
439 }
440
441 #[test]
442 fn request_budget_can_restrict_reservation() {
443 let manager = SharedPoolBudgetManager::new(SharedPoolBudgetOptions::default());
444
445 let err = manager
446 .reserve(
447 &RequestExtensions::new(),
448 &UsageEstimate {
449 total_tokens: 32,
450 ..UsageEstimate::zero()
451 },
452 RequestBudget::from_tokens(16),
453 )
454 .unwrap_err();
455
456 assert!(matches!(
457 shared_pool_error(&err),
458 SharedPoolBudgetError::RequestBudgetExceeded {
459 requested_tokens: 32,
460 budget_tokens: Some(16),
461 ..
462 }
463 ));
464 }
465
466 #[test]
467 fn request_budget_can_fail_after_actual_usage_is_higher_than_estimate() {
468 let manager = SharedPoolBudgetManager::new(SharedPoolBudgetOptions::default());
469
470 let lease = manager
471 .reserve(
472 &RequestExtensions::new(),
473 &UsageEstimate {
474 total_tokens: 8,
475 ..UsageEstimate::zero()
476 },
477 RequestBudget::from_tokens(10),
478 )
479 .unwrap();
480
481 let err = manager
482 .record_used(
483 lease,
484 Usage {
485 total_tokens: 12,
486 ..Usage::zero()
487 },
488 )
489 .unwrap_err();
490
491 assert!(matches!(
492 shared_pool_error(&err),
493 SharedPoolBudgetError::RequestBudgetExceeded {
494 requested_tokens: 12,
495 budget_tokens: Some(10),
496 ..
497 }
498 ));
499 }
500}