use std::sync::Arc;
use std::thread;
use token_budget_pool::{BudgetPool, Caps};
#[test]
fn unconstrained_accepts_everything() {
let p = BudgetPool::unconstrained();
p.record(1_000_000, 500_000, 100.0).unwrap();
p.record(1_000_000, 500_000, 100.0).unwrap();
let t = p.totals();
assert_eq!(t.input_tokens, 2_000_000);
assert_eq!(t.calls, 2);
}
#[test]
fn input_cap_blocks_at_boundary() {
let p = BudgetPool::with_caps(Caps {
max_input_tokens: Some(1000),
..Default::default()
});
p.record(800, 0, 0.0).unwrap();
let err = p.record(201, 0, 0.0).unwrap_err();
assert_eq!(err.cap, "input_tokens");
assert_eq!(p.totals().input_tokens, 800);
}
#[test]
fn output_cap_blocks() {
let p = BudgetPool::with_caps(Caps {
max_output_tokens: Some(500),
..Default::default()
});
assert_eq!(
p.record(0, 600, 0.0).unwrap_err().cap,
"output_tokens"
);
}
#[test]
fn total_cap_blocks_when_split_fits_separately() {
let p = BudgetPool::with_caps(Caps {
max_total_tokens: Some(1000),
..Default::default()
});
p.record(700, 200, 0.0).unwrap();
let err = p.record(50, 100, 0.0).unwrap_err();
assert_eq!(err.cap, "total_tokens");
}
#[test]
fn cost_cap_blocks() {
let p = BudgetPool::with_caps(Caps {
max_cost_usd: Some(1.0),
..Default::default()
});
p.record(0, 0, 0.9).unwrap();
assert_eq!(p.record(0, 0, 0.2).unwrap_err().cap, "cost_usd");
}
#[test]
fn reset_clears_totals() {
let p = BudgetPool::with_caps(Caps {
max_input_tokens: Some(100),
..Default::default()
});
p.record(50, 0, 0.0).unwrap();
p.reset();
p.record(80, 0, 0.0).unwrap();
}
#[test]
fn concurrent_records_are_safe() {
let p = Arc::new(BudgetPool::with_caps(Caps {
max_input_tokens: Some(10_000),
..Default::default()
}));
let mut handles = vec![];
for _ in 0..10 {
let p = p.clone();
handles.push(thread::spawn(move || {
for _ in 0..100 {
p.record(1, 0, 0.0).unwrap();
}
}));
}
for h in handles {
h.join().unwrap();
}
let t = p.totals();
assert_eq!(t.input_tokens, 1000);
assert_eq!(t.calls, 1000);
}