1use std::sync::atomic::{AtomicU64, Ordering};
45use std::sync::{Arc, Mutex};
46
47use async_trait::async_trait;
48use tracing::{instrument, trace};
49
50use crate::normalizer::{
51 ToolDispatchAction, ToolDispatchHook, ToolInvocation, ToolInvocationResult,
52};
53use crate::registry::KernelError;
54
55#[derive(Debug, thiserror::Error)]
63pub enum BudgetError {
64 #[error("budget backend error: {0}")]
66 Backend(String),
67}
68
69#[async_trait]
76pub trait BudgetGuard: Send + Sync {
77 async fn try_reserve(&self, cost: u64) -> Result<bool, BudgetError>;
80
81 async fn release(&self, cost: u64);
83}
84
85#[derive(Debug)]
92pub struct DispatchBudgetHook<G>
93where
94 G: BudgetGuard,
95{
96 guard: Arc<G>,
97 cost_per_invocation: u64,
98 reserved: Mutex<u64>,
99}
100
101impl<G> DispatchBudgetHook<G>
102where
103 G: BudgetGuard,
104{
105 pub fn new(guard: Arc<G>, cost_per_invocation: u64) -> Self {
107 Self {
108 guard,
109 cost_per_invocation,
110 reserved: Mutex::new(0),
111 }
112 }
113
114 pub fn cost_per_invocation(&self) -> u64 {
116 self.cost_per_invocation
117 }
118
119 pub fn guard(&self) -> &Arc<G> {
121 &self.guard
122 }
123
124 async fn release_one(&self) {
125 let should_release = {
126 let mut reserved = self
127 .reserved
128 .lock()
129 .unwrap_or_else(|poisoned| poisoned.into_inner());
130 if *reserved == 0 {
131 false
132 } else {
133 *reserved -= 1;
134 true
135 }
136 };
137
138 if should_release {
139 self.guard.release(self.cost_per_invocation).await;
140 }
141 }
142}
143
144#[async_trait]
145impl<G> ToolDispatchHook for DispatchBudgetHook<G>
146where
147 G: BudgetGuard,
148{
149 async fn before_invocation(
150 &self,
151 invocation: &ToolInvocation,
152 ) -> Result<ToolDispatchAction, KernelError> {
153 let reserved = self
154 .guard
155 .try_reserve(self.cost_per_invocation)
156 .await
157 .map_err(|error| KernelError::BudgetFailed(error.to_string()))?;
158
159 if reserved {
160 let mut count = self
161 .reserved
162 .lock()
163 .unwrap_or_else(|poisoned| poisoned.into_inner());
164 *count = count.saturating_add(1);
165 Ok(ToolDispatchAction::Continue)
166 } else {
167 Ok(ToolDispatchAction::Terminate {
168 reason: format!(
169 "budget denied `{}` at cost {}",
170 invocation.name, self.cost_per_invocation
171 ),
172 })
173 }
174 }
175
176 async fn after_invocation(&self, _result: &ToolInvocationResult) -> Result<(), KernelError> {
177 self.release_one().await;
178 Ok(())
179 }
180
181 async fn on_invocation_error(
182 &self,
183 _invocation: &ToolInvocation,
184 _error: &KernelError,
185 ) -> Result<(), KernelError> {
186 self.release_one().await;
187 Ok(())
188 }
189}
190
191pub type TokenRefund = Box<dyn FnOnce(u64) + Send + Sync>;
197
198pub struct TokenReservation {
212 estimate: u64,
213 refund: Option<TokenRefund>,
214}
215
216impl std::fmt::Debug for TokenReservation {
217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218 f.debug_struct("TokenReservation")
219 .field("estimate", &self.estimate)
220 .field("armed", &self.refund.is_some())
221 .finish()
222 }
223}
224
225impl TokenReservation {
226 pub fn new(estimate: u64, refund: TokenRefund) -> Self {
229 Self {
230 estimate,
231 refund: Some(refund),
232 }
233 }
234
235 pub fn estimate(&self) -> u64 {
237 self.estimate
238 }
239
240 pub fn disarm(&mut self) -> Option<TokenRefund> {
249 self.refund.take()
250 }
251}
252
253impl Drop for TokenReservation {
254 fn drop(&mut self) {
255 if let Some(refund) = self.refund.take() {
256 refund(self.estimate);
257 }
258 }
259}
260
261#[async_trait]
274pub trait TokenBudget: Send + Sync {
275 async fn try_reserve_tokens(&self, est: u64) -> Result<Option<TokenReservation>, BudgetError>;
280
281 async fn record_usage(&self, reservation: TokenReservation, prompt: u64, completion: u64);
285
286 async fn tokens_consumed(&self) -> u64;
288}
289
290#[derive(Debug)]
295pub struct AtomicBudget {
296 capacity: u64,
297 available: AtomicU64,
298}
299
300impl AtomicBudget {
301 pub fn new(capacity: u64) -> Self {
303 Self {
304 capacity,
305 available: AtomicU64::new(capacity),
306 }
307 }
308
309 pub fn capacity(&self) -> u64 {
311 self.capacity
312 }
313
314 pub fn available(&self) -> u64 {
316 self.available.load(Ordering::Acquire)
317 }
318
319 pub fn utilization(&self) -> f64 {
321 if self.capacity == 0 {
322 return 0.0;
323 }
324 let used = self.capacity.saturating_sub(self.available());
325 used as f64 / self.capacity as f64
326 }
327
328 pub fn refill(&self) {
331 self.available.store(self.capacity, Ordering::Release);
332 }
333}
334
335#[async_trait]
336impl BudgetGuard for AtomicBudget {
337 #[instrument(name = "rig_compose.budget.try_reserve", skip(self), fields(cost))]
338 async fn try_reserve(&self, cost: u64) -> Result<bool, BudgetError> {
339 let mut current = self.available.load(Ordering::Acquire);
340 loop {
341 if current < cost {
342 trace!(current, "budget would be exceeded");
343 return Ok(false);
344 }
345 match self.available.compare_exchange_weak(
346 current,
347 current - cost,
348 Ordering::AcqRel,
349 Ordering::Acquire,
350 ) {
351 Ok(_) => return Ok(true),
352 Err(observed) => current = observed,
353 }
354 }
355 }
356
357 #[instrument(name = "rig_compose.budget.release", skip(self), fields(cost))]
358 async fn release(&self, cost: u64) {
359 let mut current = self.available.load(Ordering::Acquire);
360 loop {
361 let next = current.saturating_add(cost).min(self.capacity);
362 match self.available.compare_exchange_weak(
363 current,
364 next,
365 Ordering::AcqRel,
366 Ordering::Acquire,
367 ) {
368 Ok(_) => return,
369 Err(observed) => current = observed,
370 }
371 }
372 }
373}
374
375#[derive(Debug)]
388pub struct AtomicTokenBudget {
389 inner: Arc<AtomicTokenBudgetInner>,
390}
391
392#[derive(Debug)]
393struct AtomicTokenBudgetInner {
394 capacity: u64,
395 available: AtomicU64,
396 consumed: AtomicU64,
397}
398
399impl AtomicTokenBudgetInner {
400 fn refund(&self, amount: u64) {
401 if amount == 0 {
402 return;
403 }
404 let mut current = self.available.load(Ordering::Acquire);
405 loop {
406 let next = current.saturating_add(amount).min(self.capacity);
407 match self.available.compare_exchange_weak(
408 current,
409 next,
410 Ordering::AcqRel,
411 Ordering::Acquire,
412 ) {
413 Ok(_) => return,
414 Err(observed) => current = observed,
415 }
416 }
417 }
418
419 fn debit(&self, amount: u64) {
420 if amount == 0 {
421 return;
422 }
423 let mut current = self.available.load(Ordering::Acquire);
424 loop {
425 let next = current.saturating_sub(amount);
426 match self.available.compare_exchange_weak(
427 current,
428 next,
429 Ordering::AcqRel,
430 Ordering::Acquire,
431 ) {
432 Ok(_) => return,
433 Err(observed) => current = observed,
434 }
435 }
436 }
437}
438
439impl AtomicTokenBudget {
440 pub fn new(capacity: u64) -> Self {
442 Self {
443 inner: Arc::new(AtomicTokenBudgetInner {
444 capacity,
445 available: AtomicU64::new(capacity),
446 consumed: AtomicU64::new(0),
447 }),
448 }
449 }
450
451 pub fn capacity(&self) -> u64 {
453 self.inner.capacity
454 }
455
456 pub fn available(&self) -> u64 {
458 self.inner.available.load(Ordering::Acquire)
459 }
460}
461
462#[async_trait]
463impl TokenBudget for AtomicTokenBudget {
464 #[instrument(name = "rig_compose.token_budget.try_reserve", skip(self), fields(est))]
465 async fn try_reserve_tokens(&self, est: u64) -> Result<Option<TokenReservation>, BudgetError> {
466 let mut current = self.inner.available.load(Ordering::Acquire);
467 loop {
468 if current < est {
469 trace!(current, "token budget would be exceeded");
470 return Ok(None);
471 }
472 match self.inner.available.compare_exchange_weak(
473 current,
474 current - est,
475 Ordering::AcqRel,
476 Ordering::Acquire,
477 ) {
478 Ok(_) => {
479 let weak = Arc::downgrade(&self.inner);
480 let refund: TokenRefund = Box::new(move |amount| {
481 if let Some(inner) = weak.upgrade() {
482 inner.refund(amount);
483 }
484 });
485 return Ok(Some(TokenReservation::new(est, refund)));
486 }
487 Err(observed) => current = observed,
488 }
489 }
490 }
491
492 #[instrument(name = "rig_compose.token_budget.record_usage", skip(self))]
493 async fn record_usage(&self, mut reservation: TokenReservation, prompt: u64, completion: u64) {
494 let actual = prompt.saturating_add(completion);
495 self.inner.consumed.fetch_add(actual, Ordering::AcqRel);
496 let _ = reservation.disarm();
499 let estimate = reservation.estimate();
500 if estimate >= actual {
501 self.inner.refund(estimate - actual);
504 } else {
505 self.inner.debit(actual - estimate);
509 }
510 }
511
512 async fn tokens_consumed(&self) -> u64 {
513 self.inner.consumed.load(Ordering::Acquire)
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 #[tokio::test]
522 async fn reserve_until_empty() {
523 let b = AtomicBudget::new(100);
524 assert!(b.try_reserve(60).await.unwrap());
525 assert!(b.try_reserve(40).await.unwrap());
526 assert!(!b.try_reserve(1).await.unwrap());
527 b.release(50).await;
528 assert!(b.try_reserve(50).await.unwrap());
529 }
530
531 #[tokio::test]
532 async fn release_caps_at_capacity() {
533 let b = AtomicBudget::new(10);
534 assert!(b.try_reserve(5).await.unwrap());
535 b.release(100).await;
536 assert_eq!(b.available(), 10);
537 }
538
539 #[tokio::test]
540 async fn refill_restores_capacity() {
541 let b = AtomicBudget::new(100);
542 assert!(b.try_reserve(75).await.unwrap());
543 assert_eq!(b.available(), 25);
544 b.refill();
545 assert_eq!(b.available(), 100);
546 }
547
548 #[tokio::test]
549 async fn utilization_tracks_consumption() {
550 let b = AtomicBudget::new(100);
551 assert!((b.utilization() - 0.0).abs() < f64::EPSILON);
552 assert!(b.try_reserve(40).await.unwrap());
553 assert!((b.utilization() - 0.4).abs() < f64::EPSILON);
554 }
555
556 #[tokio::test]
557 async fn token_budget_reserves_records_and_reports() {
558 let tb = AtomicTokenBudget::new(1_000);
559 let reservation = tb.try_reserve_tokens(400).await.unwrap().unwrap();
560 tb.record_usage(reservation, 120, 80).await;
561 assert_eq!(tb.tokens_consumed().await, 200);
562 let _hold = tb.try_reserve_tokens(800).await.unwrap().unwrap();
565 assert!(tb.try_reserve_tokens(1).await.unwrap().is_none());
566 }
567
568 #[tokio::test]
569 async fn token_budget_debits_overage() {
570 let tb = AtomicTokenBudget::new(1_000);
571 let reservation = tb.try_reserve_tokens(100).await.unwrap().unwrap();
572 tb.record_usage(reservation, 150, 50).await;
573 assert_eq!(tb.tokens_consumed().await, 200);
574 assert_eq!(tb.available(), 800);
575 }
576
577 #[tokio::test]
578 async fn token_budget_reconciles_each_reservation_independently() {
579 let tb = AtomicTokenBudget::new(1_000);
580 let first = tb.try_reserve_tokens(400).await.unwrap().unwrap();
581 let second = tb.try_reserve_tokens(400).await.unwrap().unwrap();
582 assert_eq!(tb.available(), 200);
583
584 tb.record_usage(first, 100, 100).await;
585 assert_eq!(tb.available(), 400);
586 assert!(tb.try_reserve_tokens(401).await.unwrap().is_none());
587
588 tb.record_usage(second, 200, 200).await;
589 assert_eq!(tb.available(), 400);
590 assert_eq!(tb.tokens_consumed().await, 600);
591 }
592
593 #[tokio::test]
594 async fn token_reservation_reports_estimate() {
595 let tb = AtomicTokenBudget::new(10);
596 let reservation = tb.try_reserve_tokens(7).await.unwrap().unwrap();
597 assert_eq!(reservation.estimate(), 7);
598 }
599
600 #[tokio::test]
601 async fn token_reservation_refunds_on_drop() {
602 let tb = AtomicTokenBudget::new(1_000);
603 {
604 let _reservation = tb.try_reserve_tokens(400).await.unwrap().unwrap();
605 assert_eq!(tb.available(), 600);
606 } assert_eq!(tb.available(), 1_000);
608 assert_eq!(tb.tokens_consumed().await, 0);
609 }
610
611 #[tokio::test]
612 async fn token_reservation_refund_is_capped_at_capacity() {
613 let tb = AtomicTokenBudget::new(100);
614 let r = tb.try_reserve_tokens(40).await.unwrap().unwrap();
615 drop(r);
617 assert_eq!(tb.available(), 100);
618 }
619}