1use std::sync::Arc;
45use std::sync::atomic::{AtomicU64, Ordering};
46
47use async_trait::async_trait;
48use tracing::{instrument, trace};
49
50#[derive(Debug, thiserror::Error)]
58pub enum BudgetError {
59 #[error("budget backend error: {0}")]
61 Backend(String),
62}
63
64#[async_trait]
71pub trait BudgetGuard: Send + Sync {
72 async fn try_reserve(&self, cost: u64) -> Result<bool, BudgetError>;
75
76 async fn release(&self, cost: u64);
78}
79
80pub type TokenRefund = Box<dyn FnOnce(u64) + Send + Sync>;
86
87pub struct TokenReservation {
101 estimate: u64,
102 refund: Option<TokenRefund>,
103}
104
105impl std::fmt::Debug for TokenReservation {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 f.debug_struct("TokenReservation")
108 .field("estimate", &self.estimate)
109 .field("armed", &self.refund.is_some())
110 .finish()
111 }
112}
113
114impl TokenReservation {
115 pub fn new(estimate: u64, refund: TokenRefund) -> Self {
118 Self {
119 estimate,
120 refund: Some(refund),
121 }
122 }
123
124 pub fn estimate(&self) -> u64 {
126 self.estimate
127 }
128
129 pub fn disarm(&mut self) -> Option<TokenRefund> {
138 self.refund.take()
139 }
140}
141
142impl Drop for TokenReservation {
143 fn drop(&mut self) {
144 if let Some(refund) = self.refund.take() {
145 refund(self.estimate);
146 }
147 }
148}
149
150#[async_trait]
163pub trait TokenBudget: Send + Sync {
164 async fn try_reserve_tokens(&self, est: u64) -> Result<Option<TokenReservation>, BudgetError>;
169
170 async fn record_usage(&self, reservation: TokenReservation, prompt: u64, completion: u64);
174
175 async fn tokens_consumed(&self) -> u64;
177}
178
179#[derive(Debug)]
184pub struct AtomicBudget {
185 capacity: u64,
186 available: AtomicU64,
187}
188
189impl AtomicBudget {
190 pub fn new(capacity: u64) -> Self {
192 Self {
193 capacity,
194 available: AtomicU64::new(capacity),
195 }
196 }
197
198 pub fn capacity(&self) -> u64 {
200 self.capacity
201 }
202
203 pub fn available(&self) -> u64 {
205 self.available.load(Ordering::Acquire)
206 }
207
208 pub fn utilization(&self) -> f64 {
210 if self.capacity == 0 {
211 return 0.0;
212 }
213 let used = self.capacity.saturating_sub(self.available());
214 used as f64 / self.capacity as f64
215 }
216
217 pub fn refill(&self) {
220 self.available.store(self.capacity, Ordering::Release);
221 }
222}
223
224#[async_trait]
225impl BudgetGuard for AtomicBudget {
226 #[instrument(name = "rig_compose.budget.try_reserve", skip(self), fields(cost))]
227 async fn try_reserve(&self, cost: u64) -> Result<bool, BudgetError> {
228 let mut current = self.available.load(Ordering::Acquire);
229 loop {
230 if current < cost {
231 trace!(current, "budget would be exceeded");
232 return Ok(false);
233 }
234 match self.available.compare_exchange_weak(
235 current,
236 current - cost,
237 Ordering::AcqRel,
238 Ordering::Acquire,
239 ) {
240 Ok(_) => return Ok(true),
241 Err(observed) => current = observed,
242 }
243 }
244 }
245
246 #[instrument(name = "rig_compose.budget.release", skip(self), fields(cost))]
247 async fn release(&self, cost: u64) {
248 let mut current = self.available.load(Ordering::Acquire);
249 loop {
250 let next = current.saturating_add(cost).min(self.capacity);
251 match self.available.compare_exchange_weak(
252 current,
253 next,
254 Ordering::AcqRel,
255 Ordering::Acquire,
256 ) {
257 Ok(_) => return,
258 Err(observed) => current = observed,
259 }
260 }
261 }
262}
263
264#[derive(Debug)]
277pub struct AtomicTokenBudget {
278 inner: Arc<AtomicTokenBudgetInner>,
279}
280
281#[derive(Debug)]
282struct AtomicTokenBudgetInner {
283 capacity: u64,
284 available: AtomicU64,
285 consumed: AtomicU64,
286}
287
288impl AtomicTokenBudgetInner {
289 fn refund(&self, amount: u64) {
290 if amount == 0 {
291 return;
292 }
293 let mut current = self.available.load(Ordering::Acquire);
294 loop {
295 let next = current.saturating_add(amount).min(self.capacity);
296 match self.available.compare_exchange_weak(
297 current,
298 next,
299 Ordering::AcqRel,
300 Ordering::Acquire,
301 ) {
302 Ok(_) => return,
303 Err(observed) => current = observed,
304 }
305 }
306 }
307
308 fn debit(&self, amount: u64) {
309 if amount == 0 {
310 return;
311 }
312 let mut current = self.available.load(Ordering::Acquire);
313 loop {
314 let next = current.saturating_sub(amount);
315 match self.available.compare_exchange_weak(
316 current,
317 next,
318 Ordering::AcqRel,
319 Ordering::Acquire,
320 ) {
321 Ok(_) => return,
322 Err(observed) => current = observed,
323 }
324 }
325 }
326}
327
328impl AtomicTokenBudget {
329 pub fn new(capacity: u64) -> Self {
331 Self {
332 inner: Arc::new(AtomicTokenBudgetInner {
333 capacity,
334 available: AtomicU64::new(capacity),
335 consumed: AtomicU64::new(0),
336 }),
337 }
338 }
339
340 pub fn capacity(&self) -> u64 {
342 self.inner.capacity
343 }
344
345 pub fn available(&self) -> u64 {
347 self.inner.available.load(Ordering::Acquire)
348 }
349}
350
351#[async_trait]
352impl TokenBudget for AtomicTokenBudget {
353 #[instrument(name = "rig_compose.token_budget.try_reserve", skip(self), fields(est))]
354 async fn try_reserve_tokens(&self, est: u64) -> Result<Option<TokenReservation>, BudgetError> {
355 let mut current = self.inner.available.load(Ordering::Acquire);
356 loop {
357 if current < est {
358 trace!(current, "token budget would be exceeded");
359 return Ok(None);
360 }
361 match self.inner.available.compare_exchange_weak(
362 current,
363 current - est,
364 Ordering::AcqRel,
365 Ordering::Acquire,
366 ) {
367 Ok(_) => {
368 let weak = Arc::downgrade(&self.inner);
369 let refund: TokenRefund = Box::new(move |amount| {
370 if let Some(inner) = weak.upgrade() {
371 inner.refund(amount);
372 }
373 });
374 return Ok(Some(TokenReservation::new(est, refund)));
375 }
376 Err(observed) => current = observed,
377 }
378 }
379 }
380
381 #[instrument(name = "rig_compose.token_budget.record_usage", skip(self))]
382 async fn record_usage(&self, mut reservation: TokenReservation, prompt: u64, completion: u64) {
383 let actual = prompt.saturating_add(completion);
384 self.inner.consumed.fetch_add(actual, Ordering::AcqRel);
385 let _ = reservation.disarm();
388 let estimate = reservation.estimate();
389 if estimate >= actual {
390 self.inner.refund(estimate - actual);
393 } else {
394 self.inner.debit(actual - estimate);
398 }
399 }
400
401 async fn tokens_consumed(&self) -> u64 {
402 self.inner.consumed.load(Ordering::Acquire)
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 #[tokio::test]
411 async fn reserve_until_empty() {
412 let b = AtomicBudget::new(100);
413 assert!(b.try_reserve(60).await.unwrap());
414 assert!(b.try_reserve(40).await.unwrap());
415 assert!(!b.try_reserve(1).await.unwrap());
416 b.release(50).await;
417 assert!(b.try_reserve(50).await.unwrap());
418 }
419
420 #[tokio::test]
421 async fn release_caps_at_capacity() {
422 let b = AtomicBudget::new(10);
423 assert!(b.try_reserve(5).await.unwrap());
424 b.release(100).await;
425 assert_eq!(b.available(), 10);
426 }
427
428 #[tokio::test]
429 async fn refill_restores_capacity() {
430 let b = AtomicBudget::new(100);
431 assert!(b.try_reserve(75).await.unwrap());
432 assert_eq!(b.available(), 25);
433 b.refill();
434 assert_eq!(b.available(), 100);
435 }
436
437 #[tokio::test]
438 async fn utilization_tracks_consumption() {
439 let b = AtomicBudget::new(100);
440 assert!((b.utilization() - 0.0).abs() < f64::EPSILON);
441 assert!(b.try_reserve(40).await.unwrap());
442 assert!((b.utilization() - 0.4).abs() < f64::EPSILON);
443 }
444
445 #[tokio::test]
446 async fn token_budget_reserves_records_and_reports() {
447 let tb = AtomicTokenBudget::new(1_000);
448 let reservation = tb.try_reserve_tokens(400).await.unwrap().unwrap();
449 tb.record_usage(reservation, 120, 80).await;
450 assert_eq!(tb.tokens_consumed().await, 200);
451 let _hold = tb.try_reserve_tokens(800).await.unwrap().unwrap();
454 assert!(tb.try_reserve_tokens(1).await.unwrap().is_none());
455 }
456
457 #[tokio::test]
458 async fn token_budget_debits_overage() {
459 let tb = AtomicTokenBudget::new(1_000);
460 let reservation = tb.try_reserve_tokens(100).await.unwrap().unwrap();
461 tb.record_usage(reservation, 150, 50).await;
462 assert_eq!(tb.tokens_consumed().await, 200);
463 assert_eq!(tb.available(), 800);
464 }
465
466 #[tokio::test]
467 async fn token_budget_reconciles_each_reservation_independently() {
468 let tb = AtomicTokenBudget::new(1_000);
469 let first = tb.try_reserve_tokens(400).await.unwrap().unwrap();
470 let second = tb.try_reserve_tokens(400).await.unwrap().unwrap();
471 assert_eq!(tb.available(), 200);
472
473 tb.record_usage(first, 100, 100).await;
474 assert_eq!(tb.available(), 400);
475 assert!(tb.try_reserve_tokens(401).await.unwrap().is_none());
476
477 tb.record_usage(second, 200, 200).await;
478 assert_eq!(tb.available(), 400);
479 assert_eq!(tb.tokens_consumed().await, 600);
480 }
481
482 #[tokio::test]
483 async fn token_reservation_reports_estimate() {
484 let tb = AtomicTokenBudget::new(10);
485 let reservation = tb.try_reserve_tokens(7).await.unwrap().unwrap();
486 assert_eq!(reservation.estimate(), 7);
487 }
488
489 #[tokio::test]
490 async fn token_reservation_refunds_on_drop() {
491 let tb = AtomicTokenBudget::new(1_000);
492 {
493 let _reservation = tb.try_reserve_tokens(400).await.unwrap().unwrap();
494 assert_eq!(tb.available(), 600);
495 } assert_eq!(tb.available(), 1_000);
497 assert_eq!(tb.tokens_consumed().await, 0);
498 }
499
500 #[tokio::test]
501 async fn token_reservation_refund_is_capped_at_capacity() {
502 let tb = AtomicTokenBudget::new(100);
503 let r = tb.try_reserve_tokens(40).await.unwrap().unwrap();
504 drop(r);
506 assert_eq!(tb.available(), 100);
507 }
508}