mermaid_cli/utils/
tokenizer.rs1use anyhow::Result;
2use tiktoken_rs::{num_tokens_from_messages, ChatCompletionRequestMessage};
3
4pub struct Tokenizer {
6 model_name: String,
7}
8
9impl Tokenizer {
10 pub fn new(model_name: &str) -> Self {
12 Self {
13 model_name: model_name.to_string(),
14 }
15 }
16
17 pub fn count_tokens(&self, text: &str) -> Result<usize> {
19 let model_for_encoding = self.get_base_model_name();
21
22 match tiktoken_rs::get_bpe_from_model(&model_for_encoding) {
24 Ok(bpe) => {
25 Ok(bpe.encode_with_special_tokens(text).len())
27 },
28 Err(_) => {
29 tiktoken_rs::cl100k_base()
31 .map(|bpe| bpe.encode_with_special_tokens(text).len())
32 .or_else(|_| Ok(text.len() / 4))
33 },
34 }
35 }
36
37 pub fn count_chat_tokens(&self, messages: &[(String, String)]) -> Result<usize> {
39 let chat_messages: Vec<ChatCompletionRequestMessage> = messages
41 .iter()
42 .map(|(role, content)| ChatCompletionRequestMessage {
43 role: role.clone(),
44 content: Some(content.clone()),
45 name: None,
46 function_call: None,
47 })
48 .collect();
49
50 let model_for_encoding = self.get_base_model_name();
51
52 match num_tokens_from_messages(&model_for_encoding, &chat_messages) {
54 Ok(count) => Ok(count),
55 Err(_) => {
56 num_tokens_from_messages("gpt-3.5-turbo", &chat_messages).or_else(|_| {
58 let total_chars: usize =
60 messages.iter().map(|(_, content)| content.len()).sum();
61 Ok(total_chars / 4)
62 })
63 },
64 }
65 }
66
67 pub fn get_max_tokens(&self) -> usize {
69 let model_name = self.get_base_model_name();
70
71 if model_name.contains("gpt-4o") {
74 128000 } else if model_name.contains("gpt-4-turbo") || model_name.contains("gpt-4-1106") {
76 128000 } else if model_name.contains("gpt-4-32k") {
78 32768 } else if model_name.contains("gpt-4") {
80 8192 } else if model_name.contains("gpt-3.5-turbo-16k") {
82 16384 } else if model_name.contains("gpt-3.5-turbo") {
84 4096 } else if model_name.contains("claude-3") {
86 200000 } else if model_name.contains("claude") {
88 100000 } else if model_name.contains("llama-3") {
90 8192 } else if model_name.contains("llama-2") {
92 4096 } else if model_name.contains("codellama") {
94 16384 } else if model_name.contains("deepseek-coder") {
96 65536 } else if model_name.contains("qwen") {
98 32768 } else if model_name.contains("mistral") || model_name.contains("mixtral") {
100 32768 } else {
102 8192 }
104 }
105
106 pub fn remaining_tokens(&self, used_tokens: usize) -> usize {
108 let max_tokens = self.get_max_tokens();
109 max_tokens.saturating_sub(used_tokens)
110 }
111
112 fn get_base_model_name(&self) -> String {
114 let base_name = if let Some(idx) = self.model_name.find('/') {
116 &self.model_name[idx + 1..]
117 } else {
118 &self.model_name
119 };
120
121 let model_mappings: Vec<(&str, &str)> = vec![
124 ("gpt-4o", "gpt-4o"),
126 ("gpt-4-turbo", "gpt-4-turbo"),
127 ("gpt-4-32k", "gpt-4-32k"),
128 ("gpt-4", "gpt-4"),
129 ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k"),
130 ("gpt-3.5-turbo", "gpt-3.5-turbo"),
131 ("claude-3-opus", "gpt-4"),
133 ("claude-3-sonnet", "gpt-4"),
134 ("claude-3-haiku", "gpt-4"),
135 ("claude-3", "gpt-4"),
136 ("claude", "gpt-4"),
137 ("codellama", "gpt-3.5-turbo"),
139 ("llama3", "gpt-3.5-turbo"),
140 ("llama2", "gpt-3.5-turbo"),
141 ("deepseek", "gpt-3.5-turbo"),
143 ("qwen", "gpt-3.5-turbo"),
144 ("mistral", "gpt-3.5-turbo"),
145 ("mixtral", "gpt-3.5-turbo"),
146 ];
147
148 for (pattern, tokenizer) in &model_mappings {
150 if base_name.to_lowercase().contains(pattern) {
151 return tokenizer.to_string();
152 }
153 }
154
155 "gpt-3.5-turbo".to_string()
157 }
158}
159
160pub fn count_file_tokens(content: &str, model_name: &str) -> usize {
162 let tokenizer = Tokenizer::new(model_name);
163 tokenizer
164 .count_tokens(content)
165 .unwrap_or_else(|_| content.len() / 4)
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn test_token_counting() {
174 let tokenizer = Tokenizer::new("gpt-3.5-turbo");
175 let text = "Hello, world! This is a test message.";
176 let count = tokenizer.count_tokens(text).unwrap();
177 assert!(count > 0);
178 assert!(count < text.len()); }
180
181 #[test]
182 fn test_model_name_extraction() {
183 let tokenizer = Tokenizer::new("ollama/gpt-4");
184 assert_eq!(tokenizer.get_base_model_name(), "gpt-4");
185
186 let tokenizer = Tokenizer::new("anthropic/claude-3-sonnet");
187 assert_eq!(tokenizer.get_base_model_name(), "gpt-4"); let tokenizer = Tokenizer::new("unknown-model");
190 assert_eq!(tokenizer.get_base_model_name(), "gpt-3.5-turbo"); }
192
193 #[test]
194 fn test_max_tokens() {
195 let tokenizer = Tokenizer::new("gpt-4");
196 assert_eq!(tokenizer.get_max_tokens(), 8192); let tokenizer = Tokenizer::new("gpt-4o");
199 assert_eq!(tokenizer.get_max_tokens(), 128000); let tokenizer = Tokenizer::new("gpt-3.5-turbo");
202 assert_eq!(tokenizer.get_max_tokens(), 4096); }
204}