token_count/tokenizers/
openai.rs1use crate::tokenizers::{ModelInfo, TokenDetail, Tokenizer};
4use anyhow::{Context, Result};
5use tiktoken_rs::CoreBPE;
6
7pub struct OpenAITokenizer {
9 bpe: CoreBPE,
10 model_info: ModelInfo,
11}
12
13impl OpenAITokenizer {
14 pub fn new(encoding_name: &str, model_info: ModelInfo) -> Result<Self> {
16 let tokenizer_enum = match encoding_name {
17 "o200k_base" => tiktoken_rs::tokenizer::Tokenizer::O200kBase,
18 "cl100k_base" => tiktoken_rs::tokenizer::Tokenizer::Cl100kBase,
19 "p50k_base" => tiktoken_rs::tokenizer::Tokenizer::P50kBase,
20 "r50k_base" => tiktoken_rs::tokenizer::Tokenizer::R50kBase,
21 "gpt2" => tiktoken_rs::tokenizer::Tokenizer::Gpt2,
22 _ => anyhow::bail!("Unsupported encoding: {}", encoding_name),
23 };
24
25 let bpe = tiktoken_rs::get_bpe_from_tokenizer(tokenizer_enum)
26 .with_context(|| format!("Failed to load encoding: {}", encoding_name))?;
27
28 Ok(Self { bpe, model_info })
29 }
30}
31
32impl Tokenizer for OpenAITokenizer {
33 fn count_tokens(&self, text: &str) -> Result<usize> {
34 let tokens = self.bpe.encode_with_special_tokens(text);
35 Ok(tokens.len())
36 }
37
38 fn get_model_info(&self) -> ModelInfo {
39 self.model_info.clone()
40 }
41
42 fn encode_with_details(&self, text: &str) -> Result<Option<Vec<TokenDetail>>> {
43 const MAX_DEBUG_INPUT_SIZE: usize = 50 * 1024; if text.len() > MAX_DEBUG_INPUT_SIZE {
49 eprintln!(
50 "Warning: Input size ({} bytes) exceeds debug mode limit ({} bytes). \
51 Showing token count only. For token IDs, provide smaller input.",
52 text.len(),
53 MAX_DEBUG_INPUT_SIZE
54 );
55 return Ok(None);
56 }
57
58 let token_ids = self.bpe.encode_with_special_tokens(text);
59
60 let mut details = Vec::new();
62 for token_id in token_ids.iter().take(10) {
63 let decoded = self.bpe.decode(vec![*token_id])?;
65 details.push(TokenDetail { id: *token_id, text: decoded });
66 }
67
68 Ok(Some(details))
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75
76 #[test]
77 fn test_basic_tokenization() {
78 let model_info = ModelInfo {
79 name: "gpt-4".to_string(),
80 encoding: "cl100k_base".to_string(),
81 context_window: 128000,
82 description: "GPT-4 model".to_string(),
83 };
84
85 let tokenizer = OpenAITokenizer::new("cl100k_base", model_info).unwrap();
86 let count = tokenizer.count_tokens("Hello world").unwrap();
87 assert_eq!(count, 2);
88 }
89
90 #[test]
91 fn test_empty_string() {
92 let model_info = ModelInfo {
93 name: "gpt-4".to_string(),
94 encoding: "cl100k_base".to_string(),
95 context_window: 128000,
96 description: "GPT-4 model".to_string(),
97 };
98
99 let tokenizer = OpenAITokenizer::new("cl100k_base", model_info).unwrap();
100 let count = tokenizer.count_tokens("").unwrap();
101 assert_eq!(count, 0);
102 }
103
104 #[test]
105 fn test_encode_with_details_large_input() {
106 let model_info = ModelInfo {
107 name: "gpt-4".to_string(),
108 encoding: "cl100k_base".to_string(),
109 context_window: 128000,
110 description: "GPT-4 model".to_string(),
111 };
112
113 let tokenizer = OpenAITokenizer::new("cl100k_base", model_info).unwrap();
114
115 let large_input = "a".repeat(60 * 1024);
117 let result = tokenizer.encode_with_details(&large_input);
118
119 assert!(result.is_ok());
121 assert_eq!(result.unwrap(), None);
122 }
123
124 #[test]
125 fn test_encode_with_details_normal_input() {
126 let model_info = ModelInfo {
127 name: "gpt-4".to_string(),
128 encoding: "cl100k_base".to_string(),
129 context_window: 128000,
130 description: "GPT-4 model".to_string(),
131 };
132
133 let tokenizer = OpenAITokenizer::new("cl100k_base", model_info).unwrap();
134
135 let result = tokenizer.encode_with_details("Hello world");
137 assert!(result.is_ok());
138
139 let details = result.unwrap();
140 assert!(details.is_some());
141
142 let details = details.unwrap();
143 assert_eq!(details.len(), 2); assert_eq!(details[0].id, 9906); assert_eq!(details[1].id, 1917); }
147}