Skip to main content

rig_compose/
budget.rs

1//! Cost-bounded coordination primitives.
2//!
3//! The kernel often needs to gate dispatch on a finite budget — rows
4//! parsed, LLM tokens spent, dollars burned, etc. This module exposes
5//! two domain-neutral traits and lock-free reference implementations:
6//!
7//! - [`BudgetGuard`] — generic unit-cost reservations (rows, requests,
8//!   queue slots).
9//! - [`TokenBudget`] — LLM-token accounting with optimistic reservation
10//!   and after-the-fact reconciliation against provider-reported usage.
11//!
12//! Both traits are intentionally narrow so a single coordinator can
13//! compose multiple budget policies (e.g. a row guard *and* a token
14//! guard) without coupling to any particular backend.
15//!
16//! ## Why two traits?
17//!
18//! [`BudgetGuard`] is a symmetric reserve/release pair: callers know the
19//! cost up front and either commit or roll back. [`TokenBudget`] is
20//! optimistic: callers reserve an estimate and receive a
21//! [`TokenReservation`], send the prompt, then pass that reservation to
22//! [`TokenBudget::record_usage`] with the provider's reported totals to
23//! reconcile the over- or under-estimate for that specific call.
24//!
25//! ## Reference implementations
26//!
27//! [`AtomicBudget`] and [`AtomicTokenBudget`] are lock-free token-bucket
28//! counters built on `AtomicU64::compare_exchange_weak`. They are safe
29//! for high-contention dispatch loops and refill on demand via
30//! [`AtomicBudget::refill`].
31//!
32//! ```no_run
33//! use rig_compose::budget::{AtomicBudget, BudgetGuard};
34//!
35//! # async fn run() -> Result<(), Box<dyn std::error::Error>> {
36//! let budget = AtomicBudget::new(1_000);
37//! if budget.try_reserve(250).await? {
38//!     // dispatch happens here ...
39//!     budget.release(250).await;
40//! }
41//! # Ok(()) }
42//! ```
43
44use std::sync::atomic::{AtomicU64, Ordering};
45use std::sync::{Arc, Mutex};
46
47use async_trait::async_trait;
48use tracing::{instrument, trace};
49
50use crate::normalizer::{
51    ToolDispatchAction, ToolDispatchHook, ToolInvocation, ToolInvocationResult,
52};
53use crate::registry::KernelError;
54
55/// Errors a budget implementation may surface.
56///
57/// Soft denial (the budget would be exceeded) is signalled by
58/// `Ok(false)` from [`BudgetGuard::try_reserve`] /
59/// [`TokenBudget::try_reserve_tokens`], not via this error. This enum is
60/// reserved for *infrastructure* failures — a remote budget store
61/// timing out, a persistence layer rejecting a write, etc.
62#[derive(Debug, thiserror::Error)]
63pub enum BudgetError {
64    /// The backing store could not service the request.
65    #[error("budget backend error: {0}")]
66    Backend(String),
67}
68
69/// Symmetric reserve/release budget guard.
70///
71/// Implementations gate work on a finite resource pool (rows parsed,
72/// dispatch slots, etc.). Returning `Ok(false)` from
73/// [`BudgetGuard::try_reserve`] is a soft denial — the caller should
74/// back off rather than treat it as an error.
75#[async_trait]
76pub trait BudgetGuard: Send + Sync {
77    /// Try to reserve `cost` units of budget for an upcoming operation.
78    /// Returns `Ok(true)` on success, `Ok(false)` on soft denial.
79    async fn try_reserve(&self, cost: u64) -> Result<bool, BudgetError>;
80
81    /// Release previously reserved units back to the pool.
82    async fn release(&self, cost: u64);
83}
84
85/// Dispatch hook that gates each normalized tool invocation on a [`BudgetGuard`].
86///
87/// The hook reserves `cost_per_invocation` before each call. If the budget
88/// denies the reservation, dispatch terminates before the tool runs. Successful
89/// reservations are released after the invocation completes, after a synthetic
90/// skip result is recorded, or when dispatch stops with an error.
91#[derive(Debug)]
92pub struct DispatchBudgetHook<G>
93where
94    G: BudgetGuard,
95{
96    guard: Arc<G>,
97    cost_per_invocation: u64,
98    reserved: Mutex<u64>,
99}
100
101impl<G> DispatchBudgetHook<G>
102where
103    G: BudgetGuard,
104{
105    /// Create a dispatch-budget hook over `guard`.
106    pub fn new(guard: Arc<G>, cost_per_invocation: u64) -> Self {
107        Self {
108            guard,
109            cost_per_invocation,
110            reserved: Mutex::new(0),
111        }
112    }
113
114    /// Units reserved for each invocation.
115    pub fn cost_per_invocation(&self) -> u64 {
116        self.cost_per_invocation
117    }
118
119    /// Underlying budget guard.
120    pub fn guard(&self) -> &Arc<G> {
121        &self.guard
122    }
123
124    async fn release_one(&self) {
125        let should_release = {
126            let mut reserved = self
127                .reserved
128                .lock()
129                .unwrap_or_else(|poisoned| poisoned.into_inner());
130            if *reserved == 0 {
131                false
132            } else {
133                *reserved -= 1;
134                true
135            }
136        };
137
138        if should_release {
139            self.guard.release(self.cost_per_invocation).await;
140        }
141    }
142}
143
144#[async_trait]
145impl<G> ToolDispatchHook for DispatchBudgetHook<G>
146where
147    G: BudgetGuard,
148{
149    async fn before_invocation(
150        &self,
151        invocation: &ToolInvocation,
152    ) -> Result<ToolDispatchAction, KernelError> {
153        let reserved = self
154            .guard
155            .try_reserve(self.cost_per_invocation)
156            .await
157            .map_err(|error| KernelError::BudgetFailed(error.to_string()))?;
158
159        if reserved {
160            let mut count = self
161                .reserved
162                .lock()
163                .unwrap_or_else(|poisoned| poisoned.into_inner());
164            *count = count.saturating_add(1);
165            Ok(ToolDispatchAction::Continue)
166        } else {
167            Ok(ToolDispatchAction::Terminate {
168                reason: format!(
169                    "budget denied `{}` at cost {}",
170                    invocation.name, self.cost_per_invocation
171                ),
172            })
173        }
174    }
175
176    async fn after_invocation(&self, _result: &ToolInvocationResult) -> Result<(), KernelError> {
177        self.release_one().await;
178        Ok(())
179    }
180
181    async fn on_invocation_error(
182        &self,
183        _invocation: &ToolInvocation,
184        _error: &KernelError,
185    ) -> Result<(), KernelError> {
186        self.release_one().await;
187        Ok(())
188    }
189}
190
191/// Refund channel invoked when a [`TokenReservation`] is dropped
192/// without being passed to [`TokenBudget::record_usage`].
193///
194/// Implementations call this once with the reservation's estimate to
195/// return the optimistic deduction to the underlying pool.
196pub type TokenRefund = Box<dyn FnOnce(u64) + Send + Sync>;
197
198/// Reservation handle returned by [`TokenBudget::try_reserve_tokens`].
199///
200/// The handle carries the estimate for one model call so reconciliation
201/// is per-call even when multiple prompts are outstanding concurrently.
202///
203/// # Cancellation semantics
204///
205/// Dropping a `TokenReservation` without calling
206/// [`TokenBudget::record_usage`] is treated as cancellation: the full
207/// estimate is refunded to the budget via the closure supplied at
208/// construction. Implementations of [`TokenBudget::record_usage`] must
209/// call [`TokenReservation::disarm`] to suppress the refund-on-drop
210/// before performing their own reconciliation.
211pub struct TokenReservation {
212    estimate: u64,
213    refund: Option<TokenRefund>,
214}
215
216impl std::fmt::Debug for TokenReservation {
217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218        f.debug_struct("TokenReservation")
219            .field("estimate", &self.estimate)
220            .field("armed", &self.refund.is_some())
221            .finish()
222    }
223}
224
225impl TokenReservation {
226    /// Construct a reservation. Intended for [`TokenBudget`]
227    /// implementations only.
228    pub fn new(estimate: u64, refund: TokenRefund) -> Self {
229        Self {
230            estimate,
231            refund: Some(refund),
232        }
233    }
234
235    /// Estimated tokens reserved for this call.
236    pub fn estimate(&self) -> u64 {
237        self.estimate
238    }
239
240    /// Take ownership of the refund closure, suppressing the
241    /// refund-on-drop. Returns `None` if the reservation has already
242    /// been disarmed (which would indicate misuse).
243    ///
244    /// [`TokenBudget::record_usage`] implementations must disarm the
245    /// reservation they receive before reconciling against actual
246    /// usage; otherwise the refund-on-drop would double-credit the
247    /// pool.
248    pub fn disarm(&mut self) -> Option<TokenRefund> {
249        self.refund.take()
250    }
251}
252
253impl Drop for TokenReservation {
254    fn drop(&mut self) {
255        if let Some(refund) = self.refund.take() {
256            refund(self.estimate);
257        }
258    }
259}
260
261/// Soft-cap on cumulative LLM token spend.
262///
263/// `BudgetGuard` constrains *unit-cost work* (rows, dispatches);
264/// `TokenBudget` constrains *tokens burned by a model call*. A single
265/// prompt is typically a handful of tokens; a multi-round tool-call
266/// loop can be 10–100× larger.
267///
268/// Implementations are expected to be cheap and lock-free in the hot
269/// path. [`TokenBudget::try_reserve_tokens`] is called *before* a
270/// prompt is sent; [`TokenBudget::record_usage`] is called *after* with
271/// the observed totals from the provider so the next reservation
272/// reflects reality.
273#[async_trait]
274pub trait TokenBudget: Send + Sync {
275    /// Reserve `est` prompt+completion tokens optimistically.
276    ///
277    /// Returns `Ok(Some(reservation))` on success and `Ok(None)` on soft
278    /// denial.
279    async fn try_reserve_tokens(&self, est: u64) -> Result<Option<TokenReservation>, BudgetError>;
280
281    /// Record the actual prompt + completion token usage from a
282    /// finished call. The implementation reconciles the supplied
283    /// reservation against the observed usage.
284    async fn record_usage(&self, reservation: TokenReservation, prompt: u64, completion: u64);
285
286    /// Total tokens consumed since construction (prompt + completion).
287    async fn tokens_consumed(&self) -> u64;
288}
289
290/// Atomic, lock-free token-bucket budget.
291///
292/// Reference implementation of [`BudgetGuard`]. Backed by a single
293/// `AtomicU64` counter and CAS retries.
294#[derive(Debug)]
295pub struct AtomicBudget {
296    capacity: u64,
297    available: AtomicU64,
298}
299
300impl AtomicBudget {
301    /// Create a budget with the given capacity, initially full.
302    pub fn new(capacity: u64) -> Self {
303        Self {
304            capacity,
305            available: AtomicU64::new(capacity),
306        }
307    }
308
309    /// Total capacity at construction time.
310    pub fn capacity(&self) -> u64 {
311        self.capacity
312    }
313
314    /// Currently available units.
315    pub fn available(&self) -> u64 {
316        self.available.load(Ordering::Acquire)
317    }
318
319    /// Utilization in `[0, 1]`. Useful for telemetry gauges.
320    pub fn utilization(&self) -> f64 {
321        if self.capacity == 0 {
322            return 0.0;
323        }
324        let used = self.capacity.saturating_sub(self.available());
325        used as f64 / self.capacity as f64
326    }
327
328    /// Restore the budget to capacity. Typically called at the start
329    /// of each scheduling epoch.
330    pub fn refill(&self) {
331        self.available.store(self.capacity, Ordering::Release);
332    }
333}
334
335#[async_trait]
336impl BudgetGuard for AtomicBudget {
337    #[instrument(name = "rig_compose.budget.try_reserve", skip(self), fields(cost))]
338    async fn try_reserve(&self, cost: u64) -> Result<bool, BudgetError> {
339        let mut current = self.available.load(Ordering::Acquire);
340        loop {
341            if current < cost {
342                trace!(current, "budget would be exceeded");
343                return Ok(false);
344            }
345            match self.available.compare_exchange_weak(
346                current,
347                current - cost,
348                Ordering::AcqRel,
349                Ordering::Acquire,
350            ) {
351                Ok(_) => return Ok(true),
352                Err(observed) => current = observed,
353            }
354        }
355    }
356
357    #[instrument(name = "rig_compose.budget.release", skip(self), fields(cost))]
358    async fn release(&self, cost: u64) {
359        let mut current = self.available.load(Ordering::Acquire);
360        loop {
361            let next = current.saturating_add(cost).min(self.capacity);
362            match self.available.compare_exchange_weak(
363                current,
364                next,
365                Ordering::AcqRel,
366                Ordering::Acquire,
367            ) {
368                Ok(_) => return,
369                Err(observed) => current = observed,
370            }
371        }
372    }
373}
374
375/// Atomic counterpart of [`AtomicBudget`] that tracks LLM token spend.
376///
377/// `try_reserve_tokens` deducts an estimate up front so concurrent
378/// prompts can't all over-commit. `record_usage` reconciles one
379/// [`TokenReservation`] against the provider's reported totals — if the
380/// actual usage was smaller than that estimate, the difference is
381/// returned to the pool; if larger, the overage is debited from future
382/// reservations.
383///
384/// Reservations also refund their estimate on drop, so an error path
385/// that abandons the handle without calling `record_usage` does not
386/// permanently leak tokens.
387#[derive(Debug)]
388pub struct AtomicTokenBudget {
389    inner: Arc<AtomicTokenBudgetInner>,
390}
391
392#[derive(Debug)]
393struct AtomicTokenBudgetInner {
394    capacity: u64,
395    available: AtomicU64,
396    consumed: AtomicU64,
397}
398
399impl AtomicTokenBudgetInner {
400    fn refund(&self, amount: u64) {
401        if amount == 0 {
402            return;
403        }
404        let mut current = self.available.load(Ordering::Acquire);
405        loop {
406            let next = current.saturating_add(amount).min(self.capacity);
407            match self.available.compare_exchange_weak(
408                current,
409                next,
410                Ordering::AcqRel,
411                Ordering::Acquire,
412            ) {
413                Ok(_) => return,
414                Err(observed) => current = observed,
415            }
416        }
417    }
418
419    fn debit(&self, amount: u64) {
420        if amount == 0 {
421            return;
422        }
423        let mut current = self.available.load(Ordering::Acquire);
424        loop {
425            let next = current.saturating_sub(amount);
426            match self.available.compare_exchange_weak(
427                current,
428                next,
429                Ordering::AcqRel,
430                Ordering::Acquire,
431            ) {
432                Ok(_) => return,
433                Err(observed) => current = observed,
434            }
435        }
436    }
437}
438
439impl AtomicTokenBudget {
440    /// Create a token budget with the given capacity, initially full.
441    pub fn new(capacity: u64) -> Self {
442        Self {
443            inner: Arc::new(AtomicTokenBudgetInner {
444                capacity,
445                available: AtomicU64::new(capacity),
446                consumed: AtomicU64::new(0),
447            }),
448        }
449    }
450
451    /// Total capacity at construction time.
452    pub fn capacity(&self) -> u64 {
453        self.inner.capacity
454    }
455
456    /// Currently available tokens.
457    pub fn available(&self) -> u64 {
458        self.inner.available.load(Ordering::Acquire)
459    }
460}
461
462#[async_trait]
463impl TokenBudget for AtomicTokenBudget {
464    #[instrument(name = "rig_compose.token_budget.try_reserve", skip(self), fields(est))]
465    async fn try_reserve_tokens(&self, est: u64) -> Result<Option<TokenReservation>, BudgetError> {
466        let mut current = self.inner.available.load(Ordering::Acquire);
467        loop {
468            if current < est {
469                trace!(current, "token budget would be exceeded");
470                return Ok(None);
471            }
472            match self.inner.available.compare_exchange_weak(
473                current,
474                current - est,
475                Ordering::AcqRel,
476                Ordering::Acquire,
477            ) {
478                Ok(_) => {
479                    let weak = Arc::downgrade(&self.inner);
480                    let refund: TokenRefund = Box::new(move |amount| {
481                        if let Some(inner) = weak.upgrade() {
482                            inner.refund(amount);
483                        }
484                    });
485                    return Ok(Some(TokenReservation::new(est, refund)));
486                }
487                Err(observed) => current = observed,
488            }
489        }
490    }
491
492    #[instrument(name = "rig_compose.token_budget.record_usage", skip(self))]
493    async fn record_usage(&self, mut reservation: TokenReservation, prompt: u64, completion: u64) {
494        let actual = prompt.saturating_add(completion);
495        self.inner.consumed.fetch_add(actual, Ordering::AcqRel);
496        // Disarm the refund-on-drop before reconciling, otherwise the
497        // estimate would be returned twice.
498        let _ = reservation.disarm();
499        let estimate = reservation.estimate();
500        if estimate >= actual {
501            // Refund any over-reservation so concurrent callers see
502            // accurate availability before the next prompt fires.
503            self.inner.refund(estimate - actual);
504        } else {
505            // Actuals exceeded the reservation — debit the overage
506            // from future reservations rather than letting the bucket
507            // silently drift past capacity.
508            self.inner.debit(actual - estimate);
509        }
510    }
511
512    async fn tokens_consumed(&self) -> u64 {
513        self.inner.consumed.load(Ordering::Acquire)
514    }
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520
521    #[tokio::test]
522    async fn reserve_until_empty() {
523        let b = AtomicBudget::new(100);
524        assert!(b.try_reserve(60).await.unwrap());
525        assert!(b.try_reserve(40).await.unwrap());
526        assert!(!b.try_reserve(1).await.unwrap());
527        b.release(50).await;
528        assert!(b.try_reserve(50).await.unwrap());
529    }
530
531    #[tokio::test]
532    async fn release_caps_at_capacity() {
533        let b = AtomicBudget::new(10);
534        assert!(b.try_reserve(5).await.unwrap());
535        b.release(100).await;
536        assert_eq!(b.available(), 10);
537    }
538
539    #[tokio::test]
540    async fn refill_restores_capacity() {
541        let b = AtomicBudget::new(100);
542        assert!(b.try_reserve(75).await.unwrap());
543        assert_eq!(b.available(), 25);
544        b.refill();
545        assert_eq!(b.available(), 100);
546    }
547
548    #[tokio::test]
549    async fn utilization_tracks_consumption() {
550        let b = AtomicBudget::new(100);
551        assert!((b.utilization() - 0.0).abs() < f64::EPSILON);
552        assert!(b.try_reserve(40).await.unwrap());
553        assert!((b.utilization() - 0.4).abs() < f64::EPSILON);
554    }
555
556    #[tokio::test]
557    async fn token_budget_reserves_records_and_reports() {
558        let tb = AtomicTokenBudget::new(1_000);
559        let reservation = tb.try_reserve_tokens(400).await.unwrap().unwrap();
560        tb.record_usage(reservation, 120, 80).await;
561        assert_eq!(tb.tokens_consumed().await, 200);
562        // Bind the second reservation so its drop-refund doesn't restore
563        // the bucket before the deny-check below.
564        let _hold = tb.try_reserve_tokens(800).await.unwrap().unwrap();
565        assert!(tb.try_reserve_tokens(1).await.unwrap().is_none());
566    }
567
568    #[tokio::test]
569    async fn token_budget_debits_overage() {
570        let tb = AtomicTokenBudget::new(1_000);
571        let reservation = tb.try_reserve_tokens(100).await.unwrap().unwrap();
572        tb.record_usage(reservation, 150, 50).await;
573        assert_eq!(tb.tokens_consumed().await, 200);
574        assert_eq!(tb.available(), 800);
575    }
576
577    #[tokio::test]
578    async fn token_budget_reconciles_each_reservation_independently() {
579        let tb = AtomicTokenBudget::new(1_000);
580        let first = tb.try_reserve_tokens(400).await.unwrap().unwrap();
581        let second = tb.try_reserve_tokens(400).await.unwrap().unwrap();
582        assert_eq!(tb.available(), 200);
583
584        tb.record_usage(first, 100, 100).await;
585        assert_eq!(tb.available(), 400);
586        assert!(tb.try_reserve_tokens(401).await.unwrap().is_none());
587
588        tb.record_usage(second, 200, 200).await;
589        assert_eq!(tb.available(), 400);
590        assert_eq!(tb.tokens_consumed().await, 600);
591    }
592
593    #[tokio::test]
594    async fn token_reservation_reports_estimate() {
595        let tb = AtomicTokenBudget::new(10);
596        let reservation = tb.try_reserve_tokens(7).await.unwrap().unwrap();
597        assert_eq!(reservation.estimate(), 7);
598    }
599
600    #[tokio::test]
601    async fn token_reservation_refunds_on_drop() {
602        let tb = AtomicTokenBudget::new(1_000);
603        {
604            let _reservation = tb.try_reserve_tokens(400).await.unwrap().unwrap();
605            assert_eq!(tb.available(), 600);
606        } // dropped without record_usage
607        assert_eq!(tb.available(), 1_000);
608        assert_eq!(tb.tokens_consumed().await, 0);
609    }
610
611    #[tokio::test]
612    async fn token_reservation_refund_is_capped_at_capacity() {
613        let tb = AtomicTokenBudget::new(100);
614        let r = tb.try_reserve_tokens(40).await.unwrap().unwrap();
615        // Manually leak to capacity then drop — refund must not exceed cap.
616        drop(r);
617        assert_eq!(tb.available(), 100);
618    }
619}