entelix_tokenizer_tiktoken/
lib.rs1#![cfg_attr(docsrs, feature(doc_cfg))]
36#![doc(html_root_url = "https://docs.rs/entelix-tokenizer-tiktoken/0.5.3")]
37#![deny(missing_docs)]
38#![allow(
39 clippy::doc_markdown
43)]
44
45use std::fmt;
46use std::sync::Arc;
47
48use entelix_core::TokenCounter;
49use thiserror::Error;
50use tiktoken_rs::CoreBPE;
51
52#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
56#[non_exhaustive]
57pub enum TiktokenEncoding {
58 Cl100kBase,
61 O200kBase,
63 P50kBase,
65 R50kBase,
68}
69
70impl TiktokenEncoding {
71 #[must_use]
76 pub const fn name(self) -> &'static str {
77 match self {
78 Self::Cl100kBase => "cl100k_base",
79 Self::O200kBase => "o200k_base",
80 Self::P50kBase => "p50k_base",
81 Self::R50kBase => "r50k_base",
82 }
83 }
84}
85
86#[derive(Debug, Error)]
94#[non_exhaustive]
95pub enum TiktokenError {
96 #[error("tiktoken BPE load failed for {encoding_name}: {message}")]
102 Load {
103 encoding_name: &'static str,
106 message: String,
108 },
109}
110
111#[derive(Clone)]
117pub struct TiktokenCounter {
118 bpe: Arc<CoreBPE>,
119 encoding: TiktokenEncoding,
120}
121
122impl fmt::Debug for TiktokenCounter {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 f.debug_struct("TiktokenCounter")
125 .field("encoding", &self.encoding)
126 .finish_non_exhaustive()
127 }
128}
129
130impl TiktokenCounter {
131 pub fn for_encoding(encoding: TiktokenEncoding) -> Result<Self, TiktokenError> {
135 let bpe = match encoding {
136 TiktokenEncoding::Cl100kBase => tiktoken_rs::cl100k_base(),
137 TiktokenEncoding::O200kBase => tiktoken_rs::o200k_base(),
138 TiktokenEncoding::P50kBase => tiktoken_rs::p50k_base(),
139 TiktokenEncoding::R50kBase => tiktoken_rs::r50k_base(),
140 }
141 .map_err(|e| TiktokenError::Load {
142 encoding_name: encoding.name(),
143 message: e.to_string(),
144 })?;
145 Ok(Self {
146 bpe: Arc::new(bpe),
147 encoding,
148 })
149 }
150
151 #[must_use]
153 pub const fn encoding(&self) -> TiktokenEncoding {
154 self.encoding
155 }
156}
157
158impl TokenCounter for TiktokenCounter {
159 fn count(&self, text: &str) -> u64 {
160 u64::try_from(self.bpe.encode_ordinary(text).len()).unwrap_or(u64::MAX)
166 }
167
168 fn encoding_name(&self) -> &'static str {
169 self.encoding.name()
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use entelix_core::ir::{ContentPart, Message, Role};
177
178 type TestResult = Result<(), TiktokenError>;
179
180 #[test]
181 fn each_encoding_loads_successfully() -> TestResult {
182 for encoding in [
183 TiktokenEncoding::Cl100kBase,
184 TiktokenEncoding::O200kBase,
185 TiktokenEncoding::P50kBase,
186 TiktokenEncoding::R50kBase,
187 ] {
188 let counter = TiktokenCounter::for_encoding(encoding)?;
189 assert_eq!(counter.encoding(), encoding);
190 assert_eq!(counter.encoding_name(), encoding.name());
191 }
192 Ok(())
193 }
194
195 #[test]
196 fn empty_string_counts_zero() -> TestResult {
197 let counter = TiktokenCounter::for_encoding(TiktokenEncoding::Cl100kBase)?;
198 assert_eq!(counter.count(""), 0);
199 Ok(())
200 }
201
202 #[test]
203 fn cl100k_base_counts_match_known_tiktoken_values() -> TestResult {
204 let counter = TiktokenCounter::for_encoding(TiktokenEncoding::Cl100kBase)?;
211 assert_eq!(counter.count("Hello world"), 2);
212 assert_eq!(counter.count("tiktoken is great!"), 6);
213 Ok(())
214 }
215
216 #[test]
217 fn o200k_base_handles_multibyte_utf8() -> TestResult {
218 let counter = TiktokenCounter::for_encoding(TiktokenEncoding::O200kBase)?;
223 let count = counter.count("안녕 세계");
224 assert!(count > 0, "non-empty CJK text must count above zero");
225 assert!(
226 count < 20,
227 "five-grapheme CJK should not bloat past 20 tokens"
228 );
229 Ok(())
230 }
231
232 #[test]
233 fn longer_text_produces_more_tokens() -> TestResult {
234 let counter = TiktokenCounter::for_encoding(TiktokenEncoding::Cl100kBase)?;
235 let short = counter.count("hello");
236 let long = counter.count("hello world this is a longer sentence with more tokens");
237 assert!(long > short, "monotonicity: longer input → more tokens");
238 Ok(())
239 }
240
241 #[test]
242 fn count_messages_default_walks_text_parts() -> TestResult {
243 let counter = TiktokenCounter::for_encoding(TiktokenEncoding::Cl100kBase)?;
244 let msg = Message::new(
245 Role::User,
246 vec![
247 ContentPart::text("Hello world"), ContentPart::text("tiktoken is great!"), ],
250 );
251 assert_eq!(counter.count_messages(std::slice::from_ref(&msg)), 8);
252 Ok(())
253 }
254
255 #[test]
256 fn count_messages_skips_non_text_parts() -> TestResult {
257 let counter = TiktokenCounter::for_encoding(TiktokenEncoding::Cl100kBase)?;
258 let msg = Message::new(
259 Role::Assistant,
260 vec![
261 ContentPart::text("Hello world"), ContentPart::ToolUse {
263 id: "call_1".into(),
264 name: "search".into(),
265 input: serde_json::json!({"q": "rust"}),
266 provider_echoes: Vec::new(),
267 },
268 ],
269 );
270 assert_eq!(counter.count_messages(std::slice::from_ref(&msg)), 2);
271 Ok(())
272 }
273
274 #[test]
275 fn arc_dyn_dispatch_forwards_through_blanket_impl() -> TestResult {
276 let counter: Arc<dyn TokenCounter> =
277 Arc::new(TiktokenCounter::for_encoding(TiktokenEncoding::Cl100kBase)?);
278 assert_eq!(counter.count("Hello world"), 2);
279 assert_eq!(counter.encoding_name(), "cl100k_base");
280 Ok(())
281 }
282
283 #[test]
284 fn clone_shares_bpe_and_keeps_encoding() -> TestResult {
285 let original = TiktokenCounter::for_encoding(TiktokenEncoding::O200kBase)?;
286 let cloned = original.clone();
287 assert_eq!(cloned.encoding(), TiktokenEncoding::O200kBase);
288 assert_eq!(cloned.count("hello"), original.count("hello"));
289 assert!(Arc::ptr_eq(&original.bpe, &cloned.bpe));
292 Ok(())
293 }
294
295 #[test]
296 fn debug_includes_encoding_not_bpe_table() -> TestResult {
297 let counter = TiktokenCounter::for_encoding(TiktokenEncoding::Cl100kBase)?;
298 let debug = format!("{counter:?}");
299 assert!(debug.contains("Cl100kBase"));
300 assert!(
301 !debug.contains("CoreBPE"),
302 "Debug must not dump the BPE tables: {debug}"
303 );
304 Ok(())
305 }
306
307 #[test]
308 fn encoding_name_round_trips() {
309 assert_eq!(TiktokenEncoding::Cl100kBase.name(), "cl100k_base");
310 assert_eq!(TiktokenEncoding::O200kBase.name(), "o200k_base");
311 assert_eq!(TiktokenEncoding::P50kBase.name(), "p50k_base");
312 assert_eq!(TiktokenEncoding::R50kBase.name(), "r50k_base");
313 }
314}