Skip to main content

aster/
token_counter.rs

1use ahash::AHasher;
2use dashmap::DashMap;
3use rmcp::model::Tool;
4use std::hash::{Hash, Hasher};
5use std::sync::Arc;
6use tiktoken_rs::CoreBPE;
7use tokio::sync::OnceCell;
8
9use crate::conversation::message::Message;
10
11static TOKENIZER: OnceCell<Arc<CoreBPE>> = OnceCell::const_new();
12
13const MAX_TOKEN_CACHE_SIZE: usize = 10_000;
14
15// token use for various bits of a tool calls:
16const FUNC_INIT: usize = 7;
17const PROP_INIT: usize = 3;
18const PROP_KEY: usize = 3;
19const ENUM_INIT: isize = -3;
20const ENUM_ITEM: usize = 3;
21const FUNC_END: usize = 12;
22
23pub struct TokenCounter {
24    tokenizer: Arc<CoreBPE>,
25    token_cache: Arc<DashMap<u64, usize>>,
26}
27
28impl TokenCounter {
29    pub async fn new() -> Result<Self, String> {
30        let tokenizer = get_tokenizer().await?;
31        Ok(Self {
32            tokenizer,
33            token_cache: Arc::new(DashMap::new()),
34        })
35    }
36
37    pub fn count_tokens(&self, text: &str) -> usize {
38        let mut hasher = AHasher::default();
39        text.hash(&mut hasher);
40        let hash = hasher.finish();
41
42        if let Some(count) = self.token_cache.get(&hash) {
43            return *count;
44        }
45
46        let tokens = self.tokenizer.encode_with_special_tokens(text);
47        let count = tokens.len();
48
49        if self.token_cache.len() >= MAX_TOKEN_CACHE_SIZE {
50            if let Some(entry) = self.token_cache.iter().next() {
51                let old_hash = *entry.key();
52                self.token_cache.remove(&old_hash);
53            }
54        }
55
56        self.token_cache.insert(hash, count);
57        count
58    }
59
60    pub fn count_tokens_for_tools(&self, tools: &[Tool]) -> usize {
61        let mut func_token_count = 0;
62        if !tools.is_empty() {
63            for tool in tools {
64                func_token_count += FUNC_INIT;
65                let name = &tool.name;
66                let description = &tool
67                    .description
68                    .as_ref()
69                    .map(|d| d.as_ref())
70                    .unwrap_or_default()
71                    .trim_end_matches('.');
72
73                let line = format!("{}:{}", name, description);
74                func_token_count += self.count_tokens(&line);
75
76                if let Some(serde_json::Value::Object(properties)) =
77                    tool.input_schema.get("properties")
78                {
79                    if !properties.is_empty() {
80                        func_token_count += PROP_INIT;
81                        for (key, value) in properties {
82                            func_token_count += PROP_KEY;
83                            let p_name = key;
84                            let p_type = value.get("type").and_then(|v| v.as_str()).unwrap_or("");
85                            let p_desc = value
86                                .get("description")
87                                .and_then(|v| v.as_str())
88                                .unwrap_or("")
89                                .trim_end_matches('.');
90
91                            let line = format!("{}:{}:{}", p_name, p_type, p_desc);
92                            func_token_count += self.count_tokens(&line);
93
94                            if let Some(enum_values) = value.get("enum").and_then(|v| v.as_array())
95                            {
96                                func_token_count =
97                                    func_token_count.saturating_add_signed(ENUM_INIT);
98                                for item in enum_values {
99                                    if let Some(item_str) = item.as_str() {
100                                        func_token_count += ENUM_ITEM;
101                                        func_token_count += self.count_tokens(item_str);
102                                    }
103                                }
104                            }
105                        }
106                    }
107                }
108            }
109            func_token_count += FUNC_END;
110        }
111
112        func_token_count
113    }
114
115    pub fn count_chat_tokens(
116        &self,
117        system_prompt: &str,
118        messages: &[Message],
119        tools: &[Tool],
120    ) -> usize {
121        let tokens_per_message = 4;
122        let mut num_tokens = 0;
123
124        if !system_prompt.is_empty() {
125            num_tokens += self.count_tokens(system_prompt) + tokens_per_message;
126        }
127
128        for message in messages {
129            if !message.metadata.agent_visible {
130                continue;
131            }
132            num_tokens += tokens_per_message;
133            for content in &message.content {
134                if let Some(content_text) = content.as_text() {
135                    num_tokens += self.count_tokens(content_text);
136                } else if let Some(tool_request) = content.as_tool_request() {
137                    if let Ok(tool_call) = tool_request.tool_call.as_ref() {
138                        let text = format!(
139                            "{}:{}:{:?}",
140                            tool_request.id, tool_call.name, tool_call.arguments
141                        );
142                        num_tokens += self.count_tokens(&text);
143                    }
144                } else if let Some(tool_response_text) = content.as_tool_response_text() {
145                    num_tokens += self.count_tokens(&tool_response_text);
146                }
147            }
148        }
149
150        if !tools.is_empty() {
151            num_tokens += self.count_tokens_for_tools(tools);
152        }
153
154        num_tokens += 3; // Reply primer
155
156        num_tokens
157    }
158
159    pub fn count_everything(
160        &self,
161        system_prompt: &str,
162        messages: &[Message],
163        tools: &[Tool],
164        resources: &[String],
165    ) -> usize {
166        let mut num_tokens = self.count_chat_tokens(system_prompt, messages, tools);
167
168        if !resources.is_empty() {
169            for resource in resources {
170                num_tokens += self.count_tokens(resource);
171            }
172        }
173        num_tokens
174    }
175
176    pub fn clear_cache(&self) {
177        self.token_cache.clear();
178    }
179
180    pub fn cache_size(&self) -> usize {
181        self.token_cache.len()
182    }
183}
184
185async fn get_tokenizer() -> Result<Arc<CoreBPE>, String> {
186    let tokenizer = TOKENIZER
187        .get_or_init(|| async {
188            match tiktoken_rs::o200k_base() {
189                Ok(bpe) => Arc::new(bpe),
190                Err(e) => panic!("Failed to initialize o200k_base tokenizer: {}", e),
191            }
192        })
193        .await;
194    Ok(tokenizer.clone())
195}
196
197pub async fn create_token_counter() -> Result<TokenCounter, String> {
198    TokenCounter::new().await
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    #[tokio::test]
206    async fn test_token_caching() {
207        let counter = create_token_counter().await.unwrap();
208
209        let text = "This is a test for caching functionality";
210
211        let count1 = counter.count_tokens(text);
212        assert_eq!(counter.cache_size(), 1);
213
214        let count2 = counter.count_tokens(text);
215        assert_eq!(count1, count2);
216        assert_eq!(counter.cache_size(), 1);
217
218        let count3 = counter.count_tokens("Different text");
219        assert_eq!(counter.cache_size(), 2);
220        assert_ne!(count1, count3);
221    }
222
223    #[tokio::test]
224    async fn test_cache_management() {
225        let counter = create_token_counter().await.unwrap();
226
227        counter.count_tokens("First text");
228        counter.count_tokens("Second text");
229        counter.count_tokens("Third text");
230
231        assert_eq!(counter.cache_size(), 3);
232
233        counter.clear_cache();
234        assert_eq!(counter.cache_size(), 0);
235
236        let count = counter.count_tokens("First text");
237        assert!(count > 0);
238        assert_eq!(counter.cache_size(), 1);
239    }
240
241    #[tokio::test]
242    async fn test_concurrent_token_counter_creation() {
243        let handles: Vec<_> = (0..10)
244            .map(|_| tokio::spawn(async { create_token_counter().await.unwrap() }))
245            .collect();
246
247        let counters: Vec<_> = futures::future::join_all(handles)
248            .await
249            .into_iter()
250            .map(|r| r.unwrap())
251            .collect();
252
253        let text = "Test concurrent creation";
254        let expected_count = counters[0].count_tokens(text);
255
256        for counter in &counters {
257            assert_eq!(counter.count_tokens(text), expected_count);
258        }
259    }
260
261    #[tokio::test]
262    async fn test_cache_eviction_behavior() {
263        let counter = create_token_counter().await.unwrap();
264
265        let mut cached_texts = Vec::new();
266        for i in 0..50 {
267            let text = format!("Test string number {}", i);
268            counter.count_tokens(&text);
269            cached_texts.push(text);
270        }
271
272        assert!(counter.cache_size() <= MAX_TOKEN_CACHE_SIZE);
273
274        let recent_text = &cached_texts[cached_texts.len() - 1];
275        let start_size = counter.cache_size();
276
277        counter.count_tokens(recent_text);
278        assert_eq!(counter.cache_size(), start_size);
279    }
280
281    #[tokio::test]
282    async fn test_concurrent_cache_operations() {
283        let counter = std::sync::Arc::new(create_token_counter().await.unwrap());
284
285        let handles: Vec<_> = (0..20)
286            .map(|i| {
287                let counter_clone = counter.clone();
288                tokio::spawn(async move {
289                    let text = format!("Concurrent test {}", i % 5);
290                    counter_clone.count_tokens(&text)
291                })
292            })
293            .collect();
294
295        let results: Vec<_> = futures::future::join_all(handles)
296            .await
297            .into_iter()
298            .map(|r| r.unwrap())
299            .collect();
300
301        for result in results {
302            assert!(result > 0);
303        }
304
305        assert!(counter.cache_size() > 0);
306        assert!(counter.cache_size() <= MAX_TOKEN_CACHE_SIZE);
307    }
308}