oxibonsai_tokenizer/streaming.rs
1//! UTF-8-safe streaming decoder.
2//!
3//! When a server emits tokens one at a time, naive `decode(&[id])` can return
4//! strings with invalid UTF-8 because a single BPE token may hold *part* of a
5//! multi-byte codepoint (common for CJK / emoji output). The decoder in this
6//! module keeps a small byte buffer across calls and only flushes characters
7//! that form a complete UTF-8 sequence.
8//!
9//! ## Usage
10//!
11//! ```rust
12//! use oxibonsai_tokenizer::OxiTokenizer;
13//!
14//! let tok = OxiTokenizer::char_level_stub(256);
15//! let ids = tok.encode("Hello!").expect("encode");
16//! let mut dec = tok.streaming_decoder();
17//! let mut out = String::new();
18//! for id in &ids {
19//! if let Some(piece) = dec.push_token(*id) {
20//! out.push_str(&piece);
21//! }
22//! }
23//! out.push_str(&dec.finish().expect("stream must end on a UTF-8 boundary"));
24//! assert_eq!(out, "Hello!");
25//! ```
26
27use crate::{
28 error::{TokenizerError, TokenizerResult},
29 tokenizer::OxiTokenizer,
30};
31
32/// A streaming decoder that yields well-formed UTF-8 slices as tokens arrive.
33///
34/// The decoder holds a reference to its parent [`OxiTokenizer`] so that
35/// special-token handling, vocabulary lookup and byte-level decoding remain
36/// consistent with [`OxiTokenizer::decode`].
37pub struct StreamingDecoder<'a> {
38 tokenizer: &'a OxiTokenizer,
39 /// Bytes that have been decoded but not yet emitted because they are
40 /// part of an incomplete UTF-8 sequence.
41 pending: Vec<u8>,
42 /// Total bytes the decoder has seen across the stream (for diagnostics).
43 total_bytes: usize,
44 /// Total tokens the decoder has seen across the stream.
45 total_tokens: usize,
46}
47
48impl<'a> StreamingDecoder<'a> {
49 /// Create a fresh decoder tied to `tokenizer`.
50 pub fn new(tokenizer: &'a OxiTokenizer) -> Self {
51 Self {
52 tokenizer,
53 pending: Vec::with_capacity(8),
54 total_bytes: 0,
55 total_tokens: 0,
56 }
57 }
58
59 /// Push a single token ID and return the next well-formed UTF-8 slice, if
60 /// any. Returns `None` when the token's bytes do not extend any
61 /// previously-pending prefix into a full UTF-8 character.
62 ///
63 /// The returned `String` contains all characters that became complete as
64 /// a result of this push — may be multiple characters if the token
65 /// carries several whole code points.
66 pub fn push_token(&mut self, id: u32) -> Option<String> {
67 self.total_tokens += 1;
68 let mut scratch: Vec<u8> = Vec::with_capacity(8);
69 self.tokenizer.decode_id_into(id, &mut scratch);
70 if scratch.is_empty() {
71 return None;
72 }
73 self.total_bytes += scratch.len();
74 self.pending.extend_from_slice(&scratch);
75 self.flush_complete()
76 }
77
78 /// Push many tokens at once. Equivalent to repeatedly calling
79 /// [`Self::push_token`] but only returns once, with all complete
80 /// characters concatenated.
81 pub fn push_tokens(&mut self, ids: &[u32]) -> Option<String> {
82 let mut out = String::new();
83 for &id in ids {
84 if let Some(piece) = self.push_token(id) {
85 out.push_str(&piece);
86 }
87 }
88 if out.is_empty() {
89 None
90 } else {
91 Some(out)
92 }
93 }
94
95 /// Finish the stream and return any remaining bytes as a `String`.
96 ///
97 /// Returns an error if the pending buffer still contains an incomplete
98 /// UTF-8 sequence (strict mode). If lossy finishing is desired, use
99 /// [`Self::finish_lossy`] instead.
100 pub fn finish(mut self) -> TokenizerResult<String> {
101 if self.pending.is_empty() {
102 return Ok(String::new());
103 }
104 match String::from_utf8(std::mem::take(&mut self.pending)) {
105 Ok(s) => Ok(s),
106 Err(_) => Err(TokenizerError::IncompleteUtf8),
107 }
108 }
109
110 /// Finish the stream, replacing any trailing invalid bytes with
111 /// `\u{FFFD}`. Never fails.
112 pub fn finish_lossy(mut self) -> String {
113 if self.pending.is_empty() {
114 return String::new();
115 }
116 let bytes = std::mem::take(&mut self.pending);
117 String::from_utf8_lossy(&bytes).into_owned()
118 }
119
120 /// Number of bytes currently held in the pending buffer.
121 ///
122 /// A non-zero value after a `push_token` call indicates that the last
123 /// token ended mid-UTF-8-sequence.
124 pub fn pending_len(&self) -> usize {
125 self.pending.len()
126 }
127
128 /// Reset the decoder state without destroying the `OxiTokenizer`
129 /// reference — useful when processing multiple independent streams.
130 pub fn reset(&mut self) {
131 self.pending.clear();
132 self.total_bytes = 0;
133 self.total_tokens = 0;
134 }
135
136 /// Total bytes processed since construction or last [`Self::reset`].
137 pub fn total_bytes(&self) -> usize {
138 self.total_bytes
139 }
140
141 /// Total tokens processed since construction or last [`Self::reset`].
142 pub fn total_tokens(&self) -> usize {
143 self.total_tokens
144 }
145
146 /// Pull all complete UTF-8 characters out of `pending`, leaving any
147 /// trailing incomplete sequence behind.
148 fn flush_complete(&mut self) -> Option<String> {
149 if self.pending.is_empty() {
150 return None;
151 }
152
153 // Find the longest UTF-8-valid prefix of `pending`.
154 match std::str::from_utf8(&self.pending) {
155 Ok(s) => {
156 // Entire buffer is valid.
157 let owned = s.to_owned();
158 self.pending.clear();
159 if owned.is_empty() {
160 None
161 } else {
162 Some(owned)
163 }
164 }
165 Err(e) => {
166 let valid_up_to = e.valid_up_to();
167 if valid_up_to == 0 {
168 return None;
169 }
170 // Extract the complete prefix.
171 let prefix_bytes = self.pending[..valid_up_to].to_vec();
172 self.pending.drain(..valid_up_to);
173 match String::from_utf8(prefix_bytes) {
174 Ok(s) if !s.is_empty() => Some(s),
175 _ => None,
176 }
177 }
178 }
179 }
180}
181
182// ── Tests ────────────────────────────────────────────────────────────────────
183
184#[cfg(test)]
185mod tests {
186 use crate::OxiTokenizer;
187
188 #[test]
189 fn ascii_passthrough() {
190 let tok = OxiTokenizer::char_level_stub(256);
191 let ids = tok.encode("abc").expect("encode");
192 let mut dec = tok.streaming_decoder();
193 let mut out = String::new();
194 for id in &ids {
195 if let Some(piece) = dec.push_token(*id) {
196 out.push_str(&piece);
197 }
198 }
199 out.push_str(&dec.finish().expect("finish ok"));
200 assert_eq!(out, "abc");
201 }
202
203 #[test]
204 fn reset_clears_state() {
205 let tok = OxiTokenizer::char_level_stub(256);
206 let mut dec = tok.streaming_decoder();
207 let ids = tok.encode("abc").expect("encode");
208 for id in &ids {
209 dec.push_token(*id);
210 }
211 dec.reset();
212 assert_eq!(dec.pending_len(), 0);
213 assert_eq!(dec.total_bytes(), 0);
214 assert_eq!(dec.total_tokens(), 0);
215 }
216
217 #[test]
218 fn push_tokens_batch() {
219 let tok = OxiTokenizer::char_level_stub(256);
220 let mut dec = tok.streaming_decoder();
221 let ids = tok.encode("hello").expect("encode");
222 let out = dec.push_tokens(&ids).unwrap_or_default();
223 // Non-empty because char-level stub emits one char per token.
224 assert!(!out.is_empty());
225 }
226
227 #[test]
228 fn finish_on_empty_is_ok() {
229 let tok = OxiTokenizer::char_level_stub(256);
230 let dec = tok.streaming_decoder();
231 let out = dec.finish().expect("empty finish ok");
232 assert_eq!(out, "");
233 }
234
235 #[test]
236 fn finish_lossy_never_fails() {
237 let tok = OxiTokenizer::char_level_stub(256);
238 let dec = tok.streaming_decoder();
239 let out = dec.finish_lossy();
240 assert_eq!(out, "");
241 }
242
243 #[test]
244 fn counters_advance() {
245 let tok = OxiTokenizer::char_level_stub(256);
246 let mut dec = tok.streaming_decoder();
247 let ids = tok.encode("ab").expect("encode");
248 for id in &ids {
249 dec.push_token(*id);
250 }
251 assert!(dec.total_tokens() >= ids.len());
252 assert!(dec.total_bytes() > 0);
253 }
254}