Skip to main content

mermaid_cli/utils/
tokenizer.rs

1use anyhow::Result;
2
3/// Token counting utility using character-based estimation
4/// Uses ~4 characters per token as a reasonable approximation
5pub struct Tokenizer {
6    /// Base model name (provider prefix stripped at construction time)
7    base_model_name: String,
8}
9
10impl Tokenizer {
11    /// Create a new tokenizer for the given model
12    pub fn new(model_name: &str) -> Self {
13        let base = if let Some(idx) = model_name.find('/') {
14            // Safe: '/' is ASCII, so byte offset == char offset
15            &model_name[idx + 1..]
16        } else {
17            model_name
18        };
19        Self {
20            base_model_name: base.to_lowercase(),
21        }
22    }
23
24    /// Count tokens in a single text string (~4 chars per token)
25    pub fn count_tokens(&self, text: &str) -> Result<usize> {
26        Ok(text.len().div_ceil(4))
27    }
28
29    /// Count tokens in a chat message format
30    pub fn count_chat_tokens(&self, messages: &[(String, String)]) -> Result<usize> {
31        let total_chars: usize = messages
32            .iter()
33            .map(|(role, content)| role.len() + content.len() + 4) // +4 for message overhead
34            .sum();
35        Ok(total_chars.div_ceil(4))
36    }
37
38    /// Get the maximum context window for a model (in tokens).
39    ///
40    /// Focused on Ollama model families — the actual models Mermaid supports.
41    /// These are conservative defaults; Ollama may use different context sizes
42    /// depending on the specific model variant and user's num_ctx setting.
43    pub fn get_max_tokens(&self) -> usize {
44        let model_name = &self.base_model_name;
45
46        // Large-context models (128k+)
47        if model_name.contains("qwen3-coder")
48            || model_name.contains("qwen2.5-coder")
49            || model_name.contains("deepseek-v3")
50            || model_name.contains("deepseek-r1")
51            || model_name.contains("kimi")
52        {
53            131072
54        }
55        // 64k context models
56        else if model_name.contains("deepseek-coder") || model_name.contains("command-r") {
57            65536
58        }
59        // 32k context models
60        else if model_name.contains("qwen")
61            || model_name.contains("mistral")
62            || model_name.contains("mixtral")
63            || model_name.contains("gemma2")
64        {
65            32768
66        }
67        // 16k context models
68        else if model_name.contains("codellama") || model_name.contains("phi") {
69            16384
70        }
71        // 8k context models (llama3 default)
72        else if model_name.contains("llama3")
73            || model_name.contains("llama-3")
74            || model_name.contains("gemma")
75        {
76            8192
77        }
78        // 4k context models (older)
79        else if model_name.contains("llama2")
80            || model_name.contains("llama-2")
81            || model_name.contains("tinyllama")
82        {
83            4096
84        } else {
85            8192 // Conservative default for unknown models
86        }
87    }
88
89    /// Calculate remaining tokens in context window
90    pub fn remaining_tokens(&self, used_tokens: usize) -> usize {
91        let max_tokens = self.get_max_tokens();
92        max_tokens.saturating_sub(used_tokens)
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    #[test]
101    fn test_token_counting() {
102        let tokenizer = Tokenizer::new("gpt-3.5-turbo");
103        let text = "Hello, world! This is a test message.";
104        let count = tokenizer.count_tokens(text).unwrap();
105        assert!(count > 0);
106        assert!(count < text.len());
107    }
108
109    #[test]
110    fn test_model_name_extraction() {
111        let tokenizer = Tokenizer::new("ollama/gpt-4");
112        assert_eq!(tokenizer.base_model_name, "gpt-4");
113
114        let tokenizer = Tokenizer::new("unknown-model");
115        assert_eq!(tokenizer.base_model_name, "unknown-model");
116    }
117
118    #[test]
119    fn test_max_tokens() {
120        let tokenizer = Tokenizer::new("ollama/qwen3-coder:30b");
121        assert_eq!(tokenizer.get_max_tokens(), 131072);
122
123        let tokenizer = Tokenizer::new("ollama/llama3:8b");
124        assert_eq!(tokenizer.get_max_tokens(), 8192);
125
126        let tokenizer = Tokenizer::new("tinyllama");
127        assert_eq!(tokenizer.get_max_tokens(), 4096);
128
129        let tokenizer = Tokenizer::new("ollama/mistral");
130        assert_eq!(tokenizer.get_max_tokens(), 32768);
131
132        let tokenizer = Tokenizer::new("unknown-model");
133        assert_eq!(tokenizer.get_max_tokens(), 8192);
134    }
135}