token-budget-pool 0.1.0

Shared token + dollar budget across concurrent LLM tasks. Thread-safe, returns BudgetExceeded when a record would push past a cap. Zero deps.
Documentation
use std::sync::Arc;
use std::thread;
use token_budget_pool::{BudgetPool, Caps};

#[test]
fn unconstrained_accepts_everything() {
    let p = BudgetPool::unconstrained();
    p.record(1_000_000, 500_000, 100.0).unwrap();
    p.record(1_000_000, 500_000, 100.0).unwrap();
    let t = p.totals();
    assert_eq!(t.input_tokens, 2_000_000);
    assert_eq!(t.calls, 2);
}

#[test]
fn input_cap_blocks_at_boundary() {
    let p = BudgetPool::with_caps(Caps {
        max_input_tokens: Some(1000),
        ..Default::default()
    });
    p.record(800, 0, 0.0).unwrap();
    let err = p.record(201, 0, 0.0).unwrap_err();
    assert_eq!(err.cap, "input_tokens");
    // Totals unchanged after rejection.
    assert_eq!(p.totals().input_tokens, 800);
}

#[test]
fn output_cap_blocks() {
    let p = BudgetPool::with_caps(Caps {
        max_output_tokens: Some(500),
        ..Default::default()
    });
    assert_eq!(
        p.record(0, 600, 0.0).unwrap_err().cap,
        "output_tokens"
    );
}

#[test]
fn total_cap_blocks_when_split_fits_separately() {
    let p = BudgetPool::with_caps(Caps {
        max_total_tokens: Some(1000),
        ..Default::default()
    });
    p.record(700, 200, 0.0).unwrap();
    let err = p.record(50, 100, 0.0).unwrap_err();
    assert_eq!(err.cap, "total_tokens");
}

#[test]
fn cost_cap_blocks() {
    let p = BudgetPool::with_caps(Caps {
        max_cost_usd: Some(1.0),
        ..Default::default()
    });
    p.record(0, 0, 0.9).unwrap();
    assert_eq!(p.record(0, 0, 0.2).unwrap_err().cap, "cost_usd");
}

#[test]
fn reset_clears_totals() {
    let p = BudgetPool::with_caps(Caps {
        max_input_tokens: Some(100),
        ..Default::default()
    });
    p.record(50, 0, 0.0).unwrap();
    p.reset();
    p.record(80, 0, 0.0).unwrap();
}

#[test]
fn concurrent_records_are_safe() {
    let p = Arc::new(BudgetPool::with_caps(Caps {
        max_input_tokens: Some(10_000),
        ..Default::default()
    }));
    let mut handles = vec![];
    for _ in 0..10 {
        let p = p.clone();
        handles.push(thread::spawn(move || {
            // 10 threads * 100 records each * 1 input token = 1000 — fits.
            for _ in 0..100 {
                p.record(1, 0, 0.0).unwrap();
            }
        }));
    }
    for h in handles {
        h.join().unwrap();
    }
    let t = p.totals();
    assert_eq!(t.input_tokens, 1000);
    assert_eq!(t.calls, 1000);
}