use serde::{Deserialize, Serialize};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct CognitiveBudget {
pub max_tokens: u32,
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "crate::cognitive_budget::duration_serde"
)]
pub max_duration: Option<Duration>,
pub max_iterations: u32,
pub budget_fraction: f64,
}
impl Default for CognitiveBudget {
fn default() -> Self {
Self {
max_tokens: 2000,
max_duration: Some(Duration::from_secs(10)),
max_iterations: 5,
budget_fraction: 0.15,
}
}
}
#[derive(Debug)]
pub struct BudgetTracker {
max_tokens: u32,
max_duration: Option<Duration>,
tokens_used: u32,
start: Instant,
}
impl BudgetTracker {
pub fn new(budget: &CognitiveBudget) -> Self {
Self {
max_tokens: budget.max_tokens,
max_duration: budget.max_duration,
tokens_used: 0,
start: Instant::now(),
}
}
pub fn consume_tokens(&mut self, tokens: u32) {
self.tokens_used = self.tokens_used.saturating_add(tokens);
}
pub fn remaining_tokens(&self) -> u32 {
self.max_tokens.saturating_sub(self.tokens_used)
}
pub fn tokens_used(&self) -> u32 {
self.tokens_used
}
pub fn elapsed(&self) -> Duration {
self.start.elapsed()
}
pub fn is_exhausted(&self) -> bool {
if self.tokens_used >= self.max_tokens {
return true;
}
if let Some(max_dur) = self.max_duration {
if self.start.elapsed() >= max_dur {
return true;
}
}
false
}
}
mod duration_serde {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::time::Duration;
pub fn serialize<S: Serializer>(
duration: &Option<Duration>,
serializer: S,
) -> Result<S::Ok, S::Error> {
duration.map(|d| d.as_secs_f64()).serialize(serializer)
}
pub fn deserialize<'de, D: Deserializer<'de>>(
deserializer: D,
) -> Result<Option<Duration>, D::Error> {
let secs: Option<f64> = Option::deserialize(deserializer)?;
Ok(secs.map(Duration::from_secs_f64))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_budget_defaults() {
let budget = CognitiveBudget::default();
assert_eq!(budget.max_tokens, 2000);
assert_eq!(budget.max_iterations, 5);
assert!((budget.budget_fraction - 0.15).abs() < f64::EPSILON);
assert_eq!(budget.max_duration, Some(Duration::from_secs(10)));
}
#[test]
fn test_tracker_token_consumption() {
let budget = CognitiveBudget::default();
let mut tracker = BudgetTracker::new(&budget);
assert_eq!(tracker.remaining_tokens(), 2000);
assert_eq!(tracker.tokens_used(), 0);
tracker.consume_tokens(500);
assert_eq!(tracker.remaining_tokens(), 1500);
assert_eq!(tracker.tokens_used(), 500);
}
#[test]
fn test_tracker_exhaustion_by_tokens() {
let budget = CognitiveBudget {
max_tokens: 100,
max_duration: None,
..Default::default()
};
let mut tracker = BudgetTracker::new(&budget);
assert!(!tracker.is_exhausted());
tracker.consume_tokens(99);
assert!(!tracker.is_exhausted());
tracker.consume_tokens(1);
assert!(tracker.is_exhausted());
}
#[test]
fn test_tracker_token_saturation() {
let budget = CognitiveBudget {
max_tokens: 100,
max_duration: None,
..Default::default()
};
let mut tracker = BudgetTracker::new(&budget);
tracker.consume_tokens(u32::MAX);
assert!(tracker.is_exhausted());
assert_eq!(tracker.remaining_tokens(), 0);
}
#[test]
fn test_budget_serialization() {
let budget = CognitiveBudget::default();
let json = serde_json::to_string(&budget).unwrap();
let back: CognitiveBudget = serde_json::from_str(&json).unwrap();
assert_eq!(back.max_tokens, budget.max_tokens);
assert_eq!(back.max_iterations, budget.max_iterations);
assert_eq!(back.max_duration, budget.max_duration);
assert!((back.budget_fraction - budget.budget_fraction).abs() < f64::EPSILON);
}
}