#[derive(Default, Clone, Debug)]
pub struct TokenCounter {
total: usize,
}
impl TokenCounter {
#[must_use]
pub fn new() -> Self {
Self { total: 0 }
}
#[must_use]
pub fn count_tokens(text: &str) -> usize {
text.split_whitespace().count()
}
pub fn observe(&mut self, text: &str) {
self.total += Self::count_tokens(text);
}
pub fn subtract(&mut self, text: &str) {
self.total = self.total.saturating_sub(Self::count_tokens(text));
}
#[must_use]
pub fn total(&self) -> usize {
self.total
}
#[must_use]
pub fn under_budget(&self, max: usize) -> bool {
self.total <= max
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_counter_new() {
let counter = TokenCounter::new();
assert_eq!(counter.total(), 0);
}
#[test]
fn test_token_counting() {
assert_eq!(TokenCounter::count_tokens("Hello world"), 2);
assert_eq!(TokenCounter::count_tokens(""), 0);
assert_eq!(TokenCounter::count_tokens(" "), 0);
assert_eq!(TokenCounter::count_tokens("one two three"), 3);
}
#[test]
fn test_observe_and_total() {
let mut counter = TokenCounter::default();
counter.observe("Hello world");
assert_eq!(counter.total(), 2);
counter.observe("another message");
assert_eq!(counter.total(), 4); }
#[test]
fn test_subtract() {
let mut counter = TokenCounter::default();
counter.observe("Hello world another message");
assert_eq!(counter.total(), 4);
counter.subtract("Hello world");
assert_eq!(counter.total(), 2);
counter.subtract("way more tokens than we have now");
assert_eq!(counter.total(), 0);
}
#[test]
fn test_under_budget() {
let mut counter = TokenCounter::default();
counter.observe("Hello world");
assert!(counter.under_budget(2));
assert!(counter.under_budget(3));
assert!(!counter.under_budget(1));
}
}