use crate::{
error::{TokenizerError, TokenizerResult},
tokenizer::OxiTokenizer,
};
pub struct StreamingDecoder<'a> {
tokenizer: &'a OxiTokenizer,
pending: Vec<u8>,
total_bytes: usize,
total_tokens: usize,
}
impl<'a> StreamingDecoder<'a> {
pub fn new(tokenizer: &'a OxiTokenizer) -> Self {
Self {
tokenizer,
pending: Vec::with_capacity(8),
total_bytes: 0,
total_tokens: 0,
}
}
pub fn push_token(&mut self, id: u32) -> Option<String> {
self.total_tokens += 1;
let mut scratch: Vec<u8> = Vec::with_capacity(8);
self.tokenizer.decode_id_into(id, &mut scratch);
if scratch.is_empty() {
return None;
}
self.total_bytes += scratch.len();
self.pending.extend_from_slice(&scratch);
self.flush_complete()
}
pub fn push_tokens(&mut self, ids: &[u32]) -> Option<String> {
let mut out = String::new();
for &id in ids {
if let Some(piece) = self.push_token(id) {
out.push_str(&piece);
}
}
if out.is_empty() {
None
} else {
Some(out)
}
}
pub fn finish(mut self) -> TokenizerResult<String> {
if self.pending.is_empty() {
return Ok(String::new());
}
match String::from_utf8(std::mem::take(&mut self.pending)) {
Ok(s) => Ok(s),
Err(_) => Err(TokenizerError::IncompleteUtf8),
}
}
pub fn finish_lossy(mut self) -> String {
if self.pending.is_empty() {
return String::new();
}
let bytes = std::mem::take(&mut self.pending);
String::from_utf8_lossy(&bytes).into_owned()
}
pub fn pending_len(&self) -> usize {
self.pending.len()
}
pub fn reset(&mut self) {
self.pending.clear();
self.total_bytes = 0;
self.total_tokens = 0;
}
pub fn total_bytes(&self) -> usize {
self.total_bytes
}
pub fn total_tokens(&self) -> usize {
self.total_tokens
}
fn flush_complete(&mut self) -> Option<String> {
if self.pending.is_empty() {
return None;
}
match std::str::from_utf8(&self.pending) {
Ok(s) => {
let owned = s.to_owned();
self.pending.clear();
if owned.is_empty() {
None
} else {
Some(owned)
}
}
Err(e) => {
let valid_up_to = e.valid_up_to();
if valid_up_to == 0 {
return None;
}
let prefix_bytes = self.pending[..valid_up_to].to_vec();
self.pending.drain(..valid_up_to);
match String::from_utf8(prefix_bytes) {
Ok(s) if !s.is_empty() => Some(s),
_ => None,
}
}
}
}
}
#[cfg(test)]
mod tests {
use crate::OxiTokenizer;
#[test]
fn ascii_passthrough() {
let tok = OxiTokenizer::char_level_stub(256);
let ids = tok.encode("abc").expect("encode");
let mut dec = tok.streaming_decoder();
let mut out = String::new();
for id in &ids {
if let Some(piece) = dec.push_token(*id) {
out.push_str(&piece);
}
}
out.push_str(&dec.finish().expect("finish ok"));
assert_eq!(out, "abc");
}
#[test]
fn reset_clears_state() {
let tok = OxiTokenizer::char_level_stub(256);
let mut dec = tok.streaming_decoder();
let ids = tok.encode("abc").expect("encode");
for id in &ids {
dec.push_token(*id);
}
dec.reset();
assert_eq!(dec.pending_len(), 0);
assert_eq!(dec.total_bytes(), 0);
assert_eq!(dec.total_tokens(), 0);
}
#[test]
fn push_tokens_batch() {
let tok = OxiTokenizer::char_level_stub(256);
let mut dec = tok.streaming_decoder();
let ids = tok.encode("hello").expect("encode");
let out = dec.push_tokens(&ids).unwrap_or_default();
assert!(!out.is_empty());
}
#[test]
fn finish_on_empty_is_ok() {
let tok = OxiTokenizer::char_level_stub(256);
let dec = tok.streaming_decoder();
let out = dec.finish().expect("empty finish ok");
assert_eq!(out, "");
}
#[test]
fn finish_lossy_never_fails() {
let tok = OxiTokenizer::char_level_stub(256);
let dec = tok.streaming_decoder();
let out = dec.finish_lossy();
assert_eq!(out, "");
}
#[test]
fn counters_advance() {
let tok = OxiTokenizer::char_level_stub(256);
let mut dec = tok.streaming_decoder();
let ids = tok.encode("ab").expect("encode");
for id in &ids {
dec.push_token(*id);
}
assert!(dec.total_tokens() >= ids.len());
assert!(dec.total_bytes() > 0);
}
}