mermaid_cli/utils/
tokenizer.rs1use anyhow::Result;
2
3pub struct Tokenizer {
6 model_name: String,
7}
8
9impl Tokenizer {
10 pub fn new(model_name: &str) -> Self {
12 Self {
13 model_name: model_name.to_string(),
14 }
15 }
16
17 pub fn count_tokens(&self, text: &str) -> Result<usize> {
19 Ok(text.len().div_ceil(4))
20 }
21
22 pub fn count_chat_tokens(&self, messages: &[(String, String)]) -> Result<usize> {
24 let total_chars: usize = messages
25 .iter()
26 .map(|(role, content)| role.len() + content.len() + 4) .sum();
28 Ok(total_chars.div_ceil(4))
29 }
30
31 pub fn get_max_tokens(&self) -> usize {
33 let model_name = self.get_base_model_name();
34
35 if model_name.contains("gpt-4o")
36 || model_name.contains("gpt-4-turbo")
37 || model_name.contains("gpt-4-1106")
38 {
39 128000
40 } else if model_name.contains("gpt-4-32k") {
41 32768
42 } else if model_name.contains("gpt-4") {
43 8192
44 } else if model_name.contains("gpt-3.5-turbo-16k") {
45 16384
46 } else if model_name.contains("gpt-3.5-turbo") {
47 4096
48 } else if model_name.contains("claude-3") {
49 200000
50 } else if model_name.contains("claude") {
51 100000
52 } else if model_name.contains("llama-3") {
53 8192
54 } else if model_name.contains("llama-2") {
55 4096
56 } else if model_name.contains("codellama") {
57 16384
58 } else if model_name.contains("deepseek-coder") {
59 65536
60 } else if model_name.contains("qwen")
61 || model_name.contains("mistral")
62 || model_name.contains("mixtral")
63 {
64 32768
65 } else {
66 8192 }
68 }
69
70 pub fn remaining_tokens(&self, used_tokens: usize) -> usize {
72 let max_tokens = self.get_max_tokens();
73 max_tokens.saturating_sub(used_tokens)
74 }
75
76 fn get_base_model_name(&self) -> String {
78 if let Some(idx) = self.model_name.find('/') {
79 self.model_name[idx + 1..].to_string()
80 } else {
81 self.model_name.clone()
82 }
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89
90 #[test]
91 fn test_token_counting() {
92 let tokenizer = Tokenizer::new("gpt-3.5-turbo");
93 let text = "Hello, world! This is a test message.";
94 let count = tokenizer.count_tokens(text).unwrap();
95 assert!(count > 0);
96 assert!(count < text.len());
97 }
98
99 #[test]
100 fn test_model_name_extraction() {
101 let tokenizer = Tokenizer::new("ollama/gpt-4");
102 assert_eq!(tokenizer.get_base_model_name(), "gpt-4");
103
104 let tokenizer = Tokenizer::new("unknown-model");
105 assert_eq!(tokenizer.get_base_model_name(), "unknown-model");
106 }
107
108 #[test]
109 fn test_max_tokens() {
110 let tokenizer = Tokenizer::new("gpt-4");
111 assert_eq!(tokenizer.get_max_tokens(), 8192);
112
113 let tokenizer = Tokenizer::new("gpt-4o");
114 assert_eq!(tokenizer.get_max_tokens(), 128000);
115
116 let tokenizer = Tokenizer::new("gpt-3.5-turbo");
117 assert_eq!(tokenizer.get_max_tokens(), 4096);
118 }
119}