Skip to main content

oxibonsai_runtime/
tokenizer_bridge.rs

1//! Tokenizer bridge wrapping HuggingFace tokenizers.
2//!
3//! On WASM targets, the `tokenizers` crate is unavailable (requires native C extensions).
4//! A stub implementation is provided that returns errors for all operations.
5
6use crate::error::{RuntimeError, RuntimeResult};
7
8/// Thin wrapper around `tokenizers::Tokenizer`.
9///
10/// On non-WASM targets, delegates to the full HuggingFace tokenizers library.
11/// On WASM targets, all methods return a `RuntimeError::Tokenizer` error.
12pub struct TokenizerBridge {
13    #[cfg(not(target_arch = "wasm32"))]
14    inner: tokenizers::Tokenizer,
15    #[cfg(target_arch = "wasm32")]
16    _phantom: (),
17}
18
19/// Per-stream UTF-8-safe decode state. Owned by the caller.
20///
21/// BPE / byte-level tokenizers (Qwen3, GPT-2, etc.) sometimes emit a single
22/// token that carries only **part** of a multi-byte UTF-8 character (e.g. one
23/// byte of a CJK ideograph or emoji).  Decoding tokens one-at-a-time without
24/// buffering breaks those multi-byte sequences and produces `U+FFFD`
25/// replacement characters in the output stream.  This state mirrors what the
26/// HuggingFace `tokenizers::DecodeStream` keeps internally so that we can own
27/// it externally and feed tokens through `TokenizerBridge::step_decode`.
28///
29/// Use one `DecodeStreamState` per generation request; reset (or drop &
30/// re-create) it between independent requests.
31#[derive(Default)]
32#[cfg_attr(target_arch = "wasm32", allow(dead_code))]
33pub struct DecodeStreamState {
34    ids: Vec<u32>,
35    prefix: String,
36    prefix_index: usize,
37    skip_special_tokens: bool,
38}
39
40impl DecodeStreamState {
41    /// Construct a fresh decode-stream state.
42    ///
43    /// `skip_special_tokens` matches the existing `decode()` behavior — pass
44    /// `true` to drop sentinel tokens (e.g. `<|im_end|>`) from the output.
45    pub fn new(skip_special_tokens: bool) -> Self {
46        Self {
47            ids: Vec::new(),
48            prefix: String::new(),
49            prefix_index: 0,
50            skip_special_tokens,
51        }
52    }
53
54    /// Reset the state, preserving the original `skip_special_tokens` flag.
55    pub fn reset(&mut self) {
56        *self = Self::new(self.skip_special_tokens);
57    }
58}
59
60impl TokenizerBridge {
61    /// Load a tokenizer from a JSON file.
62    #[cfg(not(target_arch = "wasm32"))]
63    pub fn from_file(path: &str) -> RuntimeResult<Self> {
64        let inner = tokenizers::Tokenizer::from_file(path)
65            .map_err(|e| RuntimeError::Tokenizer(e.to_string()))?;
66        Ok(Self { inner })
67    }
68
69    /// Load a tokenizer from a JSON file.
70    ///
71    /// On WASM targets, always returns an error since the tokenizers library
72    /// requires native code not available in WebAssembly.
73    #[cfg(target_arch = "wasm32")]
74    pub fn from_file(_path: &str) -> RuntimeResult<Self> {
75        Err(RuntimeError::Tokenizer(
76            "tokenizers library is not available on wasm32 targets".to_string(),
77        ))
78    }
79
80    /// Encode text to token IDs.
81    #[cfg(not(target_arch = "wasm32"))]
82    pub fn encode(&self, text: &str) -> RuntimeResult<Vec<u32>> {
83        let encoding = self
84            .inner
85            .encode(text, false)
86            .map_err(|e| RuntimeError::Tokenizer(e.to_string()))?;
87        Ok(encoding.get_ids().to_vec())
88    }
89
90    /// Encode text to token IDs.
91    ///
92    /// On WASM targets, always returns an error.
93    #[cfg(target_arch = "wasm32")]
94    pub fn encode(&self, _text: &str) -> RuntimeResult<Vec<u32>> {
95        Err(RuntimeError::Tokenizer(
96            "tokenizers library is not available on wasm32 targets".to_string(),
97        ))
98    }
99
100    /// Decode token IDs to text.
101    #[cfg(not(target_arch = "wasm32"))]
102    pub fn decode(&self, ids: &[u32]) -> RuntimeResult<String> {
103        self.inner
104            .decode(ids, true)
105            .map_err(|e| RuntimeError::Tokenizer(e.to_string()))
106    }
107
108    /// Decode token IDs to text.
109    ///
110    /// On WASM targets, always returns an error.
111    #[cfg(target_arch = "wasm32")]
112    pub fn decode(&self, _ids: &[u32]) -> RuntimeResult<String> {
113        Err(RuntimeError::Tokenizer(
114            "tokenizers library is not available on wasm32 targets".to_string(),
115        ))
116    }
117
118    /// Get the vocabulary size.
119    #[cfg(not(target_arch = "wasm32"))]
120    pub fn vocab_size(&self) -> usize {
121        self.inner.get_vocab_size(true)
122    }
123
124    /// Get the vocabulary size.
125    ///
126    /// On WASM targets, returns 0 since no tokenizer is available.
127    #[cfg(target_arch = "wasm32")]
128    pub fn vocab_size(&self) -> usize {
129        0
130    }
131
132    /// Get the internal tokenizer reference.
133    #[cfg(not(target_arch = "wasm32"))]
134    pub fn inner(&self) -> &tokenizers::Tokenizer {
135        &self.inner
136    }
137
138    /// Create a fresh decode-stream state for one generation request.
139    ///
140    /// See [`DecodeStreamState`] and [`Self::step_decode`] for the streaming
141    /// decode protocol.  Use this instead of repeatedly calling
142    /// [`Self::decode`] with single-token slices, which mishandles tokens that
143    /// straddle UTF-8 codepoint boundaries.
144    #[cfg(not(target_arch = "wasm32"))]
145    pub fn new_decode_stream(&self, skip_special_tokens: bool) -> DecodeStreamState {
146        DecodeStreamState::new(skip_special_tokens)
147    }
148
149    /// Create a fresh decode-stream state.
150    ///
151    /// On WASM targets this returns a state object, but [`Self::step_decode`]
152    /// will always error.  The state itself is harmless to construct.
153    #[cfg(target_arch = "wasm32")]
154    pub fn new_decode_stream(&self, skip_special_tokens: bool) -> DecodeStreamState {
155        DecodeStreamState::new(skip_special_tokens)
156    }
157
158    /// Advance the decode stream by one token.
159    ///
160    /// Returns `Ok(Some(text))` only when the buffered bytes form a complete
161    /// UTF-8 chunk (which may span several previous tokens for CJK / emoji);
162    /// returns `Ok(None)` when more tokens are needed before any well-formed
163    /// text can be emitted.  Callers must **not** print the empty string when
164    /// `Ok(None)` is returned — wait for the next token.
165    #[cfg(not(target_arch = "wasm32"))]
166    pub fn step_decode(
167        &self,
168        state: &mut DecodeStreamState,
169        id: u32,
170    ) -> RuntimeResult<Option<String>> {
171        tokenizers::step_decode_stream(
172            &*self.inner,
173            vec![id],
174            state.skip_special_tokens,
175            &mut state.ids,
176            &mut state.prefix,
177            &mut state.prefix_index,
178        )
179        .map_err(|e| RuntimeError::Tokenizer(e.to_string()))
180    }
181
182    /// Advance the decode stream by one token.
183    ///
184    /// On WASM targets this always returns an error since the tokenizers
185    /// library is unavailable.
186    #[cfg(target_arch = "wasm32")]
187    pub fn step_decode(
188        &self,
189        _state: &mut DecodeStreamState,
190        _id: u32,
191    ) -> RuntimeResult<Option<String>> {
192        Err(RuntimeError::Tokenizer(
193            "tokenizers library is not available on wasm32 targets".to_string(),
194        ))
195    }
196}
197
198#[cfg(all(test, not(target_arch = "wasm32")))]
199mod tests {
200    use super::*;
201    use std::path::Path;
202
203    /// Path to the project's bundled Qwen3 tokenizer.  Tests that need a real
204    /// BPE tokenizer skip themselves when this fixture is missing so that
205    /// freshly-cloned working trees still pass `cargo test`.
206    const FIXTURE_TOKENIZER: &str = "../../models/tokenizer.json";
207
208    fn maybe_load_fixture() -> Option<TokenizerBridge> {
209        if !Path::new(FIXTURE_TOKENIZER).exists() {
210            eprintln!(
211                "skipped: tokenizer fixture not found at {FIXTURE_TOKENIZER} \
212                 (run scripts/download_tokenizer.sh to enable)",
213            );
214            return None;
215        }
216        match TokenizerBridge::from_file(FIXTURE_TOKENIZER) {
217            Ok(t) => Some(t),
218            Err(e) => {
219                eprintln!("skipped: failed to load tokenizer fixture: {e}");
220                None
221            }
222        }
223    }
224
225    /// Drive every id through `step_decode` and concatenate the well-formed
226    /// chunks.  Mirrors what the CLI / SSE code paths do.
227    fn stream_through(tok: &TokenizerBridge, ids: &[u32]) -> RuntimeResult<String> {
228        let mut state = tok.new_decode_stream(true);
229        let mut out = String::new();
230        for &id in ids {
231            if let Some(chunk) = tok.step_decode(&mut state, id)? {
232                out.push_str(&chunk);
233            }
234        }
235        Ok(out)
236    }
237
238    #[test]
239    fn streaming_decode_cjk_no_replacement_chars() -> RuntimeResult<()> {
240        let Some(tok) = maybe_load_fixture() else {
241            return Ok(());
242        };
243
244        // Mix of Japanese ideographs and hiragana exercising multi-byte UTF-8
245        // (3 bytes per char) that BPE byte-level tokenization typically splits
246        // across two or three tokens.
247        let input = "日本語処理を専門";
248        let ids = tok.encode(input)?;
249        assert!(!ids.is_empty(), "encoding yielded no token ids");
250
251        let streamed = stream_through(&tok, &ids)?;
252
253        assert!(
254            !streamed.contains('\u{FFFD}'),
255            "streaming decode produced U+FFFD replacement char(s); output: {streamed:?}",
256        );
257        assert_eq!(
258            streamed, input,
259            "streaming decode did not reconstruct the original CJK input",
260        );
261        Ok(())
262    }
263
264    #[test]
265    fn streaming_decode_ascii_passes_through() -> RuntimeResult<()> {
266        let Some(tok) = maybe_load_fixture() else {
267            return Ok(());
268        };
269
270        let input = "Hello, world! Streaming ASCII works fine.";
271        let ids = tok.encode(input)?;
272        let streamed = stream_through(&tok, &ids)?;
273        assert!(!streamed.contains('\u{FFFD}'));
274        assert_eq!(streamed, input);
275        Ok(())
276    }
277
278    #[test]
279    fn streaming_decode_handles_empty_input() -> RuntimeResult<()> {
280        let Some(tok) = maybe_load_fixture() else {
281            return Ok(());
282        };
283
284        // Driving zero ids must yield no output and must not panic.
285        let streamed = stream_through(&tok, &[])?;
286        assert!(
287            streamed.is_empty(),
288            "empty token stream should yield empty output, got {streamed:?}",
289        );
290
291        // Resetting a fresh state is a no-op; the state is still usable
292        // afterwards (verified by re-running the empty-input drive).
293        let mut state = tok.new_decode_stream(true);
294        state.reset();
295        let still_empty = stream_through(&tok, &[])?;
296        assert!(still_empty.is_empty());
297        Ok(())
298    }
299}