mermaid_cli/utils/
tokenizer.rs1use anyhow::Result;
2
3pub struct Tokenizer {
6 base_model_name: String,
8}
9
10impl Tokenizer {
11 pub fn new(model_name: &str) -> Self {
13 let base = if let Some(idx) = model_name.find('/') {
14 &model_name[idx + 1..]
16 } else {
17 model_name
18 };
19 Self {
20 base_model_name: base.to_lowercase(),
21 }
22 }
23
24 pub fn count_tokens(&self, text: &str) -> Result<usize> {
26 Ok(text.len().div_ceil(4))
27 }
28
29 pub fn count_chat_tokens(&self, messages: &[(String, String)]) -> Result<usize> {
31 let total_chars: usize = messages
32 .iter()
33 .map(|(role, content)| role.len() + content.len() + 4) .sum();
35 Ok(total_chars.div_ceil(4))
36 }
37
38 pub fn get_max_tokens(&self) -> usize {
44 let model_name = &self.base_model_name;
45
46 if model_name.contains("qwen3-coder")
48 || model_name.contains("qwen2.5-coder")
49 || model_name.contains("deepseek-v3")
50 || model_name.contains("deepseek-r1")
51 || model_name.contains("kimi")
52 {
53 131072
54 }
55 else if model_name.contains("deepseek-coder") || model_name.contains("command-r") {
57 65536
58 }
59 else if model_name.contains("qwen")
61 || model_name.contains("mistral")
62 || model_name.contains("mixtral")
63 || model_name.contains("gemma2")
64 {
65 32768
66 }
67 else if model_name.contains("codellama") || model_name.contains("phi") {
69 16384
70 }
71 else if model_name.contains("llama3")
73 || model_name.contains("llama-3")
74 || model_name.contains("gemma")
75 {
76 8192
77 }
78 else if model_name.contains("llama2")
80 || model_name.contains("llama-2")
81 || model_name.contains("tinyllama")
82 {
83 4096
84 } else {
85 8192 }
87 }
88
89 pub fn remaining_tokens(&self, used_tokens: usize) -> usize {
91 let max_tokens = self.get_max_tokens();
92 max_tokens.saturating_sub(used_tokens)
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 #[test]
101 fn test_token_counting() {
102 let tokenizer = Tokenizer::new("gpt-3.5-turbo");
103 let text = "Hello, world! This is a test message.";
104 let count = tokenizer.count_tokens(text).unwrap();
105 assert!(count > 0);
106 assert!(count < text.len());
107 }
108
109 #[test]
110 fn test_model_name_extraction() {
111 let tokenizer = Tokenizer::new("ollama/gpt-4");
112 assert_eq!(tokenizer.base_model_name, "gpt-4");
113
114 let tokenizer = Tokenizer::new("unknown-model");
115 assert_eq!(tokenizer.base_model_name, "unknown-model");
116 }
117
118 #[test]
119 fn test_max_tokens() {
120 let tokenizer = Tokenizer::new("ollama/qwen3-coder:30b");
121 assert_eq!(tokenizer.get_max_tokens(), 131072);
122
123 let tokenizer = Tokenizer::new("ollama/llama3:8b");
124 assert_eq!(tokenizer.get_max_tokens(), 8192);
125
126 let tokenizer = Tokenizer::new("tinyllama");
127 assert_eq!(tokenizer.get_max_tokens(), 4096);
128
129 let tokenizer = Tokenizer::new("ollama/mistral");
130 assert_eq!(tokenizer.get_max_tokens(), 32768);
131
132 let tokenizer = Tokenizer::new("unknown-model");
133 assert_eq!(tokenizer.get_max_tokens(), 8192);
134 }
135}