Skip to main content

mermaid_cli/utils/
tokenizer.rs

1use anyhow::Result;
2
3/// Token counting utility using character-based estimation
4/// Uses ~4 characters per token as a reasonable approximation
5pub struct Tokenizer {
6    model_name: String,
7}
8
9impl Tokenizer {
10    /// Create a new tokenizer for the given model
11    pub fn new(model_name: &str) -> Self {
12        Self {
13            model_name: model_name.to_string(),
14        }
15    }
16
17    /// Count tokens in a single text string (~4 chars per token)
18    pub fn count_tokens(&self, text: &str) -> Result<usize> {
19        Ok((text.len() + 3) / 4)
20    }
21
22    /// Count tokens in a chat message format
23    pub fn count_chat_tokens(&self, messages: &[(String, String)]) -> Result<usize> {
24        let total_chars: usize = messages
25            .iter()
26            .map(|(role, content)| role.len() + content.len() + 4) // +4 for message overhead
27            .sum();
28        Ok((total_chars + 3) / 4)
29    }
30
31    /// Get the maximum tokens for a model
32    pub fn get_max_tokens(&self) -> usize {
33        let model_name = self.get_base_model_name();
34
35        if model_name.contains("gpt-4o") {
36            128000
37        } else if model_name.contains("gpt-4-turbo") || model_name.contains("gpt-4-1106") {
38            128000
39        } else if model_name.contains("gpt-4-32k") {
40            32768
41        } else if model_name.contains("gpt-4") {
42            8192
43        } else if model_name.contains("gpt-3.5-turbo-16k") {
44            16384
45        } else if model_name.contains("gpt-3.5-turbo") {
46            4096
47        } else if model_name.contains("claude-3") {
48            200000
49        } else if model_name.contains("claude") {
50            100000
51        } else if model_name.contains("llama-3") {
52            8192
53        } else if model_name.contains("llama-2") {
54            4096
55        } else if model_name.contains("codellama") {
56            16384
57        } else if model_name.contains("deepseek-coder") {
58            65536
59        } else if model_name.contains("qwen") {
60            32768
61        } else if model_name.contains("mistral") || model_name.contains("mixtral") {
62            32768
63        } else {
64            8192 // Conservative default
65        }
66    }
67
68    /// Calculate remaining tokens in context window
69    pub fn remaining_tokens(&self, used_tokens: usize) -> usize {
70        let max_tokens = self.get_max_tokens();
71        max_tokens.saturating_sub(used_tokens)
72    }
73
74    /// Get the base model name (strip provider prefix)
75    fn get_base_model_name(&self) -> String {
76        if let Some(idx) = self.model_name.find('/') {
77            self.model_name[idx + 1..].to_string()
78        } else {
79            self.model_name.clone()
80        }
81    }
82}
83
84/// Count tokens in file contents (~4 chars per token)
85pub fn count_file_tokens(content: &str, _model_name: &str) -> usize {
86    (content.len() + 3) / 4
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn test_token_counting() {
95        let tokenizer = Tokenizer::new("gpt-3.5-turbo");
96        let text = "Hello, world! This is a test message.";
97        let count = tokenizer.count_tokens(text).unwrap();
98        assert!(count > 0);
99        assert!(count < text.len());
100    }
101
102    #[test]
103    fn test_model_name_extraction() {
104        let tokenizer = Tokenizer::new("ollama/gpt-4");
105        assert_eq!(tokenizer.get_base_model_name(), "gpt-4");
106
107        let tokenizer = Tokenizer::new("unknown-model");
108        assert_eq!(tokenizer.get_base_model_name(), "unknown-model");
109    }
110
111    #[test]
112    fn test_max_tokens() {
113        let tokenizer = Tokenizer::new("gpt-4");
114        assert_eq!(tokenizer.get_max_tokens(), 8192);
115
116        let tokenizer = Tokenizer::new("gpt-4o");
117        assert_eq!(tokenizer.get_max_tokens(), 128000);
118
119        let tokenizer = Tokenizer::new("gpt-3.5-turbo");
120        assert_eq!(tokenizer.get_max_tokens(), 4096);
121    }
122}