devboy_format_pipeline/
token_counter.rs1use std::sync::OnceLock;
21
22use serde::{Deserialize, Serialize};
23use tiktoken_rs::CoreBPE;
24
25const CHARS_PER_TOKEN: f64 = 3.5;
34
35pub fn estimate_tokens(text: &str) -> usize {
42 Tokenizer::Heuristic.count(text)
43}
44
45pub fn tokens_to_chars(tokens: usize) -> usize {
49 (tokens as f64 * CHARS_PER_TOKEN).floor() as usize
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
58#[serde(rename_all = "snake_case")]
59pub enum Tokenizer {
60 #[default]
62 Heuristic,
63 Cl100kBase,
65 O200kBase,
68}
69
70impl Tokenizer {
71 pub fn count(&self, text: &str) -> usize {
78 if text.is_empty() {
79 return 0;
80 }
81 match self {
82 Self::Heuristic => (text.len() as f64 / CHARS_PER_TOKEN).ceil() as usize,
83 Self::Cl100kBase => match cl100k_bpe() {
84 Some(bpe) => bpe.encode_with_special_tokens(text).len(),
85 None => Self::Heuristic.count(text),
86 },
87 Self::O200kBase => match o200k_bpe() {
88 Some(bpe) => bpe.encode_with_special_tokens(text).len(),
89 None => Self::Heuristic.count(text),
90 },
91 }
92 }
93
94 pub fn as_str(&self) -> &'static str {
96 match self {
97 Self::Heuristic => "heuristic",
98 Self::Cl100kBase => "cl100k_base",
99 Self::O200kBase => "o200k_base",
100 }
101 }
102
103 pub fn from_str_lossy(s: &str) -> Self {
106 match s.to_ascii_lowercase().as_str() {
107 "cl100k_base" | "cl100k" => Self::Cl100kBase,
108 "o200k_base" | "o200k" => Self::O200kBase,
109 _ => Self::Heuristic,
110 }
111 }
112}
113
114fn cl100k_bpe() -> Option<&'static CoreBPE> {
119 static BPE: OnceLock<Option<CoreBPE>> = OnceLock::new();
120 BPE.get_or_init(|| match tiktoken_rs::cl100k_base() {
121 Ok(b) => Some(b),
122 Err(e) => {
123 tracing::warn!(
124 target: "devboy_format_pipeline::tokenizer",
125 "cl100k_base BPE table failed to load: {e} — \
126 falling back to chars/3.5 heuristic"
127 );
128 None
129 }
130 })
131 .as_ref()
132}
133
134fn o200k_bpe() -> Option<&'static CoreBPE> {
135 static BPE: OnceLock<Option<CoreBPE>> = OnceLock::new();
136 BPE.get_or_init(|| match tiktoken_rs::o200k_base() {
137 Ok(b) => Some(b),
138 Err(e) => {
139 tracing::warn!(
140 target: "devboy_format_pipeline::tokenizer",
141 "o200k_base BPE table failed to load: {e} — \
142 falling back to chars/3.5 heuristic"
143 );
144 None
145 }
146 })
147 .as_ref()
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn test_empty_string() {
156 assert_eq!(estimate_tokens(""), 0);
157 assert_eq!(Tokenizer::Cl100kBase.count(""), 0);
158 assert_eq!(Tokenizer::O200kBase.count(""), 0);
159 }
160
161 #[test]
162 fn test_short_text() {
163 assert_eq!(estimate_tokens("hello"), 2);
165 }
166
167 #[test]
168 fn test_structured_data() {
169 let toon = "key: gh#1\ntitle: Fix bug\nstate: open";
170 let tokens = estimate_tokens(toon);
171 assert_eq!(tokens, 11);
173 }
174
175 #[test]
176 fn test_round_trip() {
177 let budget = 8000;
178 let chars = tokens_to_chars(budget);
179 let back = estimate_tokens(&"x".repeat(chars));
180 assert!((back as i64 - budget as i64).unsigned_abs() <= 1);
182 }
183
184 #[test]
185 fn test_tokens_to_chars() {
186 assert_eq!(tokens_to_chars(8000), 28000);
188 }
189
190 #[test]
191 fn cl100k_and_o200k_produce_positive_counts_on_simple_input() {
192 let phrase = "hello world";
198 let heuristic = Tokenizer::Heuristic.count(phrase);
199 for tk in [Tokenizer::Cl100kBase, Tokenizer::O200kBase] {
200 let n = tk.count(phrase);
201 assert!(n > 0, "{tk:?} returned zero on `{phrase}`");
202 assert!(
203 n <= heuristic,
204 "{tk:?} reported {n} tokens, worse than the {heuristic}-token heuristic prior"
205 );
206 }
207 }
208
209 #[test]
210 fn cl100k_and_o200k_agree_on_hello_world() {
211 let cl = Tokenizer::Cl100kBase.count("hello world");
215 let o2 = Tokenizer::O200kBase.count("hello world");
216 assert_eq!(cl, o2, "cl100k and o200k should agree on `hello world`");
217 }
218
219 #[test]
220 fn cl100k_and_o200k_disagree_on_jsonish() {
221 let json = "{\"id\":42,\"name\":\"alpha\",\"tags\":[\"x\",\"y\",\"z\"]}";
226 let cl = Tokenizer::Cl100kBase.count(json);
227 let o2 = Tokenizer::O200kBase.count(json);
228 assert!(cl > 0 && o2 > 0);
229 assert_ne!(cl, o2);
230 }
231
232 #[test]
233 fn heuristic_default_is_heuristic() {
234 assert_eq!(Tokenizer::default(), Tokenizer::Heuristic);
235 assert_eq!(Tokenizer::default().as_str(), "heuristic");
236 }
237
238 #[test]
239 fn from_str_lossy_known_and_unknown() {
240 assert_eq!(
241 Tokenizer::from_str_lossy("cl100k_base"),
242 Tokenizer::Cl100kBase
243 );
244 assert_eq!(Tokenizer::from_str_lossy("CL100K"), Tokenizer::Cl100kBase);
245 assert_eq!(
246 Tokenizer::from_str_lossy("o200k_base"),
247 Tokenizer::O200kBase
248 );
249 assert_eq!(Tokenizer::from_str_lossy("o200k"), Tokenizer::O200kBase);
250 assert_eq!(Tokenizer::from_str_lossy("nonsense"), Tokenizer::Heuristic);
251 assert_eq!(Tokenizer::from_str_lossy(""), Tokenizer::Heuristic);
252 }
253
254 #[test]
255 fn round_trip_serde() {
256 for tk in [
257 Tokenizer::Heuristic,
258 Tokenizer::Cl100kBase,
259 Tokenizer::O200kBase,
260 ] {
261 let json = serde_json::to_string(&tk).unwrap();
262 let back: Tokenizer = serde_json::from_str(&json).unwrap();
263 assert_eq!(tk, back);
264 assert_eq!(Tokenizer::from_str_lossy(tk.as_str()), tk);
266 }
267 }
268}