use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use async_trait::async_trait;
use tracing::{instrument, trace};
#[derive(Debug, thiserror::Error)]
pub enum BudgetError {
#[error("budget backend error: {0}")]
Backend(String),
}
#[async_trait]
pub trait BudgetGuard: Send + Sync {
async fn try_reserve(&self, cost: u64) -> Result<bool, BudgetError>;
async fn release(&self, cost: u64);
}
pub type TokenRefund = Box<dyn FnOnce(u64) + Send + Sync>;
pub struct TokenReservation {
estimate: u64,
refund: Option<TokenRefund>,
}
impl std::fmt::Debug for TokenReservation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokenReservation")
.field("estimate", &self.estimate)
.field("armed", &self.refund.is_some())
.finish()
}
}
impl TokenReservation {
pub fn new(estimate: u64, refund: TokenRefund) -> Self {
Self {
estimate,
refund: Some(refund),
}
}
pub fn estimate(&self) -> u64 {
self.estimate
}
pub fn disarm(&mut self) -> Option<TokenRefund> {
self.refund.take()
}
}
impl Drop for TokenReservation {
fn drop(&mut self) {
if let Some(refund) = self.refund.take() {
refund(self.estimate);
}
}
}
#[async_trait]
pub trait TokenBudget: Send + Sync {
async fn try_reserve_tokens(&self, est: u64) -> Result<Option<TokenReservation>, BudgetError>;
async fn record_usage(&self, reservation: TokenReservation, prompt: u64, completion: u64);
async fn tokens_consumed(&self) -> u64;
}
#[derive(Debug)]
pub struct AtomicBudget {
capacity: u64,
available: AtomicU64,
}
impl AtomicBudget {
pub fn new(capacity: u64) -> Self {
Self {
capacity,
available: AtomicU64::new(capacity),
}
}
pub fn capacity(&self) -> u64 {
self.capacity
}
pub fn available(&self) -> u64 {
self.available.load(Ordering::Acquire)
}
pub fn utilization(&self) -> f64 {
if self.capacity == 0 {
return 0.0;
}
let used = self.capacity.saturating_sub(self.available());
used as f64 / self.capacity as f64
}
pub fn refill(&self) {
self.available.store(self.capacity, Ordering::Release);
}
}
#[async_trait]
impl BudgetGuard for AtomicBudget {
#[instrument(name = "rig_compose.budget.try_reserve", skip(self), fields(cost))]
async fn try_reserve(&self, cost: u64) -> Result<bool, BudgetError> {
let mut current = self.available.load(Ordering::Acquire);
loop {
if current < cost {
trace!(current, "budget would be exceeded");
return Ok(false);
}
match self.available.compare_exchange_weak(
current,
current - cost,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return Ok(true),
Err(observed) => current = observed,
}
}
}
#[instrument(name = "rig_compose.budget.release", skip(self), fields(cost))]
async fn release(&self, cost: u64) {
let mut current = self.available.load(Ordering::Acquire);
loop {
let next = current.saturating_add(cost).min(self.capacity);
match self.available.compare_exchange_weak(
current,
next,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return,
Err(observed) => current = observed,
}
}
}
}
#[derive(Debug)]
pub struct AtomicTokenBudget {
inner: Arc<AtomicTokenBudgetInner>,
}
#[derive(Debug)]
struct AtomicTokenBudgetInner {
capacity: u64,
available: AtomicU64,
consumed: AtomicU64,
}
impl AtomicTokenBudgetInner {
fn refund(&self, amount: u64) {
if amount == 0 {
return;
}
let mut current = self.available.load(Ordering::Acquire);
loop {
let next = current.saturating_add(amount).min(self.capacity);
match self.available.compare_exchange_weak(
current,
next,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return,
Err(observed) => current = observed,
}
}
}
fn debit(&self, amount: u64) {
if amount == 0 {
return;
}
let mut current = self.available.load(Ordering::Acquire);
loop {
let next = current.saturating_sub(amount);
match self.available.compare_exchange_weak(
current,
next,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return,
Err(observed) => current = observed,
}
}
}
}
impl AtomicTokenBudget {
pub fn new(capacity: u64) -> Self {
Self {
inner: Arc::new(AtomicTokenBudgetInner {
capacity,
available: AtomicU64::new(capacity),
consumed: AtomicU64::new(0),
}),
}
}
pub fn capacity(&self) -> u64 {
self.inner.capacity
}
pub fn available(&self) -> u64 {
self.inner.available.load(Ordering::Acquire)
}
}
#[async_trait]
impl TokenBudget for AtomicTokenBudget {
#[instrument(name = "rig_compose.token_budget.try_reserve", skip(self), fields(est))]
async fn try_reserve_tokens(&self, est: u64) -> Result<Option<TokenReservation>, BudgetError> {
let mut current = self.inner.available.load(Ordering::Acquire);
loop {
if current < est {
trace!(current, "token budget would be exceeded");
return Ok(None);
}
match self.inner.available.compare_exchange_weak(
current,
current - est,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
let weak = Arc::downgrade(&self.inner);
let refund: TokenRefund = Box::new(move |amount| {
if let Some(inner) = weak.upgrade() {
inner.refund(amount);
}
});
return Ok(Some(TokenReservation::new(est, refund)));
}
Err(observed) => current = observed,
}
}
}
#[instrument(name = "rig_compose.token_budget.record_usage", skip(self))]
async fn record_usage(&self, mut reservation: TokenReservation, prompt: u64, completion: u64) {
let actual = prompt.saturating_add(completion);
self.inner.consumed.fetch_add(actual, Ordering::AcqRel);
let _ = reservation.disarm();
let estimate = reservation.estimate();
if estimate >= actual {
self.inner.refund(estimate - actual);
} else {
self.inner.debit(actual - estimate);
}
}
async fn tokens_consumed(&self) -> u64 {
self.inner.consumed.load(Ordering::Acquire)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn reserve_until_empty() {
let b = AtomicBudget::new(100);
assert!(b.try_reserve(60).await.unwrap());
assert!(b.try_reserve(40).await.unwrap());
assert!(!b.try_reserve(1).await.unwrap());
b.release(50).await;
assert!(b.try_reserve(50).await.unwrap());
}
#[tokio::test]
async fn release_caps_at_capacity() {
let b = AtomicBudget::new(10);
assert!(b.try_reserve(5).await.unwrap());
b.release(100).await;
assert_eq!(b.available(), 10);
}
#[tokio::test]
async fn refill_restores_capacity() {
let b = AtomicBudget::new(100);
assert!(b.try_reserve(75).await.unwrap());
assert_eq!(b.available(), 25);
b.refill();
assert_eq!(b.available(), 100);
}
#[tokio::test]
async fn utilization_tracks_consumption() {
let b = AtomicBudget::new(100);
assert!((b.utilization() - 0.0).abs() < f64::EPSILON);
assert!(b.try_reserve(40).await.unwrap());
assert!((b.utilization() - 0.4).abs() < f64::EPSILON);
}
#[tokio::test]
async fn token_budget_reserves_records_and_reports() {
let tb = AtomicTokenBudget::new(1_000);
let reservation = tb.try_reserve_tokens(400).await.unwrap().unwrap();
tb.record_usage(reservation, 120, 80).await;
assert_eq!(tb.tokens_consumed().await, 200);
let _hold = tb.try_reserve_tokens(800).await.unwrap().unwrap();
assert!(tb.try_reserve_tokens(1).await.unwrap().is_none());
}
#[tokio::test]
async fn token_budget_debits_overage() {
let tb = AtomicTokenBudget::new(1_000);
let reservation = tb.try_reserve_tokens(100).await.unwrap().unwrap();
tb.record_usage(reservation, 150, 50).await;
assert_eq!(tb.tokens_consumed().await, 200);
assert_eq!(tb.available(), 800);
}
#[tokio::test]
async fn token_budget_reconciles_each_reservation_independently() {
let tb = AtomicTokenBudget::new(1_000);
let first = tb.try_reserve_tokens(400).await.unwrap().unwrap();
let second = tb.try_reserve_tokens(400).await.unwrap().unwrap();
assert_eq!(tb.available(), 200);
tb.record_usage(first, 100, 100).await;
assert_eq!(tb.available(), 400);
assert!(tb.try_reserve_tokens(401).await.unwrap().is_none());
tb.record_usage(second, 200, 200).await;
assert_eq!(tb.available(), 400);
assert_eq!(tb.tokens_consumed().await, 600);
}
#[tokio::test]
async fn token_reservation_reports_estimate() {
let tb = AtomicTokenBudget::new(10);
let reservation = tb.try_reserve_tokens(7).await.unwrap().unwrap();
assert_eq!(reservation.estimate(), 7);
}
#[tokio::test]
async fn token_reservation_refunds_on_drop() {
let tb = AtomicTokenBudget::new(1_000);
{
let _reservation = tb.try_reserve_tokens(400).await.unwrap().unwrap();
assert_eq!(tb.available(), 600);
} assert_eq!(tb.available(), 1_000);
assert_eq!(tb.tokens_consumed().await, 0);
}
#[tokio::test]
async fn token_reservation_refund_is_capped_at_capacity() {
let tb = AtomicTokenBudget::new(100);
let r = tb.try_reserve_tokens(40).await.unwrap().unwrap();
drop(r);
assert_eq!(tb.available(), 100);
}
}