use harness_core::Usage;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct TokenBudget {
pub max_input_tokens: Option<u64>,
pub max_output_tokens: Option<u64>,
pub max_total_tokens: Option<u64>,
pub max_iters_per_round: u32,
}
impl Default for TokenBudget {
fn default() -> Self {
Self {
max_input_tokens: None,
max_output_tokens: None,
max_total_tokens: None,
max_iters_per_round: 12,
}
}
}
impl TokenBudget {
pub fn iters(max_iters_per_round: u32) -> Self {
Self {
max_iters_per_round,
..Default::default()
}
}
pub fn with_max_total_tokens(mut self, n: u64) -> Self {
self.max_total_tokens = Some(n);
self
}
pub fn with_max_input_tokens(mut self, n: u64) -> Self {
self.max_input_tokens = Some(n);
self
}
pub fn with_max_output_tokens(mut self, n: u64) -> Self {
self.max_output_tokens = Some(n);
self
}
pub fn with_max_iters_per_round(mut self, n: u32) -> Self {
self.max_iters_per_round = n;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BudgetLimit {
Input,
Output,
Total,
}
impl BudgetLimit {
pub fn label(self) -> &'static str {
match self {
BudgetLimit::Input => "input-tokens",
BudgetLimit::Output => "output-tokens",
BudgetLimit::Total => "total-tokens",
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct BudgetState {
budget: TokenBudget,
pub input_tokens: u64,
pub output_tokens: u64,
}
impl BudgetState {
pub fn new(budget: TokenBudget) -> Self {
Self {
budget,
input_tokens: 0,
output_tokens: 0,
}
}
pub fn add(&mut self, usage: &Usage) {
self.input_tokens += usage.input_tokens as u64;
self.output_tokens += usage.output_tokens as u64;
}
pub fn total_tokens(&self) -> u64 {
self.input_tokens + self.output_tokens
}
pub fn max_iters(&self) -> u32 {
self.budget.max_iters_per_round
}
pub fn exceeded(&self) -> Option<BudgetLimit> {
if let Some(m) = self.budget.max_input_tokens
&& self.input_tokens > m
{
return Some(BudgetLimit::Input);
}
if let Some(m) = self.budget.max_output_tokens
&& self.output_tokens > m
{
return Some(BudgetLimit::Output);
}
if let Some(m) = self.budget.max_total_tokens
&& self.total_tokens() > m
{
return Some(BudgetLimit::Total);
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
fn usage(input: u32, output: u32) -> Usage {
Usage {
input_tokens: input,
output_tokens: output,
cached_input_tokens: 0,
}
}
#[test]
fn no_limits_never_exceeds() {
let mut s = BudgetState::new(TokenBudget::iters(8));
s.add(&usage(1_000_000, 1_000_000));
assert!(s.exceeded().is_none());
assert_eq!(s.max_iters(), 8);
}
#[test]
fn total_limit_trips() {
let mut s = BudgetState::new(TokenBudget::iters(8).with_max_total_tokens(100));
s.add(&usage(60, 30)); assert!(s.exceeded().is_none());
s.add(&usage(20, 0)); assert_eq!(s.exceeded(), Some(BudgetLimit::Total));
}
#[test]
fn input_and_output_limits_trip_independently() {
let mut s = BudgetState::new(
TokenBudget::iters(8)
.with_max_input_tokens(50)
.with_max_output_tokens(50),
);
s.add(&usage(51, 1));
assert_eq!(s.exceeded(), Some(BudgetLimit::Input));
let mut s2 = BudgetState::new(TokenBudget::iters(8).with_max_output_tokens(50));
s2.add(&usage(1, 51));
assert_eq!(s2.exceeded(), Some(BudgetLimit::Output));
}
}