1use std::fmt;
2use tiktoken_rs::CoreBPE;
3
4#[derive(Debug, Clone, Default, PartialEq, Eq)]
6pub enum TokenizerKind {
7 #[default]
9 Heuristic,
10 Claude,
12 Gpt4,
14 Gpt35,
16}
17
18impl fmt::Display for TokenizerKind {
19 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20 match self {
21 TokenizerKind::Heuristic => write!(f, "heuristic"),
22 TokenizerKind::Claude => write!(f, "claude"),
23 TokenizerKind::Gpt4 => write!(f, "gpt-4"),
24 TokenizerKind::Gpt35 => write!(f, "gpt-3.5"),
25 }
26 }
27}
28
29impl TokenizerKind {
30 pub fn parse_name(s: &str) -> Option<Self> {
32 match s {
33 "heuristic" => Some(TokenizerKind::Heuristic),
34 "claude" => Some(TokenizerKind::Claude),
35 "gpt-4" | "gpt4" => Some(TokenizerKind::Gpt4),
36 "gpt-3.5" | "gpt3.5" | "gpt-35" => Some(TokenizerKind::Gpt35),
37 _ => None,
38 }
39 }
40
41 pub fn valid_names() -> &'static str {
43 "heuristic, claude, gpt-4, gpt-3.5"
44 }
45}
46
47pub struct Tokenizer {
49 kind: TokenizerKind,
50 bpe: Option<CoreBPE>,
51}
52
53impl Tokenizer {
54 pub fn new(kind: TokenizerKind) -> Self {
57 let bpe = match &kind {
58 TokenizerKind::Heuristic => None,
59 TokenizerKind::Claude | TokenizerKind::Gpt4 | TokenizerKind::Gpt35 => {
60 match tiktoken_rs::cl100k_base() {
61 Ok(bpe) => Some(bpe),
62 Err(e) => {
63 eprintln!(
64 "Warning: failed to load {} tokenizer: {}, falling back to heuristic",
65 kind, e
66 );
67 None
68 }
69 }
70 }
71 };
72 Self { kind, bpe }
73 }
74
75 pub fn count_tokens(&self, content: &str, is_prose: bool) -> usize {
77 match &self.bpe {
78 Some(bpe) => bpe.encode_with_special_tokens(content).len(),
79 None => estimate_tokens_heuristic(content, is_prose),
80 }
81 }
82
83 pub fn kind(&self) -> &TokenizerKind {
85 &self.kind
86 }
87
88 pub fn is_real(&self) -> bool {
90 self.bpe.is_some()
91 }
92}
93
94pub fn estimate_tokens(content: &str, is_prose: bool) -> usize {
102 estimate_tokens_heuristic(content, is_prose)
103}
104
105fn estimate_tokens_heuristic(content: &str, is_prose: bool) -> usize {
106 let byte_count = content.len();
107 if is_prose {
108 byte_count / 4
109 } else {
110 byte_count / 3
111 }
112}
113
114pub fn is_prose_extension(ext: &str) -> bool {
116 matches!(
117 ext.to_lowercase().as_str(),
118 "md" | "txt" | "rst" | "adoc" | "textile" | "org" | "wiki"
119 )
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 #[test]
127 fn test_estimate_tokens_code() {
128 let code = "x".repeat(300);
130 assert_eq!(estimate_tokens(&code, false), 100);
131 }
132
133 #[test]
134 fn test_estimate_tokens_prose() {
135 let prose = "x".repeat(400);
137 assert_eq!(estimate_tokens(&prose, true), 100);
138 }
139
140 #[test]
141 fn test_estimate_tokens_empty() {
142 assert_eq!(estimate_tokens("", false), 0);
143 assert_eq!(estimate_tokens("", true), 0);
144 }
145
146 #[test]
147 fn test_is_prose_extension() {
148 assert!(is_prose_extension("md"));
149 assert!(is_prose_extension("txt"));
150 assert!(is_prose_extension("rst"));
151 assert!(!is_prose_extension("rs"));
152 assert!(!is_prose_extension("py"));
153 assert!(!is_prose_extension("ts"));
154 }
155
156 #[test]
157 fn test_tokenizer_kind_parse() {
158 assert_eq!(TokenizerKind::parse_name("heuristic"), Some(TokenizerKind::Heuristic));
159 assert_eq!(TokenizerKind::parse_name("claude"), Some(TokenizerKind::Claude));
160 assert_eq!(TokenizerKind::parse_name("gpt-4"), Some(TokenizerKind::Gpt4));
161 assert_eq!(TokenizerKind::parse_name("gpt4"), Some(TokenizerKind::Gpt4));
162 assert_eq!(TokenizerKind::parse_name("gpt-3.5"), Some(TokenizerKind::Gpt35));
163 assert_eq!(TokenizerKind::parse_name("gpt3.5"), Some(TokenizerKind::Gpt35));
164 assert_eq!(TokenizerKind::parse_name("invalid"), None);
165 }
166
167 #[test]
168 fn test_tokenizer_kind_default() {
169 assert_eq!(TokenizerKind::default(), TokenizerKind::Heuristic);
170 }
171
172 #[test]
173 fn test_heuristic_tokenizer() {
174 let tok = Tokenizer::new(TokenizerKind::Heuristic);
175 assert!(!tok.is_real());
176 let code = "x".repeat(300);
177 assert_eq!(tok.count_tokens(&code, false), 100);
178 }
179
180 #[test]
181 fn test_real_tokenizer_loads() {
182 let tok = Tokenizer::new(TokenizerKind::Gpt4);
183 assert!(tok.is_real());
184 }
185
186 #[test]
187 fn test_real_tokenizer_counts() {
188 let tok = Tokenizer::new(TokenizerKind::Gpt4);
189 let count = tok.count_tokens("Hello, world!", false);
190 assert!(count > 0 && count < 10, "Expected 1-9 tokens, got {}", count);
191 }
192
193 #[test]
194 fn test_real_tokenizer_known_value() {
195 let tok = Tokenizer::new(TokenizerKind::Gpt4);
196 let count = tok.count_tokens("Hello world", false);
197 assert_eq!(count, 2, "Expected 2 tokens for 'Hello world', got {}", count);
198 }
199
200 #[test]
201 fn test_real_vs_heuristic_comparison() {
202 let real = Tokenizer::new(TokenizerKind::Gpt4);
203 let heuristic = Tokenizer::new(TokenizerKind::Heuristic);
204
205 let code = "fn main() {\n println!(\"Hello, world!\");\n}\n";
206 let real_count = real.count_tokens(code, false);
207 let heuristic_count = heuristic.count_tokens(code, false);
208
209 assert!(real_count > 0);
210 assert!(heuristic_count > 0);
211 assert!(real_count < code.len(), "Real count should be less than byte length");
212 }
213
214 #[test]
215 fn test_heuristic_overestimates_code() {
216 let real = Tokenizer::new(TokenizerKind::Gpt4);
217 let heuristic = Tokenizer::new(TokenizerKind::Heuristic);
218
219 let code = r#"
220use std::collections::HashMap;
221
222pub struct Config {
223 pub name: String,
224 pub values: HashMap<String, Vec<u32>>,
225}
226
227impl Config {
228 pub fn new(name: &str) -> Self {
229 Self {
230 name: name.to_string(),
231 values: HashMap::new(),
232 }
233 }
234
235 pub fn insert(&mut self, key: &str, val: u32) {
236 self.values.entry(key.to_string()).or_default().push(val);
237 }
238}
239"#;
240 let real_count = real.count_tokens(code, false);
241 let heuristic_count = heuristic.count_tokens(code, false);
242
243 assert!(
244 heuristic_count >= real_count,
245 "Heuristic ({}) should overestimate vs real ({}) for code",
246 heuristic_count,
247 real_count
248 );
249 }
250
251 #[test]
252 fn test_heuristic_overestimates_prose() {
253 let real = Tokenizer::new(TokenizerKind::Gpt4);
254 let heuristic = Tokenizer::new(TokenizerKind::Heuristic);
255
256 let prose = "The quick brown fox jumps over the lazy dog. \
257 This is a longer piece of prose text that should be tokenized \
258 differently from code. Natural language tends to have longer tokens \
259 on average compared to source code with its punctuation and symbols.";
260
261 let real_count = real.count_tokens(prose, true);
262 let heuristic_count = heuristic.count_tokens(prose, true);
263
264 assert!(real_count > 0, "Real tokenizer should produce tokens for prose");
265 assert!(heuristic_count > 0, "Heuristic should produce tokens for prose");
266 let ratio = heuristic_count as f64 / real_count as f64;
267 assert!(
268 ratio > 0.5 && ratio < 3.0,
269 "Heuristic ({}) and real ({}) should be within 3x of each other for prose (ratio: {:.2})",
270 heuristic_count,
271 real_count,
272 ratio
273 );
274 }
275
276 #[test]
277 fn test_all_real_tokenizers_produce_same_counts() {
278 let claude = Tokenizer::new(TokenizerKind::Claude);
279 let gpt4 = Tokenizer::new(TokenizerKind::Gpt4);
280 let gpt35 = Tokenizer::new(TokenizerKind::Gpt35);
281
282 let text = "fn main() { println!(\"Hello, world!\"); }";
283
284 let claude_count = claude.count_tokens(text, false);
285 let gpt4_count = gpt4.count_tokens(text, false);
286 let gpt35_count = gpt35.count_tokens(text, false);
287
288 assert_eq!(claude_count, gpt4_count, "Claude and GPT-4 should match");
289 assert_eq!(gpt4_count, gpt35_count, "GPT-4 and GPT-3.5 should match");
290 }
291
292 #[test]
293 fn test_real_tokenizer_empty_string() {
294 let tok = Tokenizer::new(TokenizerKind::Gpt4);
295 assert_eq!(tok.count_tokens("", false), 0);
296 assert_eq!(tok.count_tokens("", true), 0);
297 }
298
299 #[test]
300 fn test_real_tokenizer_whitespace_only() {
301 let tok = Tokenizer::new(TokenizerKind::Gpt4);
302 let count = tok.count_tokens(" \n\n\t ", false);
303 assert!(count > 0, "Whitespace should produce at least 1 token, got {}", count);
304 }
305
306 #[test]
307 fn test_tokenizer_kind_display() {
308 assert_eq!(format!("{}", TokenizerKind::Heuristic), "heuristic");
309 assert_eq!(format!("{}", TokenizerKind::Claude), "claude");
310 assert_eq!(format!("{}", TokenizerKind::Gpt4), "gpt-4");
311 assert_eq!(format!("{}", TokenizerKind::Gpt35), "gpt-3.5");
312 }
313
314 #[test]
315 fn test_tokenizer_kind_roundtrip() {
316 for kind in [
317 TokenizerKind::Heuristic,
318 TokenizerKind::Claude,
319 TokenizerKind::Gpt4,
320 TokenizerKind::Gpt35,
321 ] {
322 let display = format!("{}", kind);
323 let parsed = TokenizerKind::parse_name(&display);
324 assert_eq!(
325 parsed,
326 Some(kind.clone()),
327 "Roundtrip failed for {}",
328 display
329 );
330 }
331 }
332
333 #[test]
334 fn test_tokenizer_kind_accessor() {
335 let tok = Tokenizer::new(TokenizerKind::Claude);
336 assert_eq!(*tok.kind(), TokenizerKind::Claude);
337 }
338}