1use ahash::AHasher;
2use dashmap::DashMap;
3use rmcp::model::Tool;
4use std::hash::{Hash, Hasher};
5use std::sync::Arc;
6use tiktoken_rs::CoreBPE;
7use tokio::sync::OnceCell;
8
9use crate::conversation::message::Message;
10
11static TOKENIZER: OnceCell<Arc<CoreBPE>> = OnceCell::const_new();
12
13const MAX_TOKEN_CACHE_SIZE: usize = 10_000;
14
15const FUNC_INIT: usize = 7;
17const PROP_INIT: usize = 3;
18const PROP_KEY: usize = 3;
19const ENUM_INIT: isize = -3;
20const ENUM_ITEM: usize = 3;
21const FUNC_END: usize = 12;
22
23pub struct TokenCounter {
24 tokenizer: Arc<CoreBPE>,
25 token_cache: Arc<DashMap<u64, usize>>,
26}
27
28impl TokenCounter {
29 pub async fn new() -> Result<Self, String> {
30 let tokenizer = get_tokenizer().await?;
31 Ok(Self {
32 tokenizer,
33 token_cache: Arc::new(DashMap::new()),
34 })
35 }
36
37 pub fn count_tokens(&self, text: &str) -> usize {
38 let mut hasher = AHasher::default();
39 text.hash(&mut hasher);
40 let hash = hasher.finish();
41
42 if let Some(count) = self.token_cache.get(&hash) {
43 return *count;
44 }
45
46 let tokens = self.tokenizer.encode_with_special_tokens(text);
47 let count = tokens.len();
48
49 if self.token_cache.len() >= MAX_TOKEN_CACHE_SIZE {
50 if let Some(entry) = self.token_cache.iter().next() {
51 let old_hash = *entry.key();
52 self.token_cache.remove(&old_hash);
53 }
54 }
55
56 self.token_cache.insert(hash, count);
57 count
58 }
59
60 pub fn count_tokens_for_tools(&self, tools: &[Tool]) -> usize {
61 let mut func_token_count = 0;
62 if !tools.is_empty() {
63 for tool in tools {
64 func_token_count += FUNC_INIT;
65 let name = &tool.name;
66 let description = &tool
67 .description
68 .as_ref()
69 .map(|d| d.as_ref())
70 .unwrap_or_default()
71 .trim_end_matches('.');
72
73 let line = format!("{}:{}", name, description);
74 func_token_count += self.count_tokens(&line);
75
76 if let Some(serde_json::Value::Object(properties)) =
77 tool.input_schema.get("properties")
78 {
79 if !properties.is_empty() {
80 func_token_count += PROP_INIT;
81 for (key, value) in properties {
82 func_token_count += PROP_KEY;
83 let p_name = key;
84 let p_type = value.get("type").and_then(|v| v.as_str()).unwrap_or("");
85 let p_desc = value
86 .get("description")
87 .and_then(|v| v.as_str())
88 .unwrap_or("")
89 .trim_end_matches('.');
90
91 let line = format!("{}:{}:{}", p_name, p_type, p_desc);
92 func_token_count += self.count_tokens(&line);
93
94 if let Some(enum_values) = value.get("enum").and_then(|v| v.as_array())
95 {
96 func_token_count =
97 func_token_count.saturating_add_signed(ENUM_INIT);
98 for item in enum_values {
99 if let Some(item_str) = item.as_str() {
100 func_token_count += ENUM_ITEM;
101 func_token_count += self.count_tokens(item_str);
102 }
103 }
104 }
105 }
106 }
107 }
108 }
109 func_token_count += FUNC_END;
110 }
111
112 func_token_count
113 }
114
115 pub fn count_chat_tokens(
116 &self,
117 system_prompt: &str,
118 messages: &[Message],
119 tools: &[Tool],
120 ) -> usize {
121 let tokens_per_message = 4;
122 let mut num_tokens = 0;
123
124 if !system_prompt.is_empty() {
125 num_tokens += self.count_tokens(system_prompt) + tokens_per_message;
126 }
127
128 for message in messages {
129 if !message.metadata.agent_visible {
130 continue;
131 }
132 num_tokens += tokens_per_message;
133 for content in &message.content {
134 if let Some(content_text) = content.as_text() {
135 num_tokens += self.count_tokens(content_text);
136 } else if let Some(tool_request) = content.as_tool_request() {
137 if let Ok(tool_call) = tool_request.tool_call.as_ref() {
138 let text = format!(
139 "{}:{}:{:?}",
140 tool_request.id, tool_call.name, tool_call.arguments
141 );
142 num_tokens += self.count_tokens(&text);
143 }
144 } else if let Some(tool_response_text) = content.as_tool_response_text() {
145 num_tokens += self.count_tokens(&tool_response_text);
146 }
147 }
148 }
149
150 if !tools.is_empty() {
151 num_tokens += self.count_tokens_for_tools(tools);
152 }
153
154 num_tokens += 3; num_tokens
157 }
158
159 pub fn count_everything(
160 &self,
161 system_prompt: &str,
162 messages: &[Message],
163 tools: &[Tool],
164 resources: &[String],
165 ) -> usize {
166 let mut num_tokens = self.count_chat_tokens(system_prompt, messages, tools);
167
168 if !resources.is_empty() {
169 for resource in resources {
170 num_tokens += self.count_tokens(resource);
171 }
172 }
173 num_tokens
174 }
175
176 pub fn clear_cache(&self) {
177 self.token_cache.clear();
178 }
179
180 pub fn cache_size(&self) -> usize {
181 self.token_cache.len()
182 }
183}
184
185async fn get_tokenizer() -> Result<Arc<CoreBPE>, String> {
186 let tokenizer = TOKENIZER
187 .get_or_init(|| async {
188 match tiktoken_rs::o200k_base() {
189 Ok(bpe) => Arc::new(bpe),
190 Err(e) => panic!("Failed to initialize o200k_base tokenizer: {}", e),
191 }
192 })
193 .await;
194 Ok(tokenizer.clone())
195}
196
197pub async fn create_token_counter() -> Result<TokenCounter, String> {
198 TokenCounter::new().await
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 #[tokio::test]
206 async fn test_token_caching() {
207 let counter = create_token_counter().await.unwrap();
208
209 let text = "This is a test for caching functionality";
210
211 let count1 = counter.count_tokens(text);
212 assert_eq!(counter.cache_size(), 1);
213
214 let count2 = counter.count_tokens(text);
215 assert_eq!(count1, count2);
216 assert_eq!(counter.cache_size(), 1);
217
218 let count3 = counter.count_tokens("Different text");
219 assert_eq!(counter.cache_size(), 2);
220 assert_ne!(count1, count3);
221 }
222
223 #[tokio::test]
224 async fn test_cache_management() {
225 let counter = create_token_counter().await.unwrap();
226
227 counter.count_tokens("First text");
228 counter.count_tokens("Second text");
229 counter.count_tokens("Third text");
230
231 assert_eq!(counter.cache_size(), 3);
232
233 counter.clear_cache();
234 assert_eq!(counter.cache_size(), 0);
235
236 let count = counter.count_tokens("First text");
237 assert!(count > 0);
238 assert_eq!(counter.cache_size(), 1);
239 }
240
241 #[tokio::test]
242 async fn test_concurrent_token_counter_creation() {
243 let handles: Vec<_> = (0..10)
244 .map(|_| tokio::spawn(async { create_token_counter().await.unwrap() }))
245 .collect();
246
247 let counters: Vec<_> = futures::future::join_all(handles)
248 .await
249 .into_iter()
250 .map(|r| r.unwrap())
251 .collect();
252
253 let text = "Test concurrent creation";
254 let expected_count = counters[0].count_tokens(text);
255
256 for counter in &counters {
257 assert_eq!(counter.count_tokens(text), expected_count);
258 }
259 }
260
261 #[tokio::test]
262 async fn test_cache_eviction_behavior() {
263 let counter = create_token_counter().await.unwrap();
264
265 let mut cached_texts = Vec::new();
266 for i in 0..50 {
267 let text = format!("Test string number {}", i);
268 counter.count_tokens(&text);
269 cached_texts.push(text);
270 }
271
272 assert!(counter.cache_size() <= MAX_TOKEN_CACHE_SIZE);
273
274 let recent_text = &cached_texts[cached_texts.len() - 1];
275 let start_size = counter.cache_size();
276
277 counter.count_tokens(recent_text);
278 assert_eq!(counter.cache_size(), start_size);
279 }
280
281 #[tokio::test]
282 async fn test_concurrent_cache_operations() {
283 let counter = std::sync::Arc::new(create_token_counter().await.unwrap());
284
285 let handles: Vec<_> = (0..20)
286 .map(|i| {
287 let counter_clone = counter.clone();
288 tokio::spawn(async move {
289 let text = format!("Concurrent test {}", i % 5);
290 counter_clone.count_tokens(&text)
291 })
292 })
293 .collect();
294
295 let results: Vec<_> = futures::future::join_all(handles)
296 .await
297 .into_iter()
298 .map(|r| r.unwrap())
299 .collect();
300
301 for result in results {
302 assert!(result > 0);
303 }
304
305 assert!(counter.cache_size() > 0);
306 assert!(counter.cache_size() <= MAX_TOKEN_CACHE_SIZE);
307 }
308}