oxibonsai_runtime/
tokenizer_bridge.rs1use crate::error::{RuntimeError, RuntimeResult};
7
8pub struct TokenizerBridge {
13 #[cfg(not(target_arch = "wasm32"))]
14 inner: tokenizers::Tokenizer,
15 #[cfg(target_arch = "wasm32")]
16 _phantom: (),
17}
18
19#[derive(Default)]
32#[cfg_attr(target_arch = "wasm32", allow(dead_code))]
33pub struct DecodeStreamState {
34 ids: Vec<u32>,
35 prefix: String,
36 prefix_index: usize,
37 skip_special_tokens: bool,
38}
39
40impl DecodeStreamState {
41 pub fn new(skip_special_tokens: bool) -> Self {
46 Self {
47 ids: Vec::new(),
48 prefix: String::new(),
49 prefix_index: 0,
50 skip_special_tokens,
51 }
52 }
53
54 pub fn reset(&mut self) {
56 *self = Self::new(self.skip_special_tokens);
57 }
58}
59
60impl TokenizerBridge {
61 #[cfg(not(target_arch = "wasm32"))]
63 pub fn from_file(path: &str) -> RuntimeResult<Self> {
64 let inner = tokenizers::Tokenizer::from_file(path)
65 .map_err(|e| RuntimeError::Tokenizer(e.to_string()))?;
66 Ok(Self { inner })
67 }
68
69 #[cfg(target_arch = "wasm32")]
74 pub fn from_file(_path: &str) -> RuntimeResult<Self> {
75 Err(RuntimeError::Tokenizer(
76 "tokenizers library is not available on wasm32 targets".to_string(),
77 ))
78 }
79
80 #[cfg(not(target_arch = "wasm32"))]
82 pub fn encode(&self, text: &str) -> RuntimeResult<Vec<u32>> {
83 let encoding = self
84 .inner
85 .encode(text, false)
86 .map_err(|e| RuntimeError::Tokenizer(e.to_string()))?;
87 Ok(encoding.get_ids().to_vec())
88 }
89
90 #[cfg(target_arch = "wasm32")]
94 pub fn encode(&self, _text: &str) -> RuntimeResult<Vec<u32>> {
95 Err(RuntimeError::Tokenizer(
96 "tokenizers library is not available on wasm32 targets".to_string(),
97 ))
98 }
99
100 #[cfg(not(target_arch = "wasm32"))]
102 pub fn decode(&self, ids: &[u32]) -> RuntimeResult<String> {
103 self.inner
104 .decode(ids, true)
105 .map_err(|e| RuntimeError::Tokenizer(e.to_string()))
106 }
107
108 #[cfg(target_arch = "wasm32")]
112 pub fn decode(&self, _ids: &[u32]) -> RuntimeResult<String> {
113 Err(RuntimeError::Tokenizer(
114 "tokenizers library is not available on wasm32 targets".to_string(),
115 ))
116 }
117
118 #[cfg(not(target_arch = "wasm32"))]
120 pub fn vocab_size(&self) -> usize {
121 self.inner.get_vocab_size(true)
122 }
123
124 #[cfg(target_arch = "wasm32")]
128 pub fn vocab_size(&self) -> usize {
129 0
130 }
131
132 #[cfg(not(target_arch = "wasm32"))]
134 pub fn inner(&self) -> &tokenizers::Tokenizer {
135 &self.inner
136 }
137
138 #[cfg(not(target_arch = "wasm32"))]
145 pub fn new_decode_stream(&self, skip_special_tokens: bool) -> DecodeStreamState {
146 DecodeStreamState::new(skip_special_tokens)
147 }
148
149 #[cfg(target_arch = "wasm32")]
154 pub fn new_decode_stream(&self, skip_special_tokens: bool) -> DecodeStreamState {
155 DecodeStreamState::new(skip_special_tokens)
156 }
157
158 #[cfg(not(target_arch = "wasm32"))]
166 pub fn step_decode(
167 &self,
168 state: &mut DecodeStreamState,
169 id: u32,
170 ) -> RuntimeResult<Option<String>> {
171 tokenizers::step_decode_stream(
172 &*self.inner,
173 vec![id],
174 state.skip_special_tokens,
175 &mut state.ids,
176 &mut state.prefix,
177 &mut state.prefix_index,
178 )
179 .map_err(|e| RuntimeError::Tokenizer(e.to_string()))
180 }
181
182 #[cfg(target_arch = "wasm32")]
187 pub fn step_decode(
188 &self,
189 _state: &mut DecodeStreamState,
190 _id: u32,
191 ) -> RuntimeResult<Option<String>> {
192 Err(RuntimeError::Tokenizer(
193 "tokenizers library is not available on wasm32 targets".to_string(),
194 ))
195 }
196}
197
198#[cfg(all(test, not(target_arch = "wasm32")))]
199mod tests {
200 use super::*;
201 use std::path::Path;
202
203 const FIXTURE_TOKENIZER: &str = "../../models/tokenizer.json";
207
208 fn maybe_load_fixture() -> Option<TokenizerBridge> {
209 if !Path::new(FIXTURE_TOKENIZER).exists() {
210 eprintln!(
211 "skipped: tokenizer fixture not found at {FIXTURE_TOKENIZER} \
212 (run scripts/download_tokenizer.sh to enable)",
213 );
214 return None;
215 }
216 match TokenizerBridge::from_file(FIXTURE_TOKENIZER) {
217 Ok(t) => Some(t),
218 Err(e) => {
219 eprintln!("skipped: failed to load tokenizer fixture: {e}");
220 None
221 }
222 }
223 }
224
225 fn stream_through(tok: &TokenizerBridge, ids: &[u32]) -> RuntimeResult<String> {
228 let mut state = tok.new_decode_stream(true);
229 let mut out = String::new();
230 for &id in ids {
231 if let Some(chunk) = tok.step_decode(&mut state, id)? {
232 out.push_str(&chunk);
233 }
234 }
235 Ok(out)
236 }
237
238 #[test]
239 fn streaming_decode_cjk_no_replacement_chars() -> RuntimeResult<()> {
240 let Some(tok) = maybe_load_fixture() else {
241 return Ok(());
242 };
243
244 let input = "日本語処理を専門";
248 let ids = tok.encode(input)?;
249 assert!(!ids.is_empty(), "encoding yielded no token ids");
250
251 let streamed = stream_through(&tok, &ids)?;
252
253 assert!(
254 !streamed.contains('\u{FFFD}'),
255 "streaming decode produced U+FFFD replacement char(s); output: {streamed:?}",
256 );
257 assert_eq!(
258 streamed, input,
259 "streaming decode did not reconstruct the original CJK input",
260 );
261 Ok(())
262 }
263
264 #[test]
265 fn streaming_decode_ascii_passes_through() -> RuntimeResult<()> {
266 let Some(tok) = maybe_load_fixture() else {
267 return Ok(());
268 };
269
270 let input = "Hello, world! Streaming ASCII works fine.";
271 let ids = tok.encode(input)?;
272 let streamed = stream_through(&tok, &ids)?;
273 assert!(!streamed.contains('\u{FFFD}'));
274 assert_eq!(streamed, input);
275 Ok(())
276 }
277
278 #[test]
279 fn streaming_decode_handles_empty_input() -> RuntimeResult<()> {
280 let Some(tok) = maybe_load_fixture() else {
281 return Ok(());
282 };
283
284 let streamed = stream_through(&tok, &[])?;
286 assert!(
287 streamed.is_empty(),
288 "empty token stream should yield empty output, got {streamed:?}",
289 );
290
291 let mut state = tok.new_decode_stream(true);
294 state.reset();
295 let still_empty = stream_through(&tok, &[])?;
296 assert!(still_empty.is_empty());
297 Ok(())
298 }
299}