Skip to main content

forgellm_runtime/
tokenizer.rs

1//! Tokenizer wrapper — encode text to token IDs and decode back.
2//!
3//! Uses the HuggingFace `tokenizers` crate, loading from either
4//! a `tokenizer.json` file (HF format) or extracting from GGUF metadata.
5
6use std::path::Path;
7
8/// Wrapper around the HuggingFace tokenizer.
9pub struct Tokenizer {
10    inner: tokenizers::Tokenizer,
11}
12
13/// Errors during tokenizer operations.
14#[derive(Debug, thiserror::Error)]
15pub enum TokenizerError {
16    #[error("failed to load tokenizer: {0}")]
17    Load(String),
18
19    #[error("encoding failed: {0}")]
20    Encode(String),
21
22    #[error("decoding failed: {0}")]
23    Decode(String),
24}
25
26impl Tokenizer {
27    /// Load a tokenizer from a `tokenizer.json` file.
28    pub fn from_file(path: impl AsRef<Path>) -> Result<Self, TokenizerError> {
29        let inner = tokenizers::Tokenizer::from_file(path.as_ref())
30            .map_err(|e| TokenizerError::Load(e.to_string()))?;
31        Ok(Self { inner })
32    }
33
34    /// Load a tokenizer from a JSON string.
35    pub fn from_json(json: &str) -> Result<Self, TokenizerError> {
36        let inner: tokenizers::Tokenizer =
37            json.parse()
38                .map_err(|e: Box<dyn std::error::Error + Send + Sync>| {
39                    TokenizerError::Load(e.to_string())
40                })?;
41        Ok(Self { inner })
42    }
43
44    /// Encode text into token IDs.
45    pub fn encode(&self, text: &str) -> Result<Vec<u32>, TokenizerError> {
46        let encoding = self
47            .inner
48            .encode(text, false)
49            .map_err(|e| TokenizerError::Encode(e.to_string()))?;
50        Ok(encoding.get_ids().to_vec())
51    }
52
53    /// Encode text with special tokens (e.g., BOS).
54    pub fn encode_with_special(&self, text: &str) -> Result<Vec<u32>, TokenizerError> {
55        let encoding = self
56            .inner
57            .encode(text, true)
58            .map_err(|e| TokenizerError::Encode(e.to_string()))?;
59        Ok(encoding.get_ids().to_vec())
60    }
61
62    /// Decode token IDs back to text.
63    pub fn decode(&self, ids: &[u32]) -> Result<String, TokenizerError> {
64        self.inner
65            .decode(ids, true)
66            .map_err(|e| TokenizerError::Decode(e.to_string()))
67    }
68
69    /// Decode a single token ID to text.
70    pub fn decode_one(&self, id: u32) -> Result<String, TokenizerError> {
71        self.decode(&[id])
72    }
73
74    /// Get the vocabulary size.
75    pub fn vocab_size(&self) -> usize {
76        self.inner.get_vocab_size(true)
77    }
78
79    /// Get the token ID for a special token by content (e.g., "<|endoftext|>").
80    pub fn token_to_id(&self, token: &str) -> Option<u32> {
81        self.inner.token_to_id(token)
82    }
83
84    /// Get the BOS (beginning of sequence) token ID, if defined.
85    pub fn bos_token_id(&self) -> Option<u32> {
86        // Common BOS tokens across models
87        self.token_to_id("<s>")
88            .or_else(|| self.token_to_id("<|begin_of_text|>"))
89            .or_else(|| self.token_to_id("<|startoftext|>"))
90    }
91
92    /// Get the EOS (end of sequence) token ID, if defined.
93    pub fn eos_token_id(&self) -> Option<u32> {
94        self.token_to_id("</s>")
95            .or_else(|| self.token_to_id("<|end_of_text|>"))
96            .or_else(|| self.token_to_id("<|endoftext|>"))
97            .or_else(|| self.token_to_id("<|im_end|>"))
98    }
99
100    /// Get all stop token IDs (EOS + chat-specific stop tokens).
101    /// Used to detect when generation should stop.
102    pub fn stop_token_ids(&self) -> Vec<u32> {
103        let candidates = [
104            "</s>",
105            "<|end_of_text|>",
106            "<|endoftext|>",
107            "<|im_end|>",
108            "<|eot_id|>",
109            "<|end|>",
110        ];
111        candidates
112            .iter()
113            .filter_map(|&token| self.token_to_id(token))
114            .collect()
115    }
116
117    /// Check if a token ID is a stop token.
118    pub fn is_stop_token(&self, token_id: u32) -> bool {
119        self.stop_token_ids().contains(&token_id)
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    /// Build a minimal tokenizer JSON for testing.
128    /// This creates a character-level tokenizer.
129    fn minimal_tokenizer_json() -> String {
130        r#"{
131            "version": "1.0",
132            "truncation": null,
133            "padding": null,
134            "added_tokens": [
135                {"id": 0, "content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true},
136                {"id": 1, "content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true},
137                {"id": 2, "content": "<unk>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true}
138            ],
139            "normalizer": null,
140            "pre_tokenizer": {"type": "ByteLevel", "add_prefix_space": false, "trim_offsets": true, "use_regex": true},
141            "post_processor": null,
142            "decoder": {"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": true, "use_regex": true},
143            "model": {
144                "type": "BPE",
145                "dropout": null,
146                "unk_token": "<unk>",
147                "continuing_subword_prefix": null,
148                "end_of_word_suffix": null,
149                "fuse_unk": false,
150                "byte_fallback": false,
151                "ignore_merges": false,
152                "vocab": {
153                    "<s>": 0, "</s>": 1, "<unk>": 2,
154                    "h": 3, "e": 4, "l": 5, "o": 6,
155                    "he": 7, "ll": 8, "lo": 9,
156                    "hel": 10, "llo": 11
157                },
158                "merges": [
159                    "h e", "l l", "l o", "he l", "ll o"
160                ]
161            }
162        }"#
163        .to_string()
164    }
165
166    #[test]
167    fn load_from_json() {
168        let json = minimal_tokenizer_json();
169        let tok = Tokenizer::from_json(&json).unwrap();
170        assert!(tok.vocab_size() > 0);
171    }
172
173    #[test]
174    fn encode_decode_roundtrip() {
175        let json = minimal_tokenizer_json();
176        let tok = Tokenizer::from_json(&json).unwrap();
177
178        let ids = tok.encode("hello").unwrap();
179        assert!(!ids.is_empty());
180
181        let text = tok.decode(&ids).unwrap();
182        assert_eq!(text, "hello");
183    }
184
185    #[test]
186    fn special_tokens() {
187        let json = minimal_tokenizer_json();
188        let tok = Tokenizer::from_json(&json).unwrap();
189
190        assert_eq!(tok.bos_token_id(), Some(0));
191        assert_eq!(tok.eos_token_id(), Some(1));
192    }
193
194    #[test]
195    fn decode_single_token() {
196        let json = minimal_tokenizer_json();
197        let tok = Tokenizer::from_json(&json).unwrap();
198
199        // Token 3 = "h"
200        let text = tok.decode_one(3).unwrap();
201        assert!(!text.is_empty());
202    }
203}