Skip to main content

entelix_tokenizer_tiktoken/
lib.rs

1//! # entelix-tokenizer-tiktoken
2//!
3//! Vendor-accurate [`TokenCounter`] for OpenAI's BPE tokenizer family —
4//! `cl100k_base`, `o200k_base`, `p50k_base`, `r50k_base`. Wraps
5//! [`tiktoken-rs`](https://crates.io/crates/tiktoken-rs) with eager BPE
6//! preload at construction so the per-call `count` stays synchronous
7//! per the [`TokenCounter`] contract.
8//!
9//! ## Encoding to model mapping
10//!
11//! - [`TiktokenEncoding::Cl100kBase`] — GPT-3.5-turbo, GPT-4,
12//!   GPT-4-turbo, text-embedding-3-*.
13//! - [`TiktokenEncoding::O200kBase`] — GPT-4o, GPT-4o-mini, o1, o3,
14//!   o3-mini, o4.
15//! - [`TiktokenEncoding::P50kBase`] — GPT-3 davinci, codex.
16//! - [`TiktokenEncoding::R50kBase`] — GPT-3 ada / babbage / curie,
17//!   GPT-2.
18//!
19//! The mapping is left to operators by design — OpenAI changes it over
20//! time, and accidentally pinning a stale mapping silently miscounts
21//! without surfacing a build failure. Pick the encoding for your
22//! target model and the wrapper preloads the matching BPE tables.
23//!
24//! ## Why eager preload
25//!
26//! The [`TokenCounter`] trait is intentionally synchronous — counters
27//! get called from inside hot dispatch paths (pre-flight `RunBudget`
28//! checks, splitter sizing) where awaiting on a lazy table-load
29//! introduces unbounded latency. `TiktokenCounter` therefore loads the
30//! BPE tables eagerly inside [`TiktokenCounter::for_encoding`] and
31//! caches them behind an [`Arc`]. Cloning a `TiktokenCounter` is
32//! cheap; loading a fresh one re-parses the embedded tables so prefer
33//! `clone` for fan-out.
34
35#![cfg_attr(docsrs, feature(doc_cfg))]
36#![doc(html_root_url = "https://docs.rs/entelix-tokenizer-tiktoken/0.5.3")]
37#![deny(missing_docs)]
38#![allow(
39    // Vendor-name proper nouns (`OpenAI`, `OTel`, `BPE`, `GPT-4o`)
40    // appear throughout the docs; backtick-quoting every occurrence
41    // hurts readability without adding signal.
42    clippy::doc_markdown
43)]
44
45use std::fmt;
46use std::sync::Arc;
47
48use entelix_core::TokenCounter;
49use thiserror::Error;
50use tiktoken_rs::CoreBPE;
51
52/// OpenAI BPE encoding family. Pick the variant matching the target
53/// model — see the [crate-level docs](crate#encoding-to-model-mapping)
54/// for the model-to-encoding table.
55#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
56#[non_exhaustive]
57pub enum TiktokenEncoding {
58    /// `cl100k_base` — GPT-3.5-turbo, GPT-4, GPT-4-turbo, the
59    /// `text-embedding-3-*` family.
60    Cl100kBase,
61    /// `o200k_base` — GPT-4o, GPT-4o-mini, o1, o3, o3-mini, o4.
62    O200kBase,
63    /// `p50k_base` — GPT-3 davinci, codex.
64    P50kBase,
65    /// `r50k_base` — GPT-3 ada / babbage / curie + the original GPT-2
66    /// tokenizer.
67    R50kBase,
68}
69
70impl TiktokenEncoding {
71    /// Canonical encoding name as published by OpenAI's tiktoken
72    /// reference implementation. Surfaces on
73    /// [`TokenCounter::encoding_name`] and the OTel
74    /// `gen_ai.tokenizer.name` attribute.
75    #[must_use]
76    pub const fn name(self) -> &'static str {
77        match self {
78            Self::Cl100kBase => "cl100k_base",
79            Self::O200kBase => "o200k_base",
80            Self::P50kBase => "p50k_base",
81            Self::R50kBase => "r50k_base",
82        }
83    }
84}
85
86/// Errors raised when constructing a [`TiktokenCounter`].
87///
88/// [`tiktoken-rs`](https://crates.io/crates/tiktoken-rs) returns
89/// `Box<dyn Error>` from its loader functions; this type strips the
90/// upstream chain to a `String` so the error stays
91/// `Send + Sync + 'static` for ergonomic cross-thread propagation
92/// (downstream operators map this onto `entelix_core::Error::config`).
93#[derive(Debug, Error)]
94#[non_exhaustive]
95pub enum TiktokenError {
96    /// Loading the BPE tables for the requested encoding failed.
97    /// In practice the upstream loaders only fail if the embedded
98    /// merge / vocab tables fail to parse, which would indicate an
99    /// upstream packaging bug. Variant shape mirrors
100    /// `HfTokenizerError::Load` for cross-companion consistency.
101    #[error("tiktoken BPE load failed for {encoding_name}: {message}")]
102    Load {
103        /// Canonical encoding name the load was attempted for
104        /// (e.g. `"cl100k_base"`).
105        encoding_name: &'static str,
106        /// Upstream `tiktoken-rs` error message (chain stripped).
107        message: String,
108    },
109}
110
111/// [`TokenCounter`] impl backed by [`tiktoken-rs`](https://crates.io/crates/tiktoken-rs).
112///
113/// Cloning is cheap — the BPE tables sit behind an [`Arc`] so every
114/// clone shares one preloaded instance. Construct once at app boot,
115/// share across `ChatModelConfig` instances and ingestion pipelines.
116#[derive(Clone)]
117pub struct TiktokenCounter {
118    bpe: Arc<CoreBPE>,
119    encoding: TiktokenEncoding,
120}
121
122impl fmt::Debug for TiktokenCounter {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        f.debug_struct("TiktokenCounter")
125            .field("encoding", &self.encoding)
126            .finish_non_exhaustive()
127    }
128}
129
130impl TiktokenCounter {
131    /// Construct a counter for `encoding`, preloading the BPE tables
132    /// eagerly. The returned counter is `Clone` and ready for hot-path
133    /// dispatch.
134    pub fn for_encoding(encoding: TiktokenEncoding) -> Result<Self, TiktokenError> {
135        let bpe = match encoding {
136            TiktokenEncoding::Cl100kBase => tiktoken_rs::cl100k_base(),
137            TiktokenEncoding::O200kBase => tiktoken_rs::o200k_base(),
138            TiktokenEncoding::P50kBase => tiktoken_rs::p50k_base(),
139            TiktokenEncoding::R50kBase => tiktoken_rs::r50k_base(),
140        }
141        .map_err(|e| TiktokenError::Load {
142            encoding_name: encoding.name(),
143            message: e.to_string(),
144        })?;
145        Ok(Self {
146            bpe: Arc::new(bpe),
147            encoding,
148        })
149    }
150
151    /// Inspect the configured encoding.
152    #[must_use]
153    pub const fn encoding(&self) -> TiktokenEncoding {
154        self.encoding
155    }
156}
157
158impl TokenCounter for TiktokenCounter {
159    fn count(&self, text: &str) -> u64 {
160        // `encode_ordinary` skips special-token handling — the right
161        // shape for content-economy budgeting since system / chat
162        // priming overhead is vendor-and-version-specific (operators
163        // wanting an exact chat-message tally override
164        // `count_messages` on a wrapper counter).
165        u64::try_from(self.bpe.encode_ordinary(text).len()).unwrap_or(u64::MAX)
166    }
167
168    fn encoding_name(&self) -> &'static str {
169        self.encoding.name()
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use entelix_core::ir::{ContentPart, Message, Role};
177
178    type TestResult = Result<(), TiktokenError>;
179
180    #[test]
181    fn each_encoding_loads_successfully() -> TestResult {
182        for encoding in [
183            TiktokenEncoding::Cl100kBase,
184            TiktokenEncoding::O200kBase,
185            TiktokenEncoding::P50kBase,
186            TiktokenEncoding::R50kBase,
187        ] {
188            let counter = TiktokenCounter::for_encoding(encoding)?;
189            assert_eq!(counter.encoding(), encoding);
190            assert_eq!(counter.encoding_name(), encoding.name());
191        }
192        Ok(())
193    }
194
195    #[test]
196    fn empty_string_counts_zero() -> TestResult {
197        let counter = TiktokenCounter::for_encoding(TiktokenEncoding::Cl100kBase)?;
198        assert_eq!(counter.count(""), 0);
199        Ok(())
200    }
201
202    #[test]
203    fn cl100k_base_counts_match_known_tiktoken_values() -> TestResult {
204        // Reference values verified against the upstream Python
205        // `tiktoken` library (`enc.encode_ordinary(text)`):
206        //   "Hello world"      → [9906, 1917]                    = 2
207        //   "tiktoken is great!" → [83, 1609, 5963, 374, 2294, 0] = 6
208        // Hard-pinning here is a regression gate against
209        // `tiktoken-rs` upstream encoding drift.
210        let counter = TiktokenCounter::for_encoding(TiktokenEncoding::Cl100kBase)?;
211        assert_eq!(counter.count("Hello world"), 2);
212        assert_eq!(counter.count("tiktoken is great!"), 6);
213        Ok(())
214    }
215
216    #[test]
217    fn o200k_base_handles_multibyte_utf8() -> TestResult {
218        // CJK characters: tokenisation differs vs cl100k_base
219        // because o200k_base extends the vocabulary. Just verify the
220        // count is non-zero and bounded — exact value is encoding-
221        // version-specific so a strict pin would brittle-test.
222        let counter = TiktokenCounter::for_encoding(TiktokenEncoding::O200kBase)?;
223        let count = counter.count("안녕 세계");
224        assert!(count > 0, "non-empty CJK text must count above zero");
225        assert!(
226            count < 20,
227            "five-grapheme CJK should not bloat past 20 tokens"
228        );
229        Ok(())
230    }
231
232    #[test]
233    fn longer_text_produces_more_tokens() -> TestResult {
234        let counter = TiktokenCounter::for_encoding(TiktokenEncoding::Cl100kBase)?;
235        let short = counter.count("hello");
236        let long = counter.count("hello world this is a longer sentence with more tokens");
237        assert!(long > short, "monotonicity: longer input → more tokens");
238        Ok(())
239    }
240
241    #[test]
242    fn count_messages_default_walks_text_parts() -> TestResult {
243        let counter = TiktokenCounter::for_encoding(TiktokenEncoding::Cl100kBase)?;
244        let msg = Message::new(
245            Role::User,
246            vec![
247                ContentPart::text("Hello world"),        // 2 tokens (verified above)
248                ContentPart::text("tiktoken is great!"), // 6 tokens
249            ],
250        );
251        assert_eq!(counter.count_messages(std::slice::from_ref(&msg)), 8);
252        Ok(())
253    }
254
255    #[test]
256    fn count_messages_skips_non_text_parts() -> TestResult {
257        let counter = TiktokenCounter::for_encoding(TiktokenEncoding::Cl100kBase)?;
258        let msg = Message::new(
259            Role::Assistant,
260            vec![
261                ContentPart::text("Hello world"), // 2 tokens
262                ContentPart::ToolUse {
263                    id: "call_1".into(),
264                    name: "search".into(),
265                    input: serde_json::json!({"q": "rust"}),
266                    provider_echoes: Vec::new(),
267                },
268            ],
269        );
270        assert_eq!(counter.count_messages(std::slice::from_ref(&msg)), 2);
271        Ok(())
272    }
273
274    #[test]
275    fn arc_dyn_dispatch_forwards_through_blanket_impl() -> TestResult {
276        let counter: Arc<dyn TokenCounter> =
277            Arc::new(TiktokenCounter::for_encoding(TiktokenEncoding::Cl100kBase)?);
278        assert_eq!(counter.count("Hello world"), 2);
279        assert_eq!(counter.encoding_name(), "cl100k_base");
280        Ok(())
281    }
282
283    #[test]
284    fn clone_shares_bpe_and_keeps_encoding() -> TestResult {
285        let original = TiktokenCounter::for_encoding(TiktokenEncoding::O200kBase)?;
286        let cloned = original.clone();
287        assert_eq!(cloned.encoding(), TiktokenEncoding::O200kBase);
288        assert_eq!(cloned.count("hello"), original.count("hello"));
289        // Both clones share the same Arc — pointer equality verifies
290        // clone is cheap (shared parsed BPE table).
291        assert!(Arc::ptr_eq(&original.bpe, &cloned.bpe));
292        Ok(())
293    }
294
295    #[test]
296    fn debug_includes_encoding_not_bpe_table() -> TestResult {
297        let counter = TiktokenCounter::for_encoding(TiktokenEncoding::Cl100kBase)?;
298        let debug = format!("{counter:?}");
299        assert!(debug.contains("Cl100kBase"));
300        assert!(
301            !debug.contains("CoreBPE"),
302            "Debug must not dump the BPE tables: {debug}"
303        );
304        Ok(())
305    }
306
307    #[test]
308    fn encoding_name_round_trips() {
309        assert_eq!(TiktokenEncoding::Cl100kBase.name(), "cl100k_base");
310        assert_eq!(TiktokenEncoding::O200kBase.name(), "o200k_base");
311        assert_eq!(TiktokenEncoding::P50kBase.name(), "p50k_base");
312        assert_eq!(TiktokenEncoding::R50kBase.name(), "r50k_base");
313    }
314}