Skip to main content

oxibonsai_tokenizer/
streaming.rs

1//! UTF-8-safe streaming decoder.
2//!
3//! When a server emits tokens one at a time, naive `decode(&[id])` can return
4//! strings with invalid UTF-8 because a single BPE token may hold *part* of a
5//! multi-byte codepoint (common for CJK / emoji output).  The decoder in this
6//! module keeps a small byte buffer across calls and only flushes characters
7//! that form a complete UTF-8 sequence.
8//!
9//! ## Usage
10//!
11//! ```rust
12//! use oxibonsai_tokenizer::OxiTokenizer;
13//!
14//! let tok = OxiTokenizer::char_level_stub(256);
15//! let ids = tok.encode("Hello!").expect("encode");
16//! let mut dec = tok.streaming_decoder();
17//! let mut out = String::new();
18//! for id in &ids {
19//!     if let Some(piece) = dec.push_token(*id) {
20//!         out.push_str(&piece);
21//!     }
22//! }
23//! out.push_str(&dec.finish().expect("stream must end on a UTF-8 boundary"));
24//! assert_eq!(out, "Hello!");
25//! ```
26
27use crate::{
28    error::{TokenizerError, TokenizerResult},
29    tokenizer::OxiTokenizer,
30};
31
32/// A streaming decoder that yields well-formed UTF-8 slices as tokens arrive.
33///
34/// The decoder holds a reference to its parent [`OxiTokenizer`] so that
35/// special-token handling, vocabulary lookup and byte-level decoding remain
36/// consistent with [`OxiTokenizer::decode`].
37pub struct StreamingDecoder<'a> {
38    tokenizer: &'a OxiTokenizer,
39    /// Bytes that have been decoded but not yet emitted because they are
40    /// part of an incomplete UTF-8 sequence.
41    pending: Vec<u8>,
42    /// Total bytes the decoder has seen across the stream (for diagnostics).
43    total_bytes: usize,
44    /// Total tokens the decoder has seen across the stream.
45    total_tokens: usize,
46}
47
48impl<'a> StreamingDecoder<'a> {
49    /// Create a fresh decoder tied to `tokenizer`.
50    pub fn new(tokenizer: &'a OxiTokenizer) -> Self {
51        Self {
52            tokenizer,
53            pending: Vec::with_capacity(8),
54            total_bytes: 0,
55            total_tokens: 0,
56        }
57    }
58
59    /// Push a single token ID and return the next well-formed UTF-8 slice, if
60    /// any.  Returns `None` when the token's bytes do not extend any
61    /// previously-pending prefix into a full UTF-8 character.
62    ///
63    /// The returned `String` contains all characters that became complete as
64    /// a result of this push — may be multiple characters if the token
65    /// carries several whole code points.
66    pub fn push_token(&mut self, id: u32) -> Option<String> {
67        self.total_tokens += 1;
68        let mut scratch: Vec<u8> = Vec::with_capacity(8);
69        self.tokenizer.decode_id_into(id, &mut scratch);
70        if scratch.is_empty() {
71            return None;
72        }
73        self.total_bytes += scratch.len();
74        self.pending.extend_from_slice(&scratch);
75        self.flush_complete()
76    }
77
78    /// Push many tokens at once.  Equivalent to repeatedly calling
79    /// [`Self::push_token`] but only returns once, with all complete
80    /// characters concatenated.
81    pub fn push_tokens(&mut self, ids: &[u32]) -> Option<String> {
82        let mut out = String::new();
83        for &id in ids {
84            if let Some(piece) = self.push_token(id) {
85                out.push_str(&piece);
86            }
87        }
88        if out.is_empty() {
89            None
90        } else {
91            Some(out)
92        }
93    }
94
95    /// Finish the stream and return any remaining bytes as a `String`.
96    ///
97    /// Returns an error if the pending buffer still contains an incomplete
98    /// UTF-8 sequence (strict mode).  If lossy finishing is desired, use
99    /// [`Self::finish_lossy`] instead.
100    pub fn finish(mut self) -> TokenizerResult<String> {
101        if self.pending.is_empty() {
102            return Ok(String::new());
103        }
104        match String::from_utf8(std::mem::take(&mut self.pending)) {
105            Ok(s) => Ok(s),
106            Err(_) => Err(TokenizerError::IncompleteUtf8),
107        }
108    }
109
110    /// Finish the stream, replacing any trailing invalid bytes with
111    /// `\u{FFFD}`.  Never fails.
112    pub fn finish_lossy(mut self) -> String {
113        if self.pending.is_empty() {
114            return String::new();
115        }
116        let bytes = std::mem::take(&mut self.pending);
117        String::from_utf8_lossy(&bytes).into_owned()
118    }
119
120    /// Number of bytes currently held in the pending buffer.
121    ///
122    /// A non-zero value after a `push_token` call indicates that the last
123    /// token ended mid-UTF-8-sequence.
124    pub fn pending_len(&self) -> usize {
125        self.pending.len()
126    }
127
128    /// Reset the decoder state without destroying the `OxiTokenizer`
129    /// reference — useful when processing multiple independent streams.
130    pub fn reset(&mut self) {
131        self.pending.clear();
132        self.total_bytes = 0;
133        self.total_tokens = 0;
134    }
135
136    /// Total bytes processed since construction or last [`Self::reset`].
137    pub fn total_bytes(&self) -> usize {
138        self.total_bytes
139    }
140
141    /// Total tokens processed since construction or last [`Self::reset`].
142    pub fn total_tokens(&self) -> usize {
143        self.total_tokens
144    }
145
146    /// Pull all complete UTF-8 characters out of `pending`, leaving any
147    /// trailing incomplete sequence behind.
148    fn flush_complete(&mut self) -> Option<String> {
149        if self.pending.is_empty() {
150            return None;
151        }
152
153        // Find the longest UTF-8-valid prefix of `pending`.
154        match std::str::from_utf8(&self.pending) {
155            Ok(s) => {
156                // Entire buffer is valid.
157                let owned = s.to_owned();
158                self.pending.clear();
159                if owned.is_empty() {
160                    None
161                } else {
162                    Some(owned)
163                }
164            }
165            Err(e) => {
166                let valid_up_to = e.valid_up_to();
167                if valid_up_to == 0 {
168                    return None;
169                }
170                // Extract the complete prefix.
171                let prefix_bytes = self.pending[..valid_up_to].to_vec();
172                self.pending.drain(..valid_up_to);
173                match String::from_utf8(prefix_bytes) {
174                    Ok(s) if !s.is_empty() => Some(s),
175                    _ => None,
176                }
177            }
178        }
179    }
180}
181
182// ── Tests ────────────────────────────────────────────────────────────────────
183
184#[cfg(test)]
185mod tests {
186    use crate::OxiTokenizer;
187
188    #[test]
189    fn ascii_passthrough() {
190        let tok = OxiTokenizer::char_level_stub(256);
191        let ids = tok.encode("abc").expect("encode");
192        let mut dec = tok.streaming_decoder();
193        let mut out = String::new();
194        for id in &ids {
195            if let Some(piece) = dec.push_token(*id) {
196                out.push_str(&piece);
197            }
198        }
199        out.push_str(&dec.finish().expect("finish ok"));
200        assert_eq!(out, "abc");
201    }
202
203    #[test]
204    fn reset_clears_state() {
205        let tok = OxiTokenizer::char_level_stub(256);
206        let mut dec = tok.streaming_decoder();
207        let ids = tok.encode("abc").expect("encode");
208        for id in &ids {
209            dec.push_token(*id);
210        }
211        dec.reset();
212        assert_eq!(dec.pending_len(), 0);
213        assert_eq!(dec.total_bytes(), 0);
214        assert_eq!(dec.total_tokens(), 0);
215    }
216
217    #[test]
218    fn push_tokens_batch() {
219        let tok = OxiTokenizer::char_level_stub(256);
220        let mut dec = tok.streaming_decoder();
221        let ids = tok.encode("hello").expect("encode");
222        let out = dec.push_tokens(&ids).unwrap_or_default();
223        // Non-empty because char-level stub emits one char per token.
224        assert!(!out.is_empty());
225    }
226
227    #[test]
228    fn finish_on_empty_is_ok() {
229        let tok = OxiTokenizer::char_level_stub(256);
230        let dec = tok.streaming_decoder();
231        let out = dec.finish().expect("empty finish ok");
232        assert_eq!(out, "");
233    }
234
235    #[test]
236    fn finish_lossy_never_fails() {
237        let tok = OxiTokenizer::char_level_stub(256);
238        let dec = tok.streaming_decoder();
239        let out = dec.finish_lossy();
240        assert_eq!(out, "");
241    }
242
243    #[test]
244    fn counters_advance() {
245        let tok = OxiTokenizer::char_level_stub(256);
246        let mut dec = tok.streaming_decoder();
247        let ids = tok.encode("ab").expect("encode");
248        for id in &ids {
249            dec.push_token(*id);
250        }
251        assert!(dec.total_tokens() >= ids.len());
252        assert!(dec.total_bytes() > 0);
253    }
254}