Skip to main content

ai_lib_rust/tokens/
counter.rs

1//! Token counter implementations.
2
3use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5use crate::types::Message;
6use crate::types::message::{ContentBlock, MessageContent};
7
8pub trait TokenCounter: Send + Sync {
9    fn count(&self, text: &str) -> usize;
10
11    fn count_messages(&self, messages: &[Message]) -> usize {
12        let mut total = 0;
13        for message in messages {
14            total += 1;
15            match &message.content {
16                MessageContent::Text(text) => { total += self.count(text); }
17                MessageContent::Blocks(blocks) => {
18                    for block in blocks {
19                        match block {
20                            ContentBlock::Text { text } => { total += self.count(text); }
21                            ContentBlock::Image { .. } => { total += 85; }
22                            ContentBlock::Audio { .. } => { total += 100; }
23                            ContentBlock::ToolUse { input, .. } => { total += self.count(&serde_json::to_string(input).unwrap_or_default()); }
24                            ContentBlock::ToolResult { content, .. } => { total += self.count(&serde_json::to_string(content).unwrap_or_default()); }
25                        }
26                    }
27                }
28            }
29        }
30        total + messages.len() * 3
31    }
32
33    fn truncate_to_limit(&self, text: &str, max_tokens: usize, suffix: &str) -> String {
34        let current = self.count(text);
35        if current <= max_tokens { return text.to_string(); }
36        let suffix_tokens = if suffix.is_empty() { 0 } else { self.count(suffix) };
37        let target = max_tokens.saturating_sub(suffix_tokens);
38        if target == 0 { return suffix.to_string(); }
39        let chars_per_token = text.len() as f64 / current as f64;
40        let mut truncated: String = text.chars().take((target as f64 * chars_per_token) as usize).collect();
41        while self.count(&truncated) > target && !truncated.is_empty() { truncated = truncated.chars().take((truncated.len() as f64 * 0.9) as usize).collect(); }
42        format!("{}{}", truncated, suffix)
43    }
44}
45
46#[derive(Debug, Clone)]
47pub struct CharacterEstimator { chars_per_token: f64 }
48impl CharacterEstimator {
49    pub fn new() -> Self { Self::with_ratio(4.0) }
50    pub fn with_ratio(r: f64) -> Self { Self { chars_per_token: r } }
51}
52impl Default for CharacterEstimator { fn default() -> Self { Self::new() } }
53impl TokenCounter for CharacterEstimator { fn count(&self, text: &str) -> usize { (text.len() as f64 / self.chars_per_token).ceil() as usize } }
54
55#[derive(Debug, Clone)]
56pub struct AnthropicEstimator { chars_per_token: f64 }
57impl AnthropicEstimator { pub fn new() -> Self { Self { chars_per_token: 3.5 } } }
58impl Default for AnthropicEstimator { fn default() -> Self { Self::new() } }
59impl TokenCounter for AnthropicEstimator {
60    fn count(&self, text: &str) -> usize {
61        let base = (text.len() as f64 / self.chars_per_token).ceil() as usize;
62        let ws = text.chars().filter(|c| c.is_whitespace()).count();
63        base + (ws as f64 * 0.1) as usize
64    }
65}
66
67pub struct CachingCounter { inner: Box<dyn TokenCounter>, cache: Arc<RwLock<HashMap<String, usize>>>, max_size: usize }
68impl CachingCounter {
69    pub fn new(inner: Box<dyn TokenCounter>, max_size: usize) -> Self { Self { inner, cache: Arc::new(RwLock::new(HashMap::new())), max_size } }
70    pub fn clear_cache(&self) { self.cache.write().unwrap().clear(); }
71}
72impl TokenCounter for CachingCounter {
73    fn count(&self, text: &str) -> usize {
74        { let c = self.cache.read().unwrap(); if let Some(&n) = c.get(text) { return n; } }
75        let n = self.inner.count(text);
76        { let mut c = self.cache.write().unwrap(); if c.len() < self.max_size { c.insert(text.to_string(), n); } }
77        n
78    }
79}
80
81static COUNTERS: once_cell::sync::Lazy<RwLock<HashMap<String, Arc<dyn TokenCounter>>>> = once_cell::sync::Lazy::new(|| RwLock::new(HashMap::new()));
82
83pub fn get_token_counter(model: &str) -> Arc<dyn TokenCounter> {
84    let ml = model.to_lowercase();
85    { let c = COUNTERS.read().unwrap(); if let Some(x) = c.get(&ml) { return x.clone(); } }
86    let counter: Arc<dyn TokenCounter> = if ml.contains("gpt") || ml.contains("o1") { Arc::new(CharacterEstimator::new()) }
87    else if ml.contains("claude") { Arc::new(AnthropicEstimator::new()) }
88    else { Arc::new(CharacterEstimator::new()) };
89    { let mut c = COUNTERS.write().unwrap(); c.insert(ml, counter.clone()); }
90    counter
91}