entelix_tokenizer_hf/
lib.rs1#![cfg_attr(docsrs, feature(doc_cfg))]
45#![doc(html_root_url = "https://docs.rs/entelix-tokenizer-hf/0.5.3")]
46#![deny(missing_docs)]
47#![allow(
48 clippy::doc_markdown
52)]
53
54use std::fmt;
55use std::sync::Arc;
56
57use entelix_core::TokenCounter;
58use thiserror::Error;
59use tokenizers::Tokenizer;
60
61#[derive(Debug, Error)]
69#[non_exhaustive]
70pub enum HfTokenizerError {
71 #[error("HuggingFace tokenizer load failed for {encoding_name}: {message}")]
74 Load {
75 encoding_name: String,
80 message: String,
82 },
83}
84
85#[derive(Clone)]
91pub struct HfTokenCounter {
92 tokenizer: Arc<Tokenizer>,
93 encoding_name: &'static str,
94}
95
96impl fmt::Debug for HfTokenCounter {
97 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98 f.debug_struct("HfTokenCounter")
99 .field("encoding_name", &self.encoding_name)
100 .finish_non_exhaustive()
101 }
102}
103
104impl HfTokenCounter {
105 pub fn from_bytes(
116 bytes: &[u8],
117 encoding_name: impl Into<String>,
118 ) -> Result<Self, HfTokenizerError> {
119 let encoding_name = encoding_name.into();
120 let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| HfTokenizerError::Load {
121 encoding_name: encoding_name.clone(),
122 message: e.to_string(),
123 })?;
124 let encoding_name: &'static str = Box::leak(encoding_name.into_boxed_str());
125 Ok(Self {
126 tokenizer: Arc::new(tokenizer),
127 encoding_name,
128 })
129 }
130
131 #[must_use]
133 pub const fn encoding(&self) -> &'static str {
134 self.encoding_name
135 }
136}
137
138impl TokenCounter for HfTokenCounter {
139 fn count(&self, text: &str) -> u64 {
140 match self.tokenizer.encode(text, false) {
141 Ok(encoding) => u64::try_from(encoding.len()).unwrap_or(u64::MAX),
142 Err(error) => {
143 tracing::warn!(
144 tokenizer = %self.encoding_name,
145 error = %error,
146 "HfTokenCounter::count encode failed; returning u64::MAX (conservative)",
147 );
148 u64::MAX
149 }
150 }
151 }
152
153 fn encoding_name(&self) -> &'static str {
154 self.encoding_name
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use entelix_core::ir::{ContentPart, Message, Role};
162
163 const TINY_TOKENIZER_JSON: &str = r#"{
170 "version": "1.0",
171 "truncation": null,
172 "padding": null,
173 "added_tokens": [],
174 "normalizer": null,
175 "pre_tokenizer": { "type": "Whitespace" },
176 "post_processor": null,
177 "decoder": null,
178 "model": {
179 "type": "WordLevel",
180 "vocab": {
181 "[UNK]": 0,
182 "hello": 1,
183 "world": 2,
184 "foo": 3,
185 "bar": 4
186 },
187 "unk_token": "[UNK]"
188 }
189 }"#;
190
191 type TestResult = Result<(), HfTokenizerError>;
192
193 fn counter() -> Result<HfTokenCounter, HfTokenizerError> {
194 HfTokenCounter::from_bytes(TINY_TOKENIZER_JSON.as_bytes(), "tiny-wordlevel")
195 }
196
197 #[test]
198 fn from_bytes_accepts_valid_tokenizer_json() -> TestResult {
199 let counter = counter()?;
200 assert_eq!(counter.encoding(), "tiny-wordlevel");
201 assert_eq!(counter.encoding_name(), "tiny-wordlevel");
202 Ok(())
203 }
204
205 #[test]
206 fn from_bytes_rejects_garbage_input() {
207 let result = HfTokenCounter::from_bytes(b"this is not json", "any");
208 assert!(matches!(result, Err(HfTokenizerError::Load { .. })));
209 }
210
211 #[test]
212 fn from_bytes_rejects_empty_input() {
213 let result = HfTokenCounter::from_bytes(b"", "any");
214 assert!(matches!(result, Err(HfTokenizerError::Load { .. })));
215 }
216
217 #[test]
218 fn load_error_captures_encoding_name() {
219 let result = HfTokenCounter::from_bytes(b"garbage", "my-bad-tokenizer");
220 match result {
221 Err(HfTokenizerError::Load {
222 encoding_name,
223 message,
224 }) => {
225 assert_eq!(encoding_name, "my-bad-tokenizer");
226 assert!(!message.is_empty(), "upstream message must propagate");
227 }
228 other => panic!("expected Load error, got {other:?}"),
229 }
230 }
231
232 #[test]
233 fn count_known_inputs_match_vocab_size() -> TestResult {
234 let counter = counter()?;
235 assert_eq!(counter.count(""), 0);
236 assert_eq!(counter.count("hello"), 1);
237 assert_eq!(counter.count("hello world"), 2);
238 assert_eq!(counter.count("hello world foo bar"), 4);
239 Ok(())
240 }
241
242 #[test]
243 fn unknown_words_count_as_unk_tokens() -> TestResult {
244 let counter = counter()?;
247 assert_eq!(counter.count("xyz abc"), 2);
248 assert_eq!(counter.count("hello xyz world abc"), 4);
249 Ok(())
250 }
251
252 #[test]
253 fn count_messages_default_walks_text_parts() -> TestResult {
254 let counter = counter()?;
255 let msg = Message::new(
256 Role::User,
257 vec![
258 ContentPart::text("hello world"), ContentPart::text("foo bar"), ],
261 );
262 assert_eq!(counter.count_messages(std::slice::from_ref(&msg)), 4);
263 Ok(())
264 }
265
266 #[test]
267 fn count_messages_skips_non_text_parts() -> TestResult {
268 let counter = counter()?;
269 let msg = Message::new(
270 Role::Assistant,
271 vec![
272 ContentPart::text("hello world"), ContentPart::ToolUse {
274 id: "call_1".into(),
275 name: "search".into(),
276 input: serde_json::json!({"q": "rust"}),
277 provider_echoes: Vec::new(),
278 },
279 ],
280 );
281 assert_eq!(counter.count_messages(std::slice::from_ref(&msg)), 2);
282 Ok(())
283 }
284
285 #[test]
286 fn arc_dyn_dispatch_forwards_through_blanket_impl() -> TestResult {
287 let counter: Arc<dyn TokenCounter> = Arc::new(counter()?);
288 assert_eq!(counter.count("hello world"), 2);
289 assert_eq!(counter.encoding_name(), "tiny-wordlevel");
290 Ok(())
291 }
292
293 #[test]
294 fn clone_shares_tokenizer_and_keeps_encoding_name() -> TestResult {
295 let original = counter()?;
296 let cloned = original.clone();
297 assert_eq!(cloned.encoding(), "tiny-wordlevel");
298 assert_eq!(cloned.count("hello"), original.count("hello"));
299 assert!(Arc::ptr_eq(&original.tokenizer, &cloned.tokenizer));
300 Ok(())
301 }
302
303 #[test]
304 fn debug_includes_encoding_not_tokenizer_table() -> TestResult {
305 let counter = counter()?;
306 let debug = format!("{counter:?}");
307 assert!(debug.contains("tiny-wordlevel"));
308 assert!(
309 !debug.contains("Tokenizer ") && !debug.contains("vocab"),
310 "Debug must not dump the parsed tokenizer: {debug}"
311 );
312 Ok(())
313 }
314
315 #[test]
316 fn encoding_name_outlives_counter_drop() -> TestResult {
317 let leaked: &'static str = {
320 let counter = HfTokenCounter::from_bytes(TINY_TOKENIZER_JSON.as_bytes(), "scoped")?;
321 counter.encoding_name()
322 };
323 assert_eq!(leaked, "scoped");
324 Ok(())
325 }
326}