token-budget-pool 0.1.0

Shared token + dollar budget across concurrent LLM tasks. Thread-safe, returns BudgetExceeded when a record would push past a cap. Zero deps.
Documentation
//! # token-budget-pool
//!
//! Shared token + dollar budget across N concurrent LLM tasks.
//!
//! Drop a `BudgetPool` at the top of an agent run; pass `&pool` to every
//! task that issues LLM calls; call [`BudgetPool::record`] after each
//! response. The pool serializes the updates and returns
//! [`BudgetExceeded`] when a record would push past any cap.
//!
//! ## Example
//!
//! ```
//! use token_budget_pool::{BudgetPool, Caps};
//!
//! let pool = BudgetPool::with_caps(Caps {
//!     max_input_tokens: Some(10_000),
//!     max_output_tokens: Some(5_000),
//!     max_total_tokens: None,
//!     max_cost_usd: Some(1.0),
//! });
//!
//! pool.record(1_000, 500, 0.05).unwrap(); // fits
//! let err = pool.record(20_000, 0, 0.0).unwrap_err(); // input cap blown
//! assert!(format!("{err}").contains("input_tokens"));
//! ```

#![deny(missing_docs)]

use std::sync::Mutex;

/// Caps for a single pool. Any cap left as `None` is unenforced.
#[derive(Debug, Clone, Copy, Default, PartialEq)]
pub struct Caps {
    /// Cap on cumulative input tokens across all recorded calls.
    pub max_input_tokens: Option<u64>,
    /// Cap on cumulative output tokens across all recorded calls.
    pub max_output_tokens: Option<u64>,
    /// Cap on cumulative input + output tokens.
    pub max_total_tokens: Option<u64>,
    /// Cap on cumulative USD spend.
    pub max_cost_usd: Option<f64>,
}

/// Running totals.
#[derive(Debug, Clone, Copy, Default, PartialEq)]
pub struct Totals {
    /// Cumulative input tokens recorded.
    pub input_tokens: u64,
    /// Cumulative output tokens recorded.
    pub output_tokens: u64,
    /// Cumulative dollars recorded.
    pub cost_usd: f64,
    /// Number of `record` calls counted.
    pub calls: u64,
}

impl Totals {
    /// Sum of input + output tokens.
    pub fn total_tokens(&self) -> u64 {
        self.input_tokens + self.output_tokens
    }
}

/// Error returned when a `record` call would push past a cap.
///
/// The error names the first cap that would be exceeded; subsequent caps
/// may also be exceeded. **The pool's totals are NOT updated** when this
/// error fires — the call is rejected outright.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct BudgetExceeded {
    /// Which cap blew. One of `"input_tokens"`, `"output_tokens"`,
    /// `"total_tokens"`, `"cost_usd"`.
    pub cap: &'static str,
    /// The cap limit that was breached.
    pub limit: f64,
    /// What the running total would have become.
    pub attempted: f64,
}

impl std::fmt::Display for BudgetExceeded {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(
            f,
            "budget cap `{}` exceeded: limit={}, attempted={}",
            self.cap, self.limit, self.attempted
        )
    }
}

impl std::error::Error for BudgetExceeded {}

/// Shared budget. Cheap to construct; record() takes a mutex.
#[derive(Debug)]
pub struct BudgetPool {
    caps: Caps,
    state: Mutex<Totals>,
}

impl BudgetPool {
    /// Build a pool with the given caps. All caps default to `None`.
    pub fn with_caps(caps: Caps) -> Self {
        Self {
            caps,
            state: Mutex::new(Totals::default()),
        }
    }

    /// Build an unconstrained pool (no caps).
    pub fn unconstrained() -> Self {
        Self::with_caps(Caps::default())
    }

    /// Record one call's usage. Returns the updated totals on success, or
    /// [`BudgetExceeded`] (totals unchanged) on cap breach.
    pub fn record(
        &self,
        input_tokens: u64,
        output_tokens: u64,
        cost_usd: f64,
    ) -> Result<Totals, BudgetExceeded> {
        let mut s = self.state.lock().unwrap();

        let next_in = s.input_tokens + input_tokens;
        let next_out = s.output_tokens + output_tokens;
        let next_total = next_in + next_out;
        let next_cost = s.cost_usd + cost_usd;

        if let Some(cap) = self.caps.max_input_tokens {
            if next_in > cap {
                return Err(BudgetExceeded {
                    cap: "input_tokens",
                    limit: cap as f64,
                    attempted: next_in as f64,
                });
            }
        }
        if let Some(cap) = self.caps.max_output_tokens {
            if next_out > cap {
                return Err(BudgetExceeded {
                    cap: "output_tokens",
                    limit: cap as f64,
                    attempted: next_out as f64,
                });
            }
        }
        if let Some(cap) = self.caps.max_total_tokens {
            if next_total > cap {
                return Err(BudgetExceeded {
                    cap: "total_tokens",
                    limit: cap as f64,
                    attempted: next_total as f64,
                });
            }
        }
        if let Some(cap) = self.caps.max_cost_usd {
            if next_cost > cap {
                return Err(BudgetExceeded {
                    cap: "cost_usd",
                    limit: cap,
                    attempted: next_cost,
                });
            }
        }

        s.input_tokens = next_in;
        s.output_tokens = next_out;
        s.cost_usd = next_cost;
        s.calls += 1;
        Ok(*s)
    }

    /// Read current totals.
    pub fn totals(&self) -> Totals {
        *self.state.lock().unwrap()
    }

    /// Read the caps.
    pub fn caps(&self) -> Caps {
        self.caps
    }

    /// Reset the pool to zero totals (caps unchanged).
    pub fn reset(&self) {
        *self.state.lock().unwrap() = Totals::default();
    }
}