use std::collections::HashMap;
use std::sync::RwLock;
use std::sync::atomic::{AtomicUsize, Ordering};
use super::config::BudgetConfig;
#[derive(Debug, Clone, Default)]
pub struct BudgetUsage {
pub input_tokens: usize,
pub output_tokens: usize,
pub calls_made: usize,
pub max_tokens: usize,
pub max_calls: usize,
}
impl BudgetUsage {
pub fn total_tokens(&self) -> usize {
self.input_tokens + self.output_tokens
}
pub fn token_utilization(&self) -> f32 {
if self.max_tokens == 0 {
0.0
} else {
(self.total_tokens() as f32 / self.max_tokens as f32).min(1.0)
}
}
pub fn call_utilization(&self) -> f32 {
if self.max_calls == 0 {
0.0
} else {
(self.calls_made as f32 / self.max_calls as f32).min(1.0)
}
}
pub fn is_exhausted(&self) -> bool {
self.total_tokens() >= self.max_tokens || self.calls_made >= self.max_calls
}
}
pub struct BudgetController {
config: BudgetConfig,
input_tokens: AtomicUsize,
output_tokens: AtomicUsize,
calls_made: AtomicUsize,
level_calls: RwLock<HashMap<usize, usize>>,
}
impl BudgetController {
pub fn new(config: BudgetConfig) -> Self {
Self {
config,
input_tokens: AtomicUsize::new(0),
output_tokens: AtomicUsize::new(0),
calls_made: AtomicUsize::new(0),
level_calls: RwLock::new(HashMap::new()),
}
}
pub fn with_defaults() -> Self {
Self::new(BudgetConfig::default())
}
pub fn can_call(&self) -> bool {
let tokens = self.total_tokens();
let calls = self.calls_made.load(Ordering::Relaxed);
tokens < self.config.max_tokens_per_query && calls < self.config.max_calls_per_query
}
pub fn can_call_at_level(&self, level: usize) -> bool {
if !self.can_call() {
return false;
}
let level_calls = self.level_calls.read().unwrap();
let calls = level_calls.get(&level).copied().unwrap_or(0);
calls < self.config.max_calls_per_level
}
pub fn estimate_cost(&self, context: &str) -> usize {
let char_count = context.chars().count();
let chinese_count = context
.chars()
.filter(|c| ('\u{4E00}'..='\u{9FFF}').contains(c))
.count();
let english_count = char_count - chinese_count;
let input_tokens =
(chinese_count as f32 / 1.5 + english_count as f32 / 4.0).ceil() as usize;
input_tokens + 100
}
pub fn can_afford(&self, estimated_cost: usize) -> bool {
let remaining = self.remaining_tokens();
estimated_cost <= remaining && estimated_cost <= self.config.max_tokens_per_call
}
pub fn remaining_tokens(&self) -> usize {
self.config
.max_tokens_per_query
.saturating_sub(self.total_tokens())
}
pub fn remaining_calls(&self) -> usize {
self.config
.max_calls_per_query
.saturating_sub(self.calls_made.load(Ordering::Relaxed))
}
pub fn record_usage(&self, input_tokens: usize, output_tokens: usize, level: usize) {
self.input_tokens.fetch_add(input_tokens, Ordering::Relaxed);
self.output_tokens
.fetch_add(output_tokens, Ordering::Relaxed);
self.calls_made.fetch_add(1, Ordering::Relaxed);
{
let mut level_calls = self.level_calls.write().unwrap();
*level_calls.entry(level).or_insert(0) += 1;
}
}
pub fn total_tokens(&self) -> usize {
self.input_tokens.load(Ordering::Relaxed) + self.output_tokens.load(Ordering::Relaxed)
}
pub fn usage(&self) -> BudgetUsage {
BudgetUsage {
input_tokens: self.input_tokens.load(Ordering::Relaxed),
output_tokens: self.output_tokens.load(Ordering::Relaxed),
calls_made: self.calls_made.load(Ordering::Relaxed),
max_tokens: self.config.max_tokens_per_query,
max_calls: self.config.max_calls_per_query,
}
}
pub fn calls_at_level(&self, level: usize) -> usize {
let level_calls = self.level_calls.read().unwrap();
level_calls.get(&level).copied().unwrap_or(0)
}
pub fn reset(&self) {
self.input_tokens.store(0, Ordering::Relaxed);
self.output_tokens.store(0, Ordering::Relaxed);
self.calls_made.store(0, Ordering::Relaxed);
self.level_calls.write().unwrap().clear();
}
pub fn config(&self) -> &BudgetConfig {
&self.config
}
pub fn is_hard_limit(&self) -> bool {
self.config.hard_limit
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_budget_controller_new() {
let config = BudgetConfig::default();
let max_calls = config.max_calls_per_query;
let budget = BudgetController::new(config);
assert!(budget.can_call());
assert_eq!(budget.remaining_calls(), max_calls);
}
#[test]
fn test_budget_can_call() {
let config = BudgetConfig {
max_tokens_per_query: 100,
max_calls_per_query: 2,
..Default::default()
};
let budget = BudgetController::new(config);
assert!(budget.can_call());
budget.record_usage(50, 30, 0);
assert!(budget.can_call());
budget.record_usage(50, 30, 0);
assert!(!budget.can_call()); }
#[test]
fn test_budget_level_limit() {
let config = BudgetConfig {
max_calls_per_query: 10,
max_calls_per_level: 2,
..Default::default()
};
let budget = BudgetController::new(config);
assert!(budget.can_call_at_level(0));
budget.record_usage(10, 10, 0);
budget.record_usage(10, 10, 0);
assert!(!budget.can_call_at_level(0)); assert!(budget.can_call_at_level(1)); }
#[test]
fn test_budget_estimate_cost() {
let budget = BudgetController::with_defaults();
let english = "Hello world this is a test";
let cost = budget.estimate_cost(english);
assert!(
cost > 100 && cost < 150,
"Expected cost between 100-150, got {}",
cost
);
let chinese = "这是一个测试";
let cost_chinese = budget.estimate_cost(chinese);
assert!(
cost_chinese > 100,
"Expected Chinese cost > 100, got {}",
cost_chinese
);
}
#[test]
fn test_budget_can_afford() {
let config = BudgetConfig {
max_tokens_per_query: 200,
max_tokens_per_call: 100,
..Default::default()
};
let budget = BudgetController::new(config);
assert!(budget.can_afford(50));
assert!(budget.can_afford(100));
assert!(!budget.can_afford(150));
budget.record_usage(100, 50, 0); assert!(budget.can_afford(50)); assert!(!budget.can_afford(60)); }
#[test]
fn test_budget_reset() {
let budget = BudgetController::with_defaults();
budget.record_usage(100, 50, 0);
assert_eq!(budget.total_tokens(), 150);
assert_eq!(budget.calls_made.load(Ordering::Relaxed), 1);
budget.reset();
assert_eq!(budget.total_tokens(), 0);
assert_eq!(budget.calls_made.load(Ordering::Relaxed), 0);
}
#[test]
fn test_budget_usage_stats() {
let budget = BudgetController::with_defaults();
budget.record_usage(100, 50, 0);
let usage = budget.usage();
assert_eq!(usage.input_tokens, 100);
assert_eq!(usage.output_tokens, 50);
assert_eq!(usage.calls_made, 1);
assert_eq!(usage.total_tokens(), 150);
}
}