mermaid_cli/utils/
tokenizer.rs1use anyhow::Result;
2
3pub struct Tokenizer {
6 model_name: String,
7}
8
9impl Tokenizer {
10 pub fn new(model_name: &str) -> Self {
12 Self {
13 model_name: model_name.to_string(),
14 }
15 }
16
17 pub fn count_tokens(&self, text: &str) -> Result<usize> {
19 Ok((text.len() + 3) / 4)
20 }
21
22 pub fn count_chat_tokens(&self, messages: &[(String, String)]) -> Result<usize> {
24 let total_chars: usize = messages
25 .iter()
26 .map(|(role, content)| role.len() + content.len() + 4) .sum();
28 Ok((total_chars + 3) / 4)
29 }
30
31 pub fn get_max_tokens(&self) -> usize {
33 let model_name = self.get_base_model_name();
34
35 if model_name.contains("gpt-4o") {
36 128000
37 } else if model_name.contains("gpt-4-turbo") || model_name.contains("gpt-4-1106") {
38 128000
39 } else if model_name.contains("gpt-4-32k") {
40 32768
41 } else if model_name.contains("gpt-4") {
42 8192
43 } else if model_name.contains("gpt-3.5-turbo-16k") {
44 16384
45 } else if model_name.contains("gpt-3.5-turbo") {
46 4096
47 } else if model_name.contains("claude-3") {
48 200000
49 } else if model_name.contains("claude") {
50 100000
51 } else if model_name.contains("llama-3") {
52 8192
53 } else if model_name.contains("llama-2") {
54 4096
55 } else if model_name.contains("codellama") {
56 16384
57 } else if model_name.contains("deepseek-coder") {
58 65536
59 } else if model_name.contains("qwen") {
60 32768
61 } else if model_name.contains("mistral") || model_name.contains("mixtral") {
62 32768
63 } else {
64 8192 }
66 }
67
68 pub fn remaining_tokens(&self, used_tokens: usize) -> usize {
70 let max_tokens = self.get_max_tokens();
71 max_tokens.saturating_sub(used_tokens)
72 }
73
74 fn get_base_model_name(&self) -> String {
76 if let Some(idx) = self.model_name.find('/') {
77 self.model_name[idx + 1..].to_string()
78 } else {
79 self.model_name.clone()
80 }
81 }
82}
83
84pub fn count_file_tokens(content: &str, _model_name: &str) -> usize {
86 (content.len() + 3) / 4
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92
93 #[test]
94 fn test_token_counting() {
95 let tokenizer = Tokenizer::new("gpt-3.5-turbo");
96 let text = "Hello, world! This is a test message.";
97 let count = tokenizer.count_tokens(text).unwrap();
98 assert!(count > 0);
99 assert!(count < text.len());
100 }
101
102 #[test]
103 fn test_model_name_extraction() {
104 let tokenizer = Tokenizer::new("ollama/gpt-4");
105 assert_eq!(tokenizer.get_base_model_name(), "gpt-4");
106
107 let tokenizer = Tokenizer::new("unknown-model");
108 assert_eq!(tokenizer.get_base_model_name(), "unknown-model");
109 }
110
111 #[test]
112 fn test_max_tokens() {
113 let tokenizer = Tokenizer::new("gpt-4");
114 assert_eq!(tokenizer.get_max_tokens(), 8192);
115
116 let tokenizer = Tokenizer::new("gpt-4o");
117 assert_eq!(tokenizer.get_max_tokens(), 128000);
118
119 let tokenizer = Tokenizer::new("gpt-3.5-turbo");
120 assert_eq!(tokenizer.get_max_tokens(), 4096);
121 }
122}