candle_mi/tokenizer/mod.rs
1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Tokenizer abstraction: dispatch between `HuggingFace` and RWKV backends.
4//!
5//! [`MITokenizer`] provides a unified encode/decode interface regardless of
6//! the underlying tokenizer implementation.
7
8#[cfg(feature = "rwkv-tokenizer")]
9mod rwkv;
10
11use crate::error::{MIError, Result};
12use crate::util::positioning::EncodingWithOffsets;
13
14/// Unified tokenizer supporting multiple backends.
15///
16/// Most models use the `HuggingFace` `tokenizers` crate. RWKV-6 models
17/// ship their own vocabulary format and require a custom trie-based
18/// tokenizer, which is available behind the `rwkv-tokenizer` feature.
19///
20/// # Example
21///
22/// ```no_run
23/// use candle_mi::MITokenizer;
24///
25/// # fn main() -> candle_mi::Result<()> {
26/// let tok = MITokenizer::from_hf_path("tokenizer.json")?;
27/// let ids = tok.encode("fn main()")?;
28/// let text = tok.decode(&ids)?;
29/// assert!(!ids.is_empty());
30/// # Ok(())
31/// # }
32/// ```
33#[non_exhaustive]
34pub enum MITokenizer {
35 /// `HuggingFace` `tokenizers` backend.
36 HuggingFace(Box<tokenizers::Tokenizer>),
37 /// RWKV World tokenizer (trie-based greedy longest-match).
38 #[cfg(feature = "rwkv-tokenizer")]
39 Rwkv(rwkv::RwkvTokenizer),
40}
41
42impl MITokenizer {
43 /// Load a `HuggingFace` tokenizer from a `tokenizer.json` file.
44 ///
45 /// # Errors
46 ///
47 /// Returns [`MIError::Tokenizer`] if the file cannot be loaded or parsed.
48 pub fn from_hf_path(path: impl AsRef<std::path::Path>) -> Result<Self> {
49 let tok = tokenizers::Tokenizer::from_file(path.as_ref()).map_err(|e| {
50 MIError::Tokenizer(format!(
51 "failed to load HF tokenizer from {}: {e}",
52 path.as_ref().display()
53 ))
54 })?;
55 Ok(Self::HuggingFace(Box::new(tok)))
56 }
57
58 /// Wrap an already-loaded `HuggingFace` tokenizer.
59 #[must_use]
60 pub fn from_hf(tokenizer: tokenizers::Tokenizer) -> Self {
61 Self::HuggingFace(Box::new(tokenizer))
62 }
63
64 /// Load an RWKV World tokenizer from a vocabulary file.
65 ///
66 /// # Errors
67 ///
68 /// Returns [`MIError::Tokenizer`] if the file cannot be loaded or parsed.
69 #[cfg(feature = "rwkv-tokenizer")]
70 pub fn from_rwkv_path(path: impl AsRef<std::path::Path>) -> Result<Self> {
71 let tok = rwkv::RwkvTokenizer::from_file(path.as_ref())?;
72 Ok(Self::Rwkv(tok))
73 }
74
75 /// Encode text into token IDs, adding special tokens (e.g. BOS for Gemma).
76 ///
77 /// Special tokens are added according to the tokenizer's configured
78 /// post-processor, matching the `HuggingFace` convention for inference.
79 ///
80 /// # Errors
81 ///
82 /// Returns [`MIError::Tokenizer`] if encoding fails.
83 pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
84 match self {
85 Self::HuggingFace(tok) => {
86 let encoding = tok
87 .encode(text, true)
88 .map_err(|e| MIError::Tokenizer(format!("HF encode failed: {e}")))?;
89 Ok(encoding.get_ids().to_vec())
90 }
91 #[cfg(feature = "rwkv-tokenizer")]
92 Self::Rwkv(tok) => tok.encode(text),
93 }
94 }
95
96 /// Encode text into token IDs **without** adding special tokens.
97 ///
98 /// Useful for MI analyses that need raw tokenization without BOS/EOS.
99 ///
100 /// # Errors
101 ///
102 /// Returns [`MIError::Tokenizer`] if encoding fails.
103 pub fn encode_raw(&self, text: &str) -> Result<Vec<u32>> {
104 match self {
105 Self::HuggingFace(tok) => {
106 let encoding = tok
107 .encode(text, false)
108 .map_err(|e| MIError::Tokenizer(format!("HF encode failed: {e}")))?;
109 Ok(encoding.get_ids().to_vec())
110 }
111 #[cfg(feature = "rwkv-tokenizer")]
112 Self::Rwkv(tok) => tok.encode(text),
113 }
114 }
115
116 /// Encode text into token IDs with character offset mapping.
117 ///
118 /// Returns an [`EncodingWithOffsets`] containing token IDs, token strings,
119 /// and byte-offset ranges for each token. Special tokens are added
120 /// (e.g., BOS for Gemma); special tokens receive a `(0, 0)` offset.
121 ///
122 /// # Errors
123 ///
124 /// Returns [`MIError::Tokenizer`] if encoding fails or if the backend
125 /// does not support offset mapping (RWKV).
126 pub fn encode_with_offsets(&self, text: &str) -> Result<EncodingWithOffsets> {
127 self.encode_with_offsets_inner(text, true)
128 }
129
130 /// Encode text into token IDs with character offset mapping, **without**
131 /// adding special tokens.
132 ///
133 /// # Errors
134 ///
135 /// Returns [`MIError::Tokenizer`] if encoding fails or if the backend
136 /// does not support offset mapping (RWKV).
137 pub fn encode_raw_with_offsets(&self, text: &str) -> Result<EncodingWithOffsets> {
138 self.encode_with_offsets_inner(text, false)
139 }
140
141 /// Shared implementation for offset-bearing encode methods.
142 fn encode_with_offsets_inner(
143 &self,
144 text: &str,
145 add_special_tokens: bool,
146 ) -> Result<EncodingWithOffsets> {
147 match self {
148 Self::HuggingFace(tok) => {
149 let encoding = tok
150 .encode(text, add_special_tokens)
151 .map_err(|e| MIError::Tokenizer(format!("HF encode failed: {e}")))?;
152 let ids = encoding.get_ids().to_vec();
153 let tokens: Vec<String> = encoding
154 .get_tokens()
155 .iter()
156 .map(ToString::to_string)
157 .collect();
158 let offsets = encoding.get_offsets().to_vec();
159 Ok(EncodingWithOffsets::new(ids, tokens, offsets))
160 }
161 #[cfg(feature = "rwkv-tokenizer")]
162 Self::Rwkv(_) => Err(MIError::Tokenizer(
163 "RWKV tokenizer does not support offset mapping".into(),
164 )),
165 }
166 }
167
168 /// Decode token IDs back to a string.
169 ///
170 /// # Errors
171 ///
172 /// Returns [`MIError::Tokenizer`] if decoding fails.
173 pub fn decode(&self, ids: &[u32]) -> Result<String> {
174 match self {
175 Self::HuggingFace(tok) => tok
176 .decode(ids, false)
177 .map_err(|e| MIError::Tokenizer(format!("HF decode failed: {e}"))),
178 #[cfg(feature = "rwkv-tokenizer")]
179 Self::Rwkv(tok) => tok.decode(ids),
180 }
181 }
182
183 /// Get vocabulary size.
184 #[must_use]
185 pub fn vocab_size(&self) -> usize {
186 match self {
187 Self::HuggingFace(tok) => tok.get_vocab_size(true),
188 #[cfg(feature = "rwkv-tokenizer")]
189 Self::Rwkv(tok) => tok.vocab_size(),
190 }
191 }
192
193 /// Find the token ID for a word, trying `" word"` (with leading space) first,
194 /// then bare `"word"`.
195 ///
196 /// This handles BPE tokenizers that represent word-initial tokens with a
197 /// leading space (e.g., `" cat"` → single token).
198 ///
199 /// # Errors
200 ///
201 /// Returns [`MIError::Tokenizer`] if the word cannot be resolved to a
202 /// single token in either form.
203 pub fn find_token_id(&self, word: &str) -> Result<u32> {
204 // Try with leading space first (most BPE tokenizers).
205 let with_space = format!(" {word}");
206 let ids = self.encode(&with_space)?;
207 // ids[0] is BOS (if present), ids[1] would be the word token.
208 if ids.len() == 2 {
209 return ids
210 .get(1)
211 .copied()
212 .ok_or_else(|| MIError::Tokenizer(format!("unexpected encoding for \" {word}\"")));
213 }
214
215 // Try bare word.
216 let bare_ids = self.encode(word)?;
217 if bare_ids.len() == 2 {
218 return bare_ids
219 .get(1)
220 .copied()
221 .ok_or_else(|| MIError::Tokenizer(format!("unexpected encoding for \"{word}\"")));
222 }
223
224 // Last resort: return last token.
225 ids.last().copied().ok_or_else(|| {
226 MIError::Tokenizer(format!("could not find single token ID for \"{word}\""))
227 })
228 }
229
230 /// Decode a single token ID to its string representation.
231 ///
232 /// # Errors
233 ///
234 /// Returns [`MIError::Tokenizer`] if decoding fails.
235 pub fn decode_token(&self, token_id: u32) -> Result<String> {
236 self.decode(&[token_id])
237 }
238}
239
240impl std::fmt::Debug for MITokenizer {
241 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242 match self {
243 Self::HuggingFace(_) => f.debug_tuple("HuggingFace").field(&"...").finish(),
244 #[cfg(feature = "rwkv-tokenizer")]
245 Self::Rwkv(tok) => f.debug_tuple("Rwkv").field(tok).finish(),
246 }
247 }
248}