Skip to main content

entelix_tokenizer_hf/
lib.rs

1//! # entelix-tokenizer-hf
2//!
3//! Vendor-accurate [`TokenCounter`] wrapping the
4//! [HuggingFace `tokenizers`](https://crates.io/crates/tokenizers)
5//! crate. Construct from any `tokenizer.json` byte payload — the
6//! single canonical entry point for Llama 3, Qwen 2.5, Mistral,
7//! DeepSeek, Gemma, Phi, and every other open-weight model whose
8//! tokenizer is published in the HF format.
9//!
10//! ## Why bytes-only construction
11//!
12//! [`HfTokenCounter::from_bytes`] is the only constructor — there is
13//! no `from_file` or `from_pretrained` shortcut. Two reasons:
14//!
15//! - **Invariant 9 alignment** — entelix first-party crates do not
16//!   import `std::fs`. Operators read tokenizer files in their own
17//!   application code (or at compile time via `include_bytes!`) and
18//!   pass the byte payload in.
19//! - **No silent network IO** —
20//!   `tokenizers::Tokenizer::from_pretrained` does HTTP downloads
21//!   and disk caching as a side effect. Wrappers that need hub
22//!   integration ship as separate companion crates; this crate stays
23//!   pure.
24//!
25//! ## Encoding name
26//!
27//! [`TokenCounter::encoding_name`] returns `&'static str`, but the
28//! HF tokenizer format does not embed a canonical name. Operators
29//! supply a name at construction; the wrapper leaks it once into a
30//! `&'static str` so the trait method can return it directly. One
31//! allocation per [`HfTokenCounter::from_bytes`] call — the canonical
32//! "construct once at app boot, share an `Arc` everywhere"
33//! pattern keeps the leak cost a single `String` per process.
34//!
35//! ## Encode-failure semantics
36//!
37//! `tokenizers::Tokenizer::encode` is fallible — a misconfigured
38//! tokenizer JSON or a post-processor that rejects the input
39//! surfaces as `Err`. [`TokenCounter::count`] returns `u64::MAX` on
40//! such failures so `RunBudget` pre-flight checks fail closed
41//! (refuses the call rather than silently under-counting).
42//! `tracing::warn!` records the underlying error.
43
44#![cfg_attr(docsrs, feature(doc_cfg))]
45#![doc(html_root_url = "https://docs.rs/entelix-tokenizer-hf/0.5.3")]
46#![deny(missing_docs)]
47#![allow(
48    // Vendor-name proper nouns (`HuggingFace`, `Llama`, `Qwen`,
49    // `OpenAI`, `BPE`) appear throughout the docs; backtick-quoting
50    // every occurrence hurts readability without adding signal.
51    clippy::doc_markdown
52)]
53
54use std::fmt;
55use std::sync::Arc;
56
57use entelix_core::TokenCounter;
58use thiserror::Error;
59use tokenizers::Tokenizer;
60
61/// Errors raised when constructing an [`HfTokenCounter`].
62///
63/// The underlying `tokenizers` crate error chain is stripped to a
64/// `String` so the variant stays `Send + Sync + 'static` for
65/// ergonomic cross-thread propagation (operators map this onto
66/// `entelix_core::Error::config`). Variant shape mirrors
67/// `TiktokenError` for cross-companion consistency.
68#[derive(Debug, Error)]
69#[non_exhaustive]
70pub enum HfTokenizerError {
71    /// Loading the tokenizer JSON failed — invalid format,
72    /// unsupported tokenizer type, or schema-version mismatch.
73    #[error("HuggingFace tokenizer load failed for {encoding_name}: {message}")]
74    Load {
75        /// Operator-supplied encoding name the load was attempted
76        /// for. Captured here so the error trail names which
77        /// counter construction failed when an app boot wires
78        /// multiple HF tokenizers.
79        encoding_name: String,
80        /// Upstream `tokenizers` error message (chain stripped).
81        message: String,
82    },
83}
84
85/// [`TokenCounter`] backed by a HuggingFace [`Tokenizer`].
86///
87/// Cloning is cheap — the tokenizer sits behind an [`Arc`] so every
88/// clone shares one parsed instance. Construct once at app boot,
89/// share across `ChatModelConfig` and ingestion pipelines.
90#[derive(Clone)]
91pub struct HfTokenCounter {
92    tokenizer: Arc<Tokenizer>,
93    encoding_name: &'static str,
94}
95
96impl fmt::Debug for HfTokenCounter {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        f.debug_struct("HfTokenCounter")
99            .field("encoding_name", &self.encoding_name)
100            .finish_non_exhaustive()
101    }
102}
103
104impl HfTokenCounter {
105    /// Construct a counter from a `tokenizer.json` byte payload.
106    /// The supplied `encoding_name` surfaces on
107    /// [`TokenCounter::encoding_name`] and the OTel
108    /// `gen_ai.tokenizer.name` attribute — pick a stable identifier
109    /// for the model whose tokenizer the bytes encode (`"llama-3"`,
110    /// `"qwen-2.5"`, `"mistral"`, …).
111    ///
112    /// The encoding name is leaked once into a `&'static str` —
113    /// see the [crate-level docs](crate#encoding-name) for the
114    /// rationale.
115    pub fn from_bytes(
116        bytes: &[u8],
117        encoding_name: impl Into<String>,
118    ) -> Result<Self, HfTokenizerError> {
119        let encoding_name = encoding_name.into();
120        let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| HfTokenizerError::Load {
121            encoding_name: encoding_name.clone(),
122            message: e.to_string(),
123        })?;
124        let encoding_name: &'static str = Box::leak(encoding_name.into_boxed_str());
125        Ok(Self {
126            tokenizer: Arc::new(tokenizer),
127            encoding_name,
128        })
129    }
130
131    /// Inspect the configured encoding name.
132    #[must_use]
133    pub const fn encoding(&self) -> &'static str {
134        self.encoding_name
135    }
136}
137
138impl TokenCounter for HfTokenCounter {
139    fn count(&self, text: &str) -> u64 {
140        match self.tokenizer.encode(text, false) {
141            Ok(encoding) => u64::try_from(encoding.len()).unwrap_or(u64::MAX),
142            Err(error) => {
143                tracing::warn!(
144                    tokenizer = %self.encoding_name,
145                    error = %error,
146                    "HfTokenCounter::count encode failed; returning u64::MAX (conservative)",
147                );
148                u64::MAX
149            }
150        }
151    }
152
153    fn encoding_name(&self) -> &'static str {
154        self.encoding_name
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use entelix_core::ir::{ContentPart, Message, Role};
162
163    /// Hand-crafted minimal HuggingFace `tokenizer.json` — a
164    /// `WordLevel` model with a 4-word vocab plus `[UNK]` and the
165    /// `Whitespace` pre-tokenizer. Mirrors the canonical HF schema
166    /// so the wrapper's `from_bytes` path is exercised against the
167    /// exact byte shape an operator would `include_bytes!` from a
168    /// downloaded model.
169    const TINY_TOKENIZER_JSON: &str = r#"{
170        "version": "1.0",
171        "truncation": null,
172        "padding": null,
173        "added_tokens": [],
174        "normalizer": null,
175        "pre_tokenizer": { "type": "Whitespace" },
176        "post_processor": null,
177        "decoder": null,
178        "model": {
179            "type": "WordLevel",
180            "vocab": {
181                "[UNK]": 0,
182                "hello": 1,
183                "world": 2,
184                "foo": 3,
185                "bar": 4
186            },
187            "unk_token": "[UNK]"
188        }
189    }"#;
190
191    type TestResult = Result<(), HfTokenizerError>;
192
193    fn counter() -> Result<HfTokenCounter, HfTokenizerError> {
194        HfTokenCounter::from_bytes(TINY_TOKENIZER_JSON.as_bytes(), "tiny-wordlevel")
195    }
196
197    #[test]
198    fn from_bytes_accepts_valid_tokenizer_json() -> TestResult {
199        let counter = counter()?;
200        assert_eq!(counter.encoding(), "tiny-wordlevel");
201        assert_eq!(counter.encoding_name(), "tiny-wordlevel");
202        Ok(())
203    }
204
205    #[test]
206    fn from_bytes_rejects_garbage_input() {
207        let result = HfTokenCounter::from_bytes(b"this is not json", "any");
208        assert!(matches!(result, Err(HfTokenizerError::Load { .. })));
209    }
210
211    #[test]
212    fn from_bytes_rejects_empty_input() {
213        let result = HfTokenCounter::from_bytes(b"", "any");
214        assert!(matches!(result, Err(HfTokenizerError::Load { .. })));
215    }
216
217    #[test]
218    fn load_error_captures_encoding_name() {
219        let result = HfTokenCounter::from_bytes(b"garbage", "my-bad-tokenizer");
220        match result {
221            Err(HfTokenizerError::Load {
222                encoding_name,
223                message,
224            }) => {
225                assert_eq!(encoding_name, "my-bad-tokenizer");
226                assert!(!message.is_empty(), "upstream message must propagate");
227            }
228            other => panic!("expected Load error, got {other:?}"),
229        }
230    }
231
232    #[test]
233    fn count_known_inputs_match_vocab_size() -> TestResult {
234        let counter = counter()?;
235        assert_eq!(counter.count(""), 0);
236        assert_eq!(counter.count("hello"), 1);
237        assert_eq!(counter.count("hello world"), 2);
238        assert_eq!(counter.count("hello world foo bar"), 4);
239        Ok(())
240    }
241
242    #[test]
243    fn unknown_words_count_as_unk_tokens() -> TestResult {
244        // The vocab has [UNK] → 0; every out-of-vocab whitespace-
245        // separated word becomes one [UNK] token.
246        let counter = counter()?;
247        assert_eq!(counter.count("xyz abc"), 2);
248        assert_eq!(counter.count("hello xyz world abc"), 4);
249        Ok(())
250    }
251
252    #[test]
253    fn count_messages_default_walks_text_parts() -> TestResult {
254        let counter = counter()?;
255        let msg = Message::new(
256            Role::User,
257            vec![
258                ContentPart::text("hello world"), // 2 tokens
259                ContentPart::text("foo bar"),     // 2 tokens
260            ],
261        );
262        assert_eq!(counter.count_messages(std::slice::from_ref(&msg)), 4);
263        Ok(())
264    }
265
266    #[test]
267    fn count_messages_skips_non_text_parts() -> TestResult {
268        let counter = counter()?;
269        let msg = Message::new(
270            Role::Assistant,
271            vec![
272                ContentPart::text("hello world"), // 2 tokens
273                ContentPart::ToolUse {
274                    id: "call_1".into(),
275                    name: "search".into(),
276                    input: serde_json::json!({"q": "rust"}),
277                    provider_echoes: Vec::new(),
278                },
279            ],
280        );
281        assert_eq!(counter.count_messages(std::slice::from_ref(&msg)), 2);
282        Ok(())
283    }
284
285    #[test]
286    fn arc_dyn_dispatch_forwards_through_blanket_impl() -> TestResult {
287        let counter: Arc<dyn TokenCounter> = Arc::new(counter()?);
288        assert_eq!(counter.count("hello world"), 2);
289        assert_eq!(counter.encoding_name(), "tiny-wordlevel");
290        Ok(())
291    }
292
293    #[test]
294    fn clone_shares_tokenizer_and_keeps_encoding_name() -> TestResult {
295        let original = counter()?;
296        let cloned = original.clone();
297        assert_eq!(cloned.encoding(), "tiny-wordlevel");
298        assert_eq!(cloned.count("hello"), original.count("hello"));
299        assert!(Arc::ptr_eq(&original.tokenizer, &cloned.tokenizer));
300        Ok(())
301    }
302
303    #[test]
304    fn debug_includes_encoding_not_tokenizer_table() -> TestResult {
305        let counter = counter()?;
306        let debug = format!("{counter:?}");
307        assert!(debug.contains("tiny-wordlevel"));
308        assert!(
309            !debug.contains("Tokenizer ") && !debug.contains("vocab"),
310            "Debug must not dump the parsed tokenizer: {debug}"
311        );
312        Ok(())
313    }
314
315    #[test]
316    fn encoding_name_outlives_counter_drop() -> TestResult {
317        // Box::leak guarantees `&'static` lifetime — the name stays
318        // valid even after the counter that produced it is dropped.
319        let leaked: &'static str = {
320            let counter = HfTokenCounter::from_bytes(TINY_TOKENIZER_JSON.as_bytes(), "scoped")?;
321            counter.encoding_name()
322        };
323        assert_eq!(leaked, "scoped");
324        Ok(())
325    }
326}