mermaid_cli/utils/
tokenizer.rs

1use anyhow::Result;
2use tiktoken_rs::{num_tokens_from_messages, ChatCompletionRequestMessage};
3
4/// Token counting utility for various model families
5pub struct Tokenizer {
6    model_name: String,
7}
8
9impl Tokenizer {
10    /// Create a new tokenizer for the given model
11    pub fn new(model_name: &str) -> Self {
12        Self {
13            model_name: model_name.to_string(),
14        }
15    }
16
17    /// Count tokens in a single text string
18    pub fn count_tokens(&self, text: &str) -> Result<usize> {
19        // Extract the base model name for tokenizer selection
20        let model_for_encoding = self.get_base_model_name();
21
22        // Get the appropriate tokenizer
23        match tiktoken_rs::get_bpe_from_model(&model_for_encoding) {
24            Ok(bpe) => {
25                // Count tokens using the BPE tokenizer
26                Ok(bpe.encode_with_special_tokens(text).len())
27            },
28            Err(_) => {
29                // Fallback to cl100k_base encoding (GPT-4/GPT-3.5-turbo)
30                tiktoken_rs::cl100k_base()
31                    .map(|bpe| bpe.encode_with_special_tokens(text).len())
32                    .or_else(|_| Ok(text.len() / 4))
33            },
34        }
35    }
36
37    /// Count tokens in a chat message format
38    pub fn count_chat_tokens(&self, messages: &[(String, String)]) -> Result<usize> {
39        // Convert to tiktoken's ChatCompletionRequestMessage format
40        let chat_messages: Vec<ChatCompletionRequestMessage> = messages
41            .iter()
42            .map(|(role, content)| ChatCompletionRequestMessage {
43                role: role.clone(),
44                content: Some(content.clone()),
45                name: None,
46                function_call: None,
47            })
48            .collect();
49
50        let model_for_encoding = self.get_base_model_name();
51
52        // Use tiktoken's chat token counter
53        match num_tokens_from_messages(&model_for_encoding, &chat_messages) {
54            Ok(count) => Ok(count),
55            Err(_) => {
56                // Fallback to GPT-3.5 encoding
57                num_tokens_from_messages("gpt-3.5-turbo", &chat_messages).or_else(|_| {
58                    // Last resort: simple approximation
59                    let total_chars: usize =
60                        messages.iter().map(|(_, content)| content.len()).sum();
61                    Ok(total_chars / 4)
62                })
63            },
64        }
65    }
66
67    /// Get the maximum tokens for a model
68    pub fn get_max_tokens(&self) -> usize {
69        let model_name = self.get_base_model_name();
70
71        // Return max tokens based on common models
72        // These are approximate values for context window sizes
73        if model_name.contains("gpt-4o") {
74            128000 // GPT-4o
75        } else if model_name.contains("gpt-4-turbo") || model_name.contains("gpt-4-1106") {
76            128000 // GPT-4 Turbo
77        } else if model_name.contains("gpt-4-32k") {
78            32768 // GPT-4 32k
79        } else if model_name.contains("gpt-4") {
80            8192 // GPT-4
81        } else if model_name.contains("gpt-3.5-turbo-16k") {
82            16384 // GPT-3.5 Turbo 16k
83        } else if model_name.contains("gpt-3.5-turbo") {
84            4096 // GPT-3.5 Turbo
85        } else if model_name.contains("claude-3") {
86            200000 // Claude 3
87        } else if model_name.contains("claude") {
88            100000 // Claude 2
89        } else if model_name.contains("llama-3") {
90            8192 // Llama 3
91        } else if model_name.contains("llama-2") {
92            4096 // Llama 2
93        } else if model_name.contains("codellama") {
94            16384 // Code Llama
95        } else if model_name.contains("deepseek-coder") {
96            65536 // DeepSeek Coder
97        } else if model_name.contains("qwen") {
98            32768 // Qwen models
99        } else if model_name.contains("mistral") || model_name.contains("mixtral") {
100            32768 // Mistral/Mixtral
101        } else {
102            8192 // Conservative default
103        }
104    }
105
106    /// Calculate remaining tokens in context window
107    pub fn remaining_tokens(&self, used_tokens: usize) -> usize {
108        let max_tokens = self.get_max_tokens();
109        max_tokens.saturating_sub(used_tokens)
110    }
111
112    /// Get the base model name for tokenizer selection
113    fn get_base_model_name(&self) -> String {
114        // Remove provider prefix if present (e.g., "ollama/gpt-4" -> "gpt-4")
115        let base_name = if let Some(idx) = self.model_name.find('/') {
116            &self.model_name[idx + 1..]
117        } else {
118            &self.model_name
119        };
120
121        // Map model variations to their base tokenizer
122        // IMPORTANT: Order matters - check most specific patterns first!
123        let model_mappings: Vec<(&str, &str)> = vec![
124            // OpenAI models - most specific first
125            ("gpt-4o", "gpt-4o"),
126            ("gpt-4-turbo", "gpt-4-turbo"),
127            ("gpt-4-32k", "gpt-4-32k"),
128            ("gpt-4", "gpt-4"),
129            ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k"),
130            ("gpt-3.5-turbo", "gpt-3.5-turbo"),
131            // Claude models - use GPT-4 encoding as approximation
132            ("claude-3-opus", "gpt-4"),
133            ("claude-3-sonnet", "gpt-4"),
134            ("claude-3-haiku", "gpt-4"),
135            ("claude-3", "gpt-4"),
136            ("claude", "gpt-4"),
137            // Llama models - use GPT-3.5 encoding as approximation
138            ("codellama", "gpt-3.5-turbo"),
139            ("llama3", "gpt-3.5-turbo"),
140            ("llama2", "gpt-3.5-turbo"),
141            // Other models - use GPT-3.5 as default
142            ("deepseek", "gpt-3.5-turbo"),
143            ("qwen", "gpt-3.5-turbo"),
144            ("mistral", "gpt-3.5-turbo"),
145            ("mixtral", "gpt-3.5-turbo"),
146        ];
147
148        // Find the best matching tokenizer (checks in order)
149        for (pattern, tokenizer) in &model_mappings {
150            if base_name.to_lowercase().contains(pattern) {
151                return tokenizer.to_string();
152            }
153        }
154
155        // Default to GPT-3.5 tokenizer for unknown models
156        "gpt-3.5-turbo".to_string()
157    }
158}
159
160/// Count tokens in file contents (convenience function)
161pub fn count_file_tokens(content: &str, model_name: &str) -> usize {
162    let tokenizer = Tokenizer::new(model_name);
163    tokenizer
164        .count_tokens(content)
165        .unwrap_or_else(|_| content.len() / 4)
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn test_token_counting() {
174        let tokenizer = Tokenizer::new("gpt-3.5-turbo");
175        let text = "Hello, world! This is a test message.";
176        let count = tokenizer.count_tokens(text).unwrap();
177        assert!(count > 0);
178        assert!(count < text.len()); // Tokens should be less than characters
179    }
180
181    #[test]
182    fn test_model_name_extraction() {
183        let tokenizer = Tokenizer::new("ollama/gpt-4");
184        assert_eq!(tokenizer.get_base_model_name(), "gpt-4");
185
186        let tokenizer = Tokenizer::new("anthropic/claude-3-sonnet");
187        assert_eq!(tokenizer.get_base_model_name(), "gpt-4"); // Mapped to GPT-4
188
189        let tokenizer = Tokenizer::new("unknown-model");
190        assert_eq!(tokenizer.get_base_model_name(), "gpt-3.5-turbo"); // Default
191    }
192
193    #[test]
194    fn test_max_tokens() {
195        let tokenizer = Tokenizer::new("gpt-4");
196        assert_eq!(tokenizer.get_max_tokens(), 8192); // GPT-4 standard
197
198        let tokenizer = Tokenizer::new("gpt-4o");
199        assert_eq!(tokenizer.get_max_tokens(), 128000); // GPT-4o
200
201        let tokenizer = Tokenizer::new("gpt-3.5-turbo");
202        assert_eq!(tokenizer.get_max_tokens(), 4096); // GPT-3.5 Turbo
203    }
204}