Skip to main content

flat/
tokens.rs

1use std::fmt;
2use tiktoken_rs::CoreBPE;
3
4/// Which tokenizer to use for token counting.
5#[derive(Debug, Clone, Default, PartialEq, Eq)]
6pub enum TokenizerKind {
7    /// Heuristic estimation: bytes/3 for code, bytes/4 for prose (default)
8    #[default]
9    Heuristic,
10    /// Claude tokenizer (uses cl100k_base as approximation)
11    Claude,
12    /// GPT-4 tokenizer (cl100k_base)
13    Gpt4,
14    /// GPT-3.5 tokenizer (cl100k_base)
15    Gpt35,
16}
17
18impl fmt::Display for TokenizerKind {
19    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20        match self {
21            TokenizerKind::Heuristic => write!(f, "heuristic"),
22            TokenizerKind::Claude => write!(f, "claude"),
23            TokenizerKind::Gpt4 => write!(f, "gpt-4"),
24            TokenizerKind::Gpt35 => write!(f, "gpt-3.5"),
25        }
26    }
27}
28
29impl TokenizerKind {
30    /// Parse a tokenizer name from CLI input.
31    pub fn parse_name(s: &str) -> Option<Self> {
32        match s {
33            "heuristic" => Some(TokenizerKind::Heuristic),
34            "claude" => Some(TokenizerKind::Claude),
35            "gpt-4" | "gpt4" => Some(TokenizerKind::Gpt4),
36            "gpt-3.5" | "gpt3.5" | "gpt-35" => Some(TokenizerKind::Gpt35),
37            _ => None,
38        }
39    }
40
41    /// List valid tokenizer names for help text.
42    pub fn valid_names() -> &'static str {
43        "heuristic, claude, gpt-4, gpt-3.5"
44    }
45}
46
47/// A tokenizer that can count tokens in text.
48pub struct Tokenizer {
49    kind: TokenizerKind,
50    bpe: Option<CoreBPE>,
51}
52
53impl Tokenizer {
54    /// Create a new tokenizer of the given kind.
55    /// Falls back to heuristic if the BPE model fails to load.
56    pub fn new(kind: TokenizerKind) -> Self {
57        let bpe = match &kind {
58            TokenizerKind::Heuristic => None,
59            TokenizerKind::Claude | TokenizerKind::Gpt4 | TokenizerKind::Gpt35 => {
60                match tiktoken_rs::cl100k_base() {
61                    Ok(bpe) => Some(bpe),
62                    Err(e) => {
63                        eprintln!(
64                            "Warning: failed to load {} tokenizer: {}, falling back to heuristic",
65                            kind, e
66                        );
67                        None
68                    }
69                }
70            }
71        };
72        Self { kind, bpe }
73    }
74
75    /// Count tokens in the given text.
76    pub fn count_tokens(&self, content: &str, is_prose: bool) -> usize {
77        match &self.bpe {
78            Some(bpe) => bpe.encode_with_special_tokens(content).len(),
79            None => estimate_tokens_heuristic(content, is_prose),
80        }
81    }
82
83    /// Return the kind of this tokenizer.
84    pub fn kind(&self) -> &TokenizerKind {
85        &self.kind
86    }
87
88    /// Whether this tokenizer uses real BPE encoding (not heuristic).
89    pub fn is_real(&self) -> bool {
90        self.bpe.is_some()
91    }
92}
93
94/// Estimate the number of tokens for a piece of content using heuristic.
95///
96/// Uses pessimistic (conservative) estimation per PDR spec:
97/// - Code files: bytes / 3 (~3.0 chars/token)
98/// - Prose files: bytes / 4 (~4.0 chars/token)
99///
100/// This intentionally overestimates to stay within context windows.
101pub fn estimate_tokens(content: &str, is_prose: bool) -> usize {
102    estimate_tokens_heuristic(content, is_prose)
103}
104
105fn estimate_tokens_heuristic(content: &str, is_prose: bool) -> usize {
106    let byte_count = content.len();
107    if is_prose {
108        byte_count / 4
109    } else {
110        byte_count / 3
111    }
112}
113
114/// Check if a file extension indicates prose content
115pub fn is_prose_extension(ext: &str) -> bool {
116    matches!(
117        ext.to_lowercase().as_str(),
118        "md" | "txt" | "rst" | "adoc" | "textile" | "org" | "wiki"
119    )
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test]
127    fn test_estimate_tokens_code() {
128        // 300 bytes of code = 100 tokens (300/3)
129        let code = "x".repeat(300);
130        assert_eq!(estimate_tokens(&code, false), 100);
131    }
132
133    #[test]
134    fn test_estimate_tokens_prose() {
135        // 400 bytes of prose = 100 tokens (400/4)
136        let prose = "x".repeat(400);
137        assert_eq!(estimate_tokens(&prose, true), 100);
138    }
139
140    #[test]
141    fn test_estimate_tokens_empty() {
142        assert_eq!(estimate_tokens("", false), 0);
143        assert_eq!(estimate_tokens("", true), 0);
144    }
145
146    #[test]
147    fn test_is_prose_extension() {
148        assert!(is_prose_extension("md"));
149        assert!(is_prose_extension("txt"));
150        assert!(is_prose_extension("rst"));
151        assert!(!is_prose_extension("rs"));
152        assert!(!is_prose_extension("py"));
153        assert!(!is_prose_extension("ts"));
154    }
155
156    #[test]
157    fn test_tokenizer_kind_parse() {
158        assert_eq!(TokenizerKind::parse_name("heuristic"), Some(TokenizerKind::Heuristic));
159        assert_eq!(TokenizerKind::parse_name("claude"), Some(TokenizerKind::Claude));
160        assert_eq!(TokenizerKind::parse_name("gpt-4"), Some(TokenizerKind::Gpt4));
161        assert_eq!(TokenizerKind::parse_name("gpt4"), Some(TokenizerKind::Gpt4));
162        assert_eq!(TokenizerKind::parse_name("gpt-3.5"), Some(TokenizerKind::Gpt35));
163        assert_eq!(TokenizerKind::parse_name("gpt3.5"), Some(TokenizerKind::Gpt35));
164        assert_eq!(TokenizerKind::parse_name("invalid"), None);
165    }
166
167    #[test]
168    fn test_tokenizer_kind_default() {
169        assert_eq!(TokenizerKind::default(), TokenizerKind::Heuristic);
170    }
171
172    #[test]
173    fn test_heuristic_tokenizer() {
174        let tok = Tokenizer::new(TokenizerKind::Heuristic);
175        assert!(!tok.is_real());
176        let code = "x".repeat(300);
177        assert_eq!(tok.count_tokens(&code, false), 100);
178    }
179
180    #[test]
181    fn test_real_tokenizer_loads() {
182        let tok = Tokenizer::new(TokenizerKind::Gpt4);
183        assert!(tok.is_real());
184    }
185
186    #[test]
187    fn test_real_tokenizer_counts() {
188        let tok = Tokenizer::new(TokenizerKind::Gpt4);
189        let count = tok.count_tokens("Hello, world!", false);
190        assert!(count > 0 && count < 10, "Expected 1-9 tokens, got {}", count);
191    }
192
193    #[test]
194    fn test_real_tokenizer_known_value() {
195        let tok = Tokenizer::new(TokenizerKind::Gpt4);
196        let count = tok.count_tokens("Hello world", false);
197        assert_eq!(count, 2, "Expected 2 tokens for 'Hello world', got {}", count);
198    }
199
200    #[test]
201    fn test_real_vs_heuristic_comparison() {
202        let real = Tokenizer::new(TokenizerKind::Gpt4);
203        let heuristic = Tokenizer::new(TokenizerKind::Heuristic);
204
205        let code = "fn main() {\n    println!(\"Hello, world!\");\n}\n";
206        let real_count = real.count_tokens(code, false);
207        let heuristic_count = heuristic.count_tokens(code, false);
208
209        assert!(real_count > 0);
210        assert!(heuristic_count > 0);
211        assert!(real_count < code.len(), "Real count should be less than byte length");
212    }
213
214    #[test]
215    fn test_heuristic_overestimates_code() {
216        let real = Tokenizer::new(TokenizerKind::Gpt4);
217        let heuristic = Tokenizer::new(TokenizerKind::Heuristic);
218
219        let code = r#"
220use std::collections::HashMap;
221
222pub struct Config {
223    pub name: String,
224    pub values: HashMap<String, Vec<u32>>,
225}
226
227impl Config {
228    pub fn new(name: &str) -> Self {
229        Self {
230            name: name.to_string(),
231            values: HashMap::new(),
232        }
233    }
234
235    pub fn insert(&mut self, key: &str, val: u32) {
236        self.values.entry(key.to_string()).or_default().push(val);
237    }
238}
239"#;
240        let real_count = real.count_tokens(code, false);
241        let heuristic_count = heuristic.count_tokens(code, false);
242
243        assert!(
244            heuristic_count >= real_count,
245            "Heuristic ({}) should overestimate vs real ({}) for code",
246            heuristic_count,
247            real_count
248        );
249    }
250
251    #[test]
252    fn test_heuristic_overestimates_prose() {
253        let real = Tokenizer::new(TokenizerKind::Gpt4);
254        let heuristic = Tokenizer::new(TokenizerKind::Heuristic);
255
256        let prose = "The quick brown fox jumps over the lazy dog. \
257            This is a longer piece of prose text that should be tokenized \
258            differently from code. Natural language tends to have longer tokens \
259            on average compared to source code with its punctuation and symbols.";
260
261        let real_count = real.count_tokens(prose, true);
262        let heuristic_count = heuristic.count_tokens(prose, true);
263
264        assert!(real_count > 0, "Real tokenizer should produce tokens for prose");
265        assert!(heuristic_count > 0, "Heuristic should produce tokens for prose");
266        let ratio = heuristic_count as f64 / real_count as f64;
267        assert!(
268            ratio > 0.5 && ratio < 3.0,
269            "Heuristic ({}) and real ({}) should be within 3x of each other for prose (ratio: {:.2})",
270            heuristic_count,
271            real_count,
272            ratio
273        );
274    }
275
276    #[test]
277    fn test_all_real_tokenizers_produce_same_counts() {
278        let claude = Tokenizer::new(TokenizerKind::Claude);
279        let gpt4 = Tokenizer::new(TokenizerKind::Gpt4);
280        let gpt35 = Tokenizer::new(TokenizerKind::Gpt35);
281
282        let text = "fn main() { println!(\"Hello, world!\"); }";
283
284        let claude_count = claude.count_tokens(text, false);
285        let gpt4_count = gpt4.count_tokens(text, false);
286        let gpt35_count = gpt35.count_tokens(text, false);
287
288        assert_eq!(claude_count, gpt4_count, "Claude and GPT-4 should match");
289        assert_eq!(gpt4_count, gpt35_count, "GPT-4 and GPT-3.5 should match");
290    }
291
292    #[test]
293    fn test_real_tokenizer_empty_string() {
294        let tok = Tokenizer::new(TokenizerKind::Gpt4);
295        assert_eq!(tok.count_tokens("", false), 0);
296        assert_eq!(tok.count_tokens("", true), 0);
297    }
298
299    #[test]
300    fn test_real_tokenizer_whitespace_only() {
301        let tok = Tokenizer::new(TokenizerKind::Gpt4);
302        let count = tok.count_tokens("   \n\n\t  ", false);
303        assert!(count > 0, "Whitespace should produce at least 1 token, got {}", count);
304    }
305
306    #[test]
307    fn test_tokenizer_kind_display() {
308        assert_eq!(format!("{}", TokenizerKind::Heuristic), "heuristic");
309        assert_eq!(format!("{}", TokenizerKind::Claude), "claude");
310        assert_eq!(format!("{}", TokenizerKind::Gpt4), "gpt-4");
311        assert_eq!(format!("{}", TokenizerKind::Gpt35), "gpt-3.5");
312    }
313
314    #[test]
315    fn test_tokenizer_kind_roundtrip() {
316        for kind in [
317            TokenizerKind::Heuristic,
318            TokenizerKind::Claude,
319            TokenizerKind::Gpt4,
320            TokenizerKind::Gpt35,
321        ] {
322            let display = format!("{}", kind);
323            let parsed = TokenizerKind::parse_name(&display);
324            assert_eq!(
325                parsed,
326                Some(kind.clone()),
327                "Roundtrip failed for {}",
328                display
329            );
330        }
331    }
332
333    #[test]
334    fn test_tokenizer_kind_accessor() {
335        let tok = Tokenizer::new(TokenizerKind::Claude);
336        assert_eq!(*tok.kind(), TokenizerKind::Claude);
337    }
338}