use serde_json::Value;
use std::collections::HashMap;
pub trait TokenCounter: Send + Sync {
fn count(&self, text: &str) -> usize;
fn count_messages(&self, messages: &[Value]) -> usize {
let mut total: usize = 0;
for msg in messages {
total += 3; if let Some(content) = msg.get("content").and_then(|v| v.as_str()) {
total += self.count(content);
}
}
total + 3 }
fn name(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct SimpleTokenCounter;
impl SimpleTokenCounter {
pub fn new() -> Self {
Self
}
}
impl Default for SimpleTokenCounter {
fn default() -> Self {
Self::new()
}
}
impl TokenCounter for SimpleTokenCounter {
fn count(&self, text: &str) -> usize {
if text.is_empty() {
return 0;
}
let word_count = text.split_whitespace().count();
(word_count as f64 / 0.75).ceil() as usize
}
fn name(&self) -> &str {
"simple"
}
}
#[derive(Debug, Clone)]
pub struct CharBasedCounter {
chars_per_token: f64,
}
impl CharBasedCounter {
pub fn new(chars_per_token: f64) -> Self {
Self { chars_per_token }
}
}
impl Default for CharBasedCounter {
fn default() -> Self {
Self::new(4.0)
}
}
impl TokenCounter for CharBasedCounter {
fn count(&self, text: &str) -> usize {
if text.is_empty() {
return 0;
}
(text.len() as f64 / self.chars_per_token).ceil() as usize
}
fn name(&self) -> &str {
"char_based"
}
}
#[derive(Debug, Clone)]
pub struct ModelTokenCounter {
model_name: String,
chars_per_token: f64,
}
impl ModelTokenCounter {
pub fn for_model(model_name: &str) -> Self {
let chars_per_token = if model_name.contains("claude") {
3.5
} else {
4.0
};
Self {
model_name: model_name.to_string(),
chars_per_token,
}
}
pub fn chars_per_token(&self) -> f64 {
self.chars_per_token
}
}
impl TokenCounter for ModelTokenCounter {
fn count(&self, text: &str) -> usize {
if text.is_empty() {
return 0;
}
(text.len() as f64 / self.chars_per_token).ceil() as usize
}
fn name(&self) -> &str {
&self.model_name
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Priority {
Critical,
High,
Normal,
Low,
Optional,
}
impl Priority {
pub fn weight(&self) -> u32 {
match self {
Priority::Critical => 4,
Priority::High => 3,
Priority::Normal => 2,
Priority::Low => 1,
Priority::Optional => 0,
}
}
}
impl PartialOrd for Priority {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Priority {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.weight().cmp(&other.weight())
}
}
#[derive(Debug, Clone)]
pub struct ContextItem {
pub text: String,
pub priority: Priority,
pub token_count: usize,
pub label: Option<String>,
}
impl ContextItem {
pub fn to_json(&self) -> Value {
serde_json::json!({
"text": self.text,
"priority": format!("{:?}", self.priority),
"token_count": self.token_count,
"label": self.label,
})
}
}
pub struct ContextWindow {
max_tokens: usize,
counter: Box<dyn TokenCounter>,
items: Vec<ContextItem>,
}
impl ContextWindow {
pub fn new(max_tokens: usize, counter: Box<dyn TokenCounter>) -> Self {
Self {
max_tokens,
counter,
items: Vec::new(),
}
}
pub fn add(&mut self, text: &str, priority: Priority) -> bool {
let token_count = self.counter.count(text);
if self.used_tokens() + token_count > self.max_tokens {
return false;
}
self.items.push(ContextItem {
text: text.to_string(),
priority,
token_count,
label: None,
});
true
}
pub fn add_labeled(&mut self, text: &str, priority: Priority, label: &str) -> bool {
let token_count = self.counter.count(text);
if self.used_tokens() + token_count > self.max_tokens {
return false;
}
self.items.push(ContextItem {
text: text.to_string(),
priority,
token_count,
label: Some(label.to_string()),
});
true
}
pub fn remaining_tokens(&self) -> usize {
self.max_tokens.saturating_sub(self.used_tokens())
}
pub fn used_tokens(&self) -> usize {
self.items.iter().map(|i| i.token_count).sum()
}
pub fn utilization(&self) -> f64 {
if self.max_tokens == 0 {
return 0.0;
}
self.used_tokens() as f64 / self.max_tokens as f64
}
pub fn content(&self) -> Vec<&ContextItem> {
self.items.iter().collect()
}
pub fn clear(&mut self) {
self.items.clear();
}
pub fn trim_to_fit(&mut self) {
while self.used_tokens() > self.max_tokens {
let candidate = self
.items
.iter()
.enumerate()
.filter(|(_, item)| item.priority != Priority::Critical)
.min_by_key(|(_, item)| item.priority.weight());
match candidate {
Some((idx, _)) => {
self.items.remove(idx);
}
None => break, }
}
}
}
#[derive(Debug, Clone)]
pub struct TokenBudget {
total: usize,
allocations: HashMap<String, usize>,
used: HashMap<String, usize>,
}
impl TokenBudget {
pub fn new(total: usize) -> Self {
Self {
total,
allocations: HashMap::new(),
used: HashMap::new(),
}
}
pub fn allocate(&mut self, section: &str, tokens: usize) {
*self.allocations.entry(section.to_string()).or_insert(0) += tokens;
self.used.entry(section.to_string()).or_insert(0);
}
pub fn use_tokens(&mut self, section: &str, tokens: usize) {
*self.used.entry(section.to_string()).or_insert(0) += tokens;
}
pub fn remaining(&self, section: &str) -> usize {
let allocated = self.allocations.get(section).copied().unwrap_or(0);
let used = self.used.get(section).copied().unwrap_or(0);
allocated.saturating_sub(used)
}
pub fn used(&self, section: &str) -> usize {
self.used.get(section).copied().unwrap_or(0)
}
pub fn total_remaining(&self) -> usize {
let total_used: usize = self.used.values().sum();
self.total.saturating_sub(total_used)
}
pub fn sections(&self) -> Vec<&str> {
self.allocations.keys().map(|s| s.as_str()).collect()
}
pub fn to_json(&self) -> Value {
let sections: HashMap<&str, Value> = self
.allocations
.keys()
.map(|k| {
(
k.as_str(),
serde_json::json!({
"allocated": self.allocations[k],
"used": self.used.get(k).copied().unwrap_or(0),
"remaining": self.remaining(k),
}),
)
})
.collect();
serde_json::json!({
"total": self.total,
"total_remaining": self.total_remaining(),
"sections": sections,
})
}
}
#[derive(Debug, Clone)]
struct UsageRecord {
model: String,
prompt_tokens: usize,
completion_tokens: usize,
}
#[derive(Debug, Clone)]
pub struct TokenUsageTracker {
records: Vec<UsageRecord>,
}
impl TokenUsageTracker {
pub fn new() -> Self {
Self {
records: Vec::new(),
}
}
pub fn record(&mut self, model: &str, prompt_tokens: usize, completion_tokens: usize) {
self.records.push(UsageRecord {
model: model.to_string(),
prompt_tokens,
completion_tokens,
});
}
pub fn total_tokens(&self) -> usize {
self.records
.iter()
.map(|r| r.prompt_tokens + r.completion_tokens)
.sum()
}
pub fn by_model(&self) -> HashMap<String, usize> {
let mut map: HashMap<String, usize> = HashMap::new();
for r in &self.records {
*map.entry(r.model.clone()).or_insert(0) += r.prompt_tokens + r.completion_tokens;
}
map
}
pub fn avg_per_request(&self) -> f64 {
if self.records.is_empty() {
return 0.0;
}
self.total_tokens() as f64 / self.records.len() as f64
}
pub fn to_json(&self) -> Value {
serde_json::json!({
"total_tokens": self.total_tokens(),
"request_count": self.records.len(),
"avg_per_request": self.avg_per_request(),
"by_model": self.by_model(),
})
}
}
impl Default for TokenUsageTracker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn simple_counter_empty() {
let c = SimpleTokenCounter::new();
assert_eq!(c.count(""), 0);
}
#[test]
fn simple_counter_single_word() {
let c = SimpleTokenCounter::new();
assert_eq!(c.count("hello"), 2);
}
#[test]
fn simple_counter_multiple_words() {
let c = SimpleTokenCounter::new();
assert_eq!(c.count("hello beautiful world"), 4);
}
#[test]
fn simple_counter_name() {
assert_eq!(SimpleTokenCounter::new().name(), "simple");
}
#[test]
fn simple_counter_messages() {
let c = SimpleTokenCounter::new();
let msgs = vec![
serde_json::json!({"content": "hello"}),
serde_json::json!({"content": "world"}),
];
assert_eq!(c.count_messages(&msgs), 13);
}
#[test]
fn char_counter_empty() {
let c = CharBasedCounter::default();
assert_eq!(c.count(""), 0);
}
#[test]
fn char_counter_default_ratio() {
let c = CharBasedCounter::default();
assert_eq!(c.count("hello"), 2);
}
#[test]
fn char_counter_custom_ratio() {
let c = CharBasedCounter::new(2.0);
assert_eq!(c.count("hello"), 3);
}
#[test]
fn char_counter_name() {
assert_eq!(CharBasedCounter::default().name(), "char_based");
}
#[test]
fn char_counter_long_text() {
let c = CharBasedCounter::new(4.0);
let text = "a".repeat(100);
assert_eq!(c.count(&text), 25); }
#[test]
fn model_counter_gpt4() {
let c = ModelTokenCounter::for_model("gpt-4");
assert!((c.chars_per_token() - 4.0).abs() < f64::EPSILON);
assert_eq!(c.count("hello"), 2);
}
#[test]
fn model_counter_claude() {
let c = ModelTokenCounter::for_model("claude-3-opus");
assert!((c.chars_per_token() - 3.5).abs() < f64::EPSILON);
assert_eq!(c.count("hello"), 2);
assert_eq!(c.count("abcdefg"), 2);
}
#[test]
fn model_counter_gemini() {
let c = ModelTokenCounter::for_model("gemini-pro");
assert!((c.chars_per_token() - 4.0).abs() < f64::EPSILON);
}
#[test]
fn model_counter_unknown_default() {
let c = ModelTokenCounter::for_model("llama-3");
assert!((c.chars_per_token() - 4.0).abs() < f64::EPSILON);
}
#[test]
fn model_counter_name() {
let c = ModelTokenCounter::for_model("gpt-4");
assert_eq!(c.name(), "gpt-4");
}
#[test]
fn model_counter_empty() {
let c = ModelTokenCounter::for_model("gpt-4");
assert_eq!(c.count(""), 0);
}
#[test]
fn priority_ordering() {
assert!(Priority::Critical > Priority::High);
assert!(Priority::High > Priority::Normal);
assert!(Priority::Normal > Priority::Low);
assert!(Priority::Low > Priority::Optional);
}
#[test]
fn priority_weights() {
assert_eq!(Priority::Critical.weight(), 4);
assert_eq!(Priority::High.weight(), 3);
assert_eq!(Priority::Normal.weight(), 2);
assert_eq!(Priority::Low.weight(), 1);
assert_eq!(Priority::Optional.weight(), 0);
}
#[test]
fn context_item_to_json() {
let item = ContextItem {
text: "hello".to_string(),
priority: Priority::High,
token_count: 2,
label: Some("greeting".to_string()),
};
let json = item.to_json();
assert_eq!(json["text"], "hello");
assert_eq!(json["priority"], "High");
assert_eq!(json["token_count"], 2);
assert_eq!(json["label"], "greeting");
}
#[test]
fn context_item_to_json_no_label() {
let item = ContextItem {
text: "x".to_string(),
priority: Priority::Low,
token_count: 1,
label: None,
};
let json = item.to_json();
assert!(json["label"].is_null());
}
#[test]
fn context_window_add_and_counts() {
let mut cw = ContextWindow::new(100, Box::new(CharBasedCounter::default()));
assert!(cw.add("hello world!", Priority::Normal));
assert!(cw.used_tokens() > 0);
assert!(cw.remaining_tokens() < 100);
}
#[test]
fn context_window_add_exceeds_budget() {
let mut cw = ContextWindow::new(1, Box::new(CharBasedCounter::default()));
assert!(!cw.add("hello world", Priority::Normal));
assert_eq!(cw.used_tokens(), 0);
}
#[test]
fn context_window_utilization() {
let mut cw = ContextWindow::new(10, Box::new(CharBasedCounter::new(1.0)));
cw.add("hello", Priority::Normal);
assert!((cw.utilization() - 0.5).abs() < f64::EPSILON);
}
#[test]
fn context_window_utilization_empty_budget() {
let cw = ContextWindow::new(0, Box::new(CharBasedCounter::default()));
assert!((cw.utilization() - 0.0).abs() < f64::EPSILON);
}
#[test]
fn context_window_clear() {
let mut cw = ContextWindow::new(100, Box::new(CharBasedCounter::default()));
cw.add("stuff", Priority::Normal);
cw.clear();
assert_eq!(cw.used_tokens(), 0);
assert!(cw.content().is_empty());
}
#[test]
fn context_window_content_returns_items() {
let mut cw = ContextWindow::new(100, Box::new(CharBasedCounter::default()));
cw.add("a", Priority::High);
cw.add("b", Priority::Low);
assert_eq!(cw.content().len(), 2);
}
#[test]
fn context_window_trim_removes_lowest_priority() {
let counter = CharBasedCounter::new(1.0);
let mut cw = ContextWindow::new(100, Box::new(counter));
cw.add("aaaa", Priority::Optional); cw.add("bbbb", Priority::High); cw.add("cccc", Priority::Normal);
cw.max_tokens = 8;
cw.trim_to_fit();
assert_eq!(cw.content().len(), 2);
assert!(cw
.content()
.iter()
.all(|i| i.priority != Priority::Optional));
}
#[test]
fn context_window_trim_preserves_critical() {
let counter = CharBasedCounter::new(1.0);
let mut cw = ContextWindow::new(100, Box::new(counter));
cw.add("critical", Priority::Critical); cw.add("optional", Priority::Optional);
cw.max_tokens = 10;
cw.trim_to_fit();
assert_eq!(cw.content().len(), 1);
assert_eq!(cw.content()[0].priority, Priority::Critical);
}
#[test]
fn context_window_add_labeled() {
let mut cw = ContextWindow::new(100, Box::new(CharBasedCounter::default()));
assert!(cw.add_labeled("data", Priority::Normal, "my_label"));
let items = cw.content();
assert_eq!(items[0].label.as_deref(), Some("my_label"));
}
#[test]
fn budget_allocate_and_remaining() {
let mut b = TokenBudget::new(1000);
b.allocate("system", 200);
b.allocate("history", 500);
assert_eq!(b.remaining("system"), 200);
assert_eq!(b.remaining("history"), 500);
}
#[test]
fn budget_use_tokens() {
let mut b = TokenBudget::new(1000);
b.allocate("system", 200);
b.use_tokens("system", 50);
assert_eq!(b.remaining("system"), 150);
assert_eq!(b.used("system"), 50);
}
#[test]
fn budget_total_remaining() {
let mut b = TokenBudget::new(1000);
b.allocate("a", 400);
b.allocate("b", 400);
b.use_tokens("a", 100);
b.use_tokens("b", 200);
assert_eq!(b.total_remaining(), 700);
}
#[test]
fn budget_sections() {
let mut b = TokenBudget::new(1000);
b.allocate("system", 200);
b.allocate("history", 500);
let mut sections = b.sections();
sections.sort();
assert_eq!(sections, vec!["history", "system"]);
}
#[test]
fn budget_unknown_section() {
let b = TokenBudget::new(1000);
assert_eq!(b.remaining("nonexistent"), 0);
assert_eq!(b.used("nonexistent"), 0);
}
#[test]
fn budget_to_json() {
let mut b = TokenBudget::new(500);
b.allocate("prompt", 300);
b.use_tokens("prompt", 100);
let json = b.to_json();
assert_eq!(json["total"], 500);
assert_eq!(json["total_remaining"], 400);
assert_eq!(json["sections"]["prompt"]["allocated"], 300);
assert_eq!(json["sections"]["prompt"]["used"], 100);
assert_eq!(json["sections"]["prompt"]["remaining"], 200);
}
#[test]
fn budget_cumulative_allocation() {
let mut b = TokenBudget::new(1000);
b.allocate("x", 100);
b.allocate("x", 50);
assert_eq!(b.remaining("x"), 150);
}
#[test]
fn tracker_empty() {
let t = TokenUsageTracker::new();
assert_eq!(t.total_tokens(), 0);
assert_eq!(t.avg_per_request(), 0.0);
assert!(t.by_model().is_empty());
}
#[test]
fn tracker_record_and_total() {
let mut t = TokenUsageTracker::new();
t.record("gpt-4", 100, 50);
t.record("gpt-4", 200, 100);
assert_eq!(t.total_tokens(), 450);
}
#[test]
fn tracker_by_model() {
let mut t = TokenUsageTracker::new();
t.record("gpt-4", 100, 50);
t.record("claude", 200, 100);
t.record("gpt-4", 50, 25);
let by_model = t.by_model();
assert_eq!(by_model["gpt-4"], 225);
assert_eq!(by_model["claude"], 300);
}
#[test]
fn tracker_avg_per_request() {
let mut t = TokenUsageTracker::new();
t.record("gpt-4", 100, 100); t.record("gpt-4", 200, 200); assert!((t.avg_per_request() - 300.0).abs() < f64::EPSILON);
}
#[test]
fn tracker_to_json() {
let mut t = TokenUsageTracker::new();
t.record("gpt-4", 100, 50);
let json = t.to_json();
assert_eq!(json["total_tokens"], 150);
assert_eq!(json["request_count"], 1);
assert_eq!(json["avg_per_request"], 150.0);
}
#[test]
fn tracker_default() {
let t = TokenUsageTracker::default();
assert_eq!(t.total_tokens(), 0);
}
#[test]
fn count_messages_empty_content() {
let c = CharBasedCounter::default();
let msgs = vec![serde_json::json!({"role": "user"})]; assert_eq!(c.count_messages(&msgs), 6);
}
#[test]
fn count_messages_empty_slice() {
let c = CharBasedCounter::default();
assert_eq!(c.count_messages(&[]), 3);
}
}