aster/context/
token_estimator.rs1use crate::context::types::{CHARS_PER_TOKEN_ASIAN, CHARS_PER_TOKEN_CODE, CHARS_PER_TOKEN_DEFAULT};
18use crate::conversation::message::{Message, MessageContent};
19
20const MESSAGE_OVERHEAD_TOKENS: usize = 4;
22
23pub struct TokenEstimator;
27
28impl TokenEstimator {
29 pub fn estimate_tokens(text: &str) -> usize {
56 if text.is_empty() {
57 return 0;
58 }
59
60 let chars_per_token = if Self::has_asian_chars(text) {
62 CHARS_PER_TOKEN_ASIAN
63 } else if Self::is_code(text) {
64 CHARS_PER_TOKEN_CODE
65 } else {
66 CHARS_PER_TOKEN_DEFAULT
67 };
68
69 let char_count = text.chars().count();
71
72 let base_tokens = (char_count as f64 / chars_per_token).ceil() as usize;
74
75 let special_weight = Self::calculate_special_weight(text);
77
78 base_tokens + special_weight
79 }
80
81 pub fn has_asian_chars(text: &str) -> bool {
91 let total_chars = text.chars().count();
92 if total_chars == 0 {
93 return false;
94 }
95
96 let asian_count = text.chars().filter(|c| Self::is_asian_char(*c)).count();
97
98 (asian_count as f64 / total_chars as f64) > 0.2
100 }
101
102 fn is_asian_char(c: char) -> bool {
104 matches!(c,
105 '\u{4E00}'..='\u{9FFF}' |
107 '\u{3400}'..='\u{4DBF}' |
109 '\u{20000}'..='\u{2A6DF}' |
111 '\u{F900}'..='\u{FAFF}' |
113 '\u{3040}'..='\u{309F}' |
115 '\u{30A0}'..='\u{30FF}' |
117 '\u{AC00}'..='\u{D7AF}' |
119 '\u{1100}'..='\u{11FF}' |
121 '\u{3100}'..='\u{312F}'
123 )
124 }
125
126 pub fn is_code(text: &str) -> bool {
141 if text.contains("```") || text.contains("~~~") {
143 return true;
144 }
145
146 let code_indicators = [
148 '{', '}', '[', ']', '(', ')', ';', '=', '+', '-', '*', '/', '<', '>', '&', '|', '!',
149 ];
150
151 let total_chars = text.chars().count();
152 if total_chars == 0 {
153 return false;
154 }
155
156 let code_char_count = text.chars().filter(|c| code_indicators.contains(c)).count();
157
158 let has_code_patterns = text.contains("fn ")
160 || text.contains("def ")
161 || text.contains("function ")
162 || text.contains("class ")
163 || text.contains("const ")
164 || text.contains("let ")
165 || text.contains("var ")
166 || text.contains("import ")
167 || text.contains("pub ")
168 || text.contains("async ")
169 || text.contains("await ")
170 || text.contains("return ")
171 || text.contains("if ")
172 || text.contains("for ")
173 || text.contains("while ");
174
175 let has_indentation_with_code = text.lines().any(|line| {
178 let trimmed = line.trim_start();
179 let indent_size = line.len() - trimmed.len();
180 indent_size >= 2
182 && (trimmed.contains('{')
183 || trimmed.contains('}')
184 || trimmed.contains(';')
185 || trimmed.starts_with("let ")
186 || trimmed.starts_with("const ")
187 || trimmed.starts_with("return ")
188 || trimmed.starts_with("if ")
189 || trimmed.starts_with("for ")
190 || trimmed.starts_with("while ")
191 || trimmed.starts_with("//")
192 || trimmed.starts_with("#"))
193 });
194
195 (code_char_count as f64 / total_chars as f64) > 0.05
200 || has_code_patterns
201 || has_indentation_with_code
202 }
203
204 fn calculate_special_weight(text: &str) -> usize {
206 let newline_count = text.chars().filter(|c| *c == '\n').count();
207 let special_count = text
208 .chars()
209 .filter(|c| {
210 matches!(
211 c,
212 '\t' | '\r' | '\\' | '"' | '\'' | '`' | '~' | '@' | '#' | '$' | '%' | '^'
213 )
214 })
215 .count();
216
217 (newline_count as f64 * 0.5).ceil() as usize + (special_count as f64 * 0.25).ceil() as usize
219 }
220
221 pub fn estimate_message_tokens(message: &Message) -> usize {
233 let content_tokens: usize = message
234 .content
235 .iter()
236 .map(Self::estimate_content_tokens)
237 .sum();
238
239 content_tokens + MESSAGE_OVERHEAD_TOKENS
240 }
241
242 fn estimate_content_tokens(content: &MessageContent) -> usize {
244 match content {
245 MessageContent::Text(text_content) => Self::estimate_tokens(&text_content.text),
246 MessageContent::Image(_) => {
247 1600
250 }
251 MessageContent::ToolRequest(tool_request) => {
252 let mut tokens = 10; if let Ok(call) = &tool_request.tool_call {
256 tokens += Self::estimate_tokens(&call.name);
257 if let Some(args) = &call.arguments {
258 let args_str = serde_json::to_string(args).unwrap_or_default();
259 tokens += Self::estimate_tokens(&args_str);
260 }
261 }
262
263 tokens
264 }
265 MessageContent::ToolResponse(tool_response) => {
266 let mut tokens = 10; if let Ok(result) = &tool_response.tool_result {
269 for content in &result.content {
270 if let Some(text) = content.as_text() {
271 tokens += Self::estimate_tokens(&text.text);
272 }
273 }
274 }
275
276 tokens
277 }
278 MessageContent::Thinking(thinking) => Self::estimate_tokens(&thinking.thinking),
279 MessageContent::RedactedThinking(_) => 50, MessageContent::ToolConfirmationRequest(req) => {
281 let args_str = serde_json::to_string(&req.arguments).unwrap_or_default();
282 10 + Self::estimate_tokens(&req.tool_name) + Self::estimate_tokens(&args_str)
283 }
284 MessageContent::ActionRequired(action) => {
285 match &action.data {
286 crate::conversation::message::ActionRequiredData::ToolConfirmation {
287 tool_name,
288 arguments,
289 ..
290 } => {
291 let args_str = serde_json::to_string(arguments).unwrap_or_default();
292 10 + Self::estimate_tokens(tool_name) + Self::estimate_tokens(&args_str)
293 }
294 crate::conversation::message::ActionRequiredData::Elicitation {
295 message,
296 ..
297 } => 10 + Self::estimate_tokens(message),
298 crate::conversation::message::ActionRequiredData::ElicitationResponse {
299 ..
300 } => 20, }
302 }
303 MessageContent::FrontendToolRequest(req) => {
304 let mut tokens = 10;
305 if let Ok(call) = &req.tool_call {
306 tokens += Self::estimate_tokens(&call.name);
307 if let Some(args) = &call.arguments {
308 let args_str = serde_json::to_string(args).unwrap_or_default();
309 tokens += Self::estimate_tokens(&args_str);
310 }
311 }
312 tokens
313 }
314 MessageContent::SystemNotification(notification) => {
315 Self::estimate_tokens(¬ification.msg)
316 }
317 }
318 }
319
320 pub fn estimate_total_tokens(messages: &[Message]) -> usize {
330 messages.iter().map(Self::estimate_message_tokens).sum()
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337
338 #[test]
339 fn test_estimate_tokens_empty() {
340 assert_eq!(TokenEstimator::estimate_tokens(""), 0);
341 }
342
343 #[test]
344 fn test_estimate_tokens_english() {
345 let text = "Hello, world! This is a test.";
346 let tokens = TokenEstimator::estimate_tokens(text);
347 assert!(tokens > 0);
349 assert!(tokens < 20);
350 }
351
352 #[test]
353 fn test_estimate_tokens_chinese() {
354 let text = "你好世界,这是一个测试。";
355 let tokens = TokenEstimator::estimate_tokens(text);
356 assert!(tokens > 0);
358 assert!(tokens < 15);
359 }
360
361 #[test]
362 fn test_estimate_tokens_code() {
363 let text = r#"
364fn main() {
365 println!("Hello, world!");
366}
367"#;
368 let tokens = TokenEstimator::estimate_tokens(text);
369 assert!(tokens > 0);
370 }
371
372 #[test]
373 fn test_has_asian_chars_chinese() {
374 assert!(TokenEstimator::has_asian_chars("你好世界"));
375 assert!(TokenEstimator::has_asian_chars("Hello 你好"));
376 }
377
378 #[test]
379 fn test_has_asian_chars_japanese() {
380 assert!(TokenEstimator::has_asian_chars("こんにちは"));
381 assert!(TokenEstimator::has_asian_chars("カタカナ"));
382 }
383
384 #[test]
385 fn test_has_asian_chars_korean() {
386 assert!(TokenEstimator::has_asian_chars("안녕하세요"));
387 }
388
389 #[test]
390 fn test_has_asian_chars_english() {
391 assert!(!TokenEstimator::has_asian_chars("Hello, world!"));
392 assert!(!TokenEstimator::has_asian_chars(""));
393 }
394
395 #[test]
396 fn test_is_code_rust() {
397 let code = r#"
398fn main() {
399 let x = 5;
400 println!("{}", x);
401}
402"#;
403 assert!(TokenEstimator::is_code(code));
404 }
405
406 #[test]
407 fn test_is_code_javascript() {
408 let code = r#"
409function hello() {
410 const x = 5;
411 return x + 1;
412}
413"#;
414 assert!(TokenEstimator::is_code(code));
415 }
416
417 #[test]
418 fn test_is_code_python() {
419 let code = r#"
420def hello():
421 x = 5
422 return x + 1
423"#;
424 assert!(TokenEstimator::is_code(code));
425 }
426
427 #[test]
428 fn test_is_code_markdown_block() {
429 let text = "```rust\nfn main() {}\n```";
430 assert!(TokenEstimator::is_code(text));
431 }
432
433 #[test]
434 fn test_is_code_plain_text() {
435 let text = "This is just plain English text without any code.";
436 assert!(!TokenEstimator::is_code(text));
437 }
438
439 #[test]
440 fn test_estimate_message_tokens() {
441 let message = Message::user().with_text("Hello, world!");
442 let tokens = TokenEstimator::estimate_message_tokens(&message);
443 assert!(tokens >= MESSAGE_OVERHEAD_TOKENS);
445 }
446
447 #[test]
448 fn test_estimate_total_tokens() {
449 let messages = vec![
450 Message::user().with_text("Hello"),
451 Message::assistant().with_text("Hi there!"),
452 ];
453 let total = TokenEstimator::estimate_total_tokens(&messages);
454 assert!(total > 0);
455 assert!(total >= MESSAGE_OVERHEAD_TOKENS * 2);
456 }
457
458 #[test]
459 fn test_estimate_tokens_with_newlines() {
460 let text = "Line 1\nLine 2\nLine 3";
461 let tokens = TokenEstimator::estimate_tokens(text);
462 assert!(tokens > 0);
464 }
465
466 #[test]
467 fn test_estimate_tokens_with_special_chars() {
468 let text = "Hello @user #tag $var %percent";
469 let tokens = TokenEstimator::estimate_tokens(text);
470 assert!(tokens > 0);
472 }
473}