Skip to main content

candle_mi/tokenizer/
mod.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Tokenizer abstraction: dispatch between `HuggingFace` and RWKV backends.
4//!
5//! [`MITokenizer`] provides a unified encode/decode interface regardless of
6//! the underlying tokenizer implementation.
7
8#[cfg(feature = "rwkv-tokenizer")]
9mod rwkv;
10
11use crate::error::{MIError, Result};
12use crate::util::positioning::EncodingWithOffsets;
13
14/// Unified tokenizer supporting multiple backends.
15///
16/// Most models use the `HuggingFace` `tokenizers` crate. RWKV-6 models
17/// ship their own vocabulary format and require a custom trie-based
18/// tokenizer, which is available behind the `rwkv-tokenizer` feature.
19///
20/// # Example
21///
22/// ```no_run
23/// use candle_mi::MITokenizer;
24///
25/// # fn main() -> candle_mi::Result<()> {
26/// let tok = MITokenizer::from_hf_path("tokenizer.json")?;
27/// let ids = tok.encode("fn main()")?;
28/// let text = tok.decode(&ids)?;
29/// assert!(!ids.is_empty());
30/// # Ok(())
31/// # }
32/// ```
33#[non_exhaustive]
34pub enum MITokenizer {
35    /// `HuggingFace` `tokenizers` backend.
36    HuggingFace(Box<tokenizers::Tokenizer>),
37    /// RWKV World tokenizer (trie-based greedy longest-match).
38    #[cfg(feature = "rwkv-tokenizer")]
39    Rwkv(rwkv::RwkvTokenizer),
40}
41
42impl MITokenizer {
43    /// Load a `HuggingFace` tokenizer from a `tokenizer.json` file.
44    ///
45    /// # Errors
46    ///
47    /// Returns [`MIError::Tokenizer`] if the file cannot be loaded or parsed.
48    pub fn from_hf_path(path: impl AsRef<std::path::Path>) -> Result<Self> {
49        let tok = tokenizers::Tokenizer::from_file(path.as_ref()).map_err(|e| {
50            MIError::Tokenizer(format!(
51                "failed to load HF tokenizer from {}: {e}",
52                path.as_ref().display()
53            ))
54        })?;
55        Ok(Self::HuggingFace(Box::new(tok)))
56    }
57
58    /// Wrap an already-loaded `HuggingFace` tokenizer.
59    #[must_use]
60    pub fn from_hf(tokenizer: tokenizers::Tokenizer) -> Self {
61        Self::HuggingFace(Box::new(tokenizer))
62    }
63
64    /// Load an RWKV World tokenizer from a vocabulary file.
65    ///
66    /// # Errors
67    ///
68    /// Returns [`MIError::Tokenizer`] if the file cannot be loaded or parsed.
69    #[cfg(feature = "rwkv-tokenizer")]
70    pub fn from_rwkv_path(path: impl AsRef<std::path::Path>) -> Result<Self> {
71        let tok = rwkv::RwkvTokenizer::from_file(path.as_ref())?;
72        Ok(Self::Rwkv(tok))
73    }
74
75    /// Encode text into token IDs, adding special tokens (e.g. BOS for Gemma).
76    ///
77    /// Special tokens are added according to the tokenizer's configured
78    /// post-processor, matching the `HuggingFace` convention for inference.
79    ///
80    /// # Errors
81    ///
82    /// Returns [`MIError::Tokenizer`] if encoding fails.
83    pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
84        match self {
85            Self::HuggingFace(tok) => {
86                let encoding = tok
87                    .encode(text, true)
88                    .map_err(|e| MIError::Tokenizer(format!("HF encode failed: {e}")))?;
89                Ok(encoding.get_ids().to_vec())
90            }
91            #[cfg(feature = "rwkv-tokenizer")]
92            Self::Rwkv(tok) => tok.encode(text),
93        }
94    }
95
96    /// Encode text into token IDs **without** adding special tokens.
97    ///
98    /// Useful for MI analyses that need raw tokenization without BOS/EOS.
99    ///
100    /// # Errors
101    ///
102    /// Returns [`MIError::Tokenizer`] if encoding fails.
103    pub fn encode_raw(&self, text: &str) -> Result<Vec<u32>> {
104        match self {
105            Self::HuggingFace(tok) => {
106                let encoding = tok
107                    .encode(text, false)
108                    .map_err(|e| MIError::Tokenizer(format!("HF encode failed: {e}")))?;
109                Ok(encoding.get_ids().to_vec())
110            }
111            #[cfg(feature = "rwkv-tokenizer")]
112            Self::Rwkv(tok) => tok.encode(text),
113        }
114    }
115
116    /// Encode text into token IDs with character offset mapping.
117    ///
118    /// Returns an [`EncodingWithOffsets`] containing token IDs, token strings,
119    /// and byte-offset ranges for each token. Special tokens are added
120    /// (e.g., BOS for Gemma); special tokens receive a `(0, 0)` offset.
121    ///
122    /// # Errors
123    ///
124    /// Returns [`MIError::Tokenizer`] if encoding fails or if the backend
125    /// does not support offset mapping (RWKV).
126    pub fn encode_with_offsets(&self, text: &str) -> Result<EncodingWithOffsets> {
127        self.encode_with_offsets_inner(text, true)
128    }
129
130    /// Encode text into token IDs with character offset mapping, **without**
131    /// adding special tokens.
132    ///
133    /// # Errors
134    ///
135    /// Returns [`MIError::Tokenizer`] if encoding fails or if the backend
136    /// does not support offset mapping (RWKV).
137    pub fn encode_raw_with_offsets(&self, text: &str) -> Result<EncodingWithOffsets> {
138        self.encode_with_offsets_inner(text, false)
139    }
140
141    /// Shared implementation for offset-bearing encode methods.
142    fn encode_with_offsets_inner(
143        &self,
144        text: &str,
145        add_special_tokens: bool,
146    ) -> Result<EncodingWithOffsets> {
147        match self {
148            Self::HuggingFace(tok) => {
149                let encoding = tok
150                    .encode(text, add_special_tokens)
151                    .map_err(|e| MIError::Tokenizer(format!("HF encode failed: {e}")))?;
152                let ids = encoding.get_ids().to_vec();
153                let tokens: Vec<String> = encoding
154                    .get_tokens()
155                    .iter()
156                    .map(ToString::to_string)
157                    .collect();
158                let offsets = encoding.get_offsets().to_vec();
159                Ok(EncodingWithOffsets::new(ids, tokens, offsets))
160            }
161            #[cfg(feature = "rwkv-tokenizer")]
162            Self::Rwkv(_) => Err(MIError::Tokenizer(
163                "RWKV tokenizer does not support offset mapping".into(),
164            )),
165        }
166    }
167
168    /// Decode token IDs back to a string.
169    ///
170    /// # Errors
171    ///
172    /// Returns [`MIError::Tokenizer`] if decoding fails.
173    pub fn decode(&self, ids: &[u32]) -> Result<String> {
174        match self {
175            Self::HuggingFace(tok) => tok
176                .decode(ids, false)
177                .map_err(|e| MIError::Tokenizer(format!("HF decode failed: {e}"))),
178            #[cfg(feature = "rwkv-tokenizer")]
179            Self::Rwkv(tok) => tok.decode(ids),
180        }
181    }
182
183    /// Get vocabulary size.
184    #[must_use]
185    pub fn vocab_size(&self) -> usize {
186        match self {
187            Self::HuggingFace(tok) => tok.get_vocab_size(true),
188            #[cfg(feature = "rwkv-tokenizer")]
189            Self::Rwkv(tok) => tok.vocab_size(),
190        }
191    }
192
193    /// Find the token ID for a word, trying `" word"` (with leading space) first,
194    /// then bare `"word"`.
195    ///
196    /// This handles BPE tokenizers that represent word-initial tokens with a
197    /// leading space (e.g., `" cat"` → single token).
198    ///
199    /// # Errors
200    ///
201    /// Returns [`MIError::Tokenizer`] if the word cannot be resolved to a
202    /// single token in either form.
203    pub fn find_token_id(&self, word: &str) -> Result<u32> {
204        // Try with leading space first (most BPE tokenizers).
205        let with_space = format!(" {word}");
206        let ids = self.encode(&with_space)?;
207        // ids[0] is BOS (if present), ids[1] would be the word token.
208        if ids.len() == 2 {
209            return ids
210                .get(1)
211                .copied()
212                .ok_or_else(|| MIError::Tokenizer(format!("unexpected encoding for \" {word}\"")));
213        }
214
215        // Try bare word.
216        let bare_ids = self.encode(word)?;
217        if bare_ids.len() == 2 {
218            return bare_ids
219                .get(1)
220                .copied()
221                .ok_or_else(|| MIError::Tokenizer(format!("unexpected encoding for \"{word}\"")));
222        }
223
224        // Last resort: return last token.
225        ids.last().copied().ok_or_else(|| {
226            MIError::Tokenizer(format!("could not find single token ID for \"{word}\""))
227        })
228    }
229
230    /// Decode a single token ID to its string representation.
231    ///
232    /// # Errors
233    ///
234    /// Returns [`MIError::Tokenizer`] if decoding fails.
235    pub fn decode_token(&self, token_id: u32) -> Result<String> {
236        self.decode(&[token_id])
237    }
238}
239
240impl std::fmt::Debug for MITokenizer {
241    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242        match self {
243            Self::HuggingFace(_) => f.debug_tuple("HuggingFace").field(&"...").finish(),
244            #[cfg(feature = "rwkv-tokenizer")]
245            Self::Rwkv(tok) => f.debug_tuple("Rwkv").field(tok).finish(),
246        }
247    }
248}