Skip to main content

mermaid_cli/utils/
tokenizer.rs

1use anyhow::Result;
2
3/// Token counting utility using character-based estimation
4/// Uses ~4 characters per token as a reasonable approximation
5pub struct Tokenizer {
6    model_name: String,
7}
8
9impl Tokenizer {
10    /// Create a new tokenizer for the given model
11    pub fn new(model_name: &str) -> Self {
12        Self {
13            model_name: model_name.to_string(),
14        }
15    }
16
17    /// Count tokens in a single text string (~4 chars per token)
18    pub fn count_tokens(&self, text: &str) -> Result<usize> {
19        Ok(text.len().div_ceil(4))
20    }
21
22    /// Count tokens in a chat message format
23    pub fn count_chat_tokens(&self, messages: &[(String, String)]) -> Result<usize> {
24        let total_chars: usize = messages
25            .iter()
26            .map(|(role, content)| role.len() + content.len() + 4) // +4 for message overhead
27            .sum();
28        Ok(total_chars.div_ceil(4))
29    }
30
31    /// Get the maximum tokens for a model
32    pub fn get_max_tokens(&self) -> usize {
33        let model_name = self.get_base_model_name();
34
35        if model_name.contains("gpt-4o")
36            || model_name.contains("gpt-4-turbo")
37            || model_name.contains("gpt-4-1106")
38        {
39            128000
40        } else if model_name.contains("gpt-4-32k") {
41            32768
42        } else if model_name.contains("gpt-4") {
43            8192
44        } else if model_name.contains("gpt-3.5-turbo-16k") {
45            16384
46        } else if model_name.contains("gpt-3.5-turbo") {
47            4096
48        } else if model_name.contains("claude-3") {
49            200000
50        } else if model_name.contains("claude") {
51            100000
52        } else if model_name.contains("llama-3") {
53            8192
54        } else if model_name.contains("llama-2") {
55            4096
56        } else if model_name.contains("codellama") {
57            16384
58        } else if model_name.contains("deepseek-coder") {
59            65536
60        } else if model_name.contains("qwen")
61            || model_name.contains("mistral")
62            || model_name.contains("mixtral")
63        {
64            32768
65        } else {
66            8192 // Conservative default
67        }
68    }
69
70    /// Calculate remaining tokens in context window
71    pub fn remaining_tokens(&self, used_tokens: usize) -> usize {
72        let max_tokens = self.get_max_tokens();
73        max_tokens.saturating_sub(used_tokens)
74    }
75
76    /// Get the base model name (strip provider prefix)
77    fn get_base_model_name(&self) -> String {
78        if let Some(idx) = self.model_name.find('/') {
79            self.model_name[idx + 1..].to_string()
80        } else {
81            self.model_name.clone()
82        }
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89
90    #[test]
91    fn test_token_counting() {
92        let tokenizer = Tokenizer::new("gpt-3.5-turbo");
93        let text = "Hello, world! This is a test message.";
94        let count = tokenizer.count_tokens(text).unwrap();
95        assert!(count > 0);
96        assert!(count < text.len());
97    }
98
99    #[test]
100    fn test_model_name_extraction() {
101        let tokenizer = Tokenizer::new("ollama/gpt-4");
102        assert_eq!(tokenizer.get_base_model_name(), "gpt-4");
103
104        let tokenizer = Tokenizer::new("unknown-model");
105        assert_eq!(tokenizer.get_base_model_name(), "unknown-model");
106    }
107
108    #[test]
109    fn test_max_tokens() {
110        let tokenizer = Tokenizer::new("gpt-4");
111        assert_eq!(tokenizer.get_max_tokens(), 8192);
112
113        let tokenizer = Tokenizer::new("gpt-4o");
114        assert_eq!(tokenizer.get_max_tokens(), 128000);
115
116        let tokenizer = Tokenizer::new("gpt-3.5-turbo");
117        assert_eq!(tokenizer.get_max_tokens(), 4096);
118    }
119}