use super::byte_level::byte_level_decode_bytes;
use super::tokenizer::Tokenizer;
pub struct StreamingDecoder<'a> {
tokenizer: &'a Tokenizer,
buffer: Vec<u8>,
}
impl<'a> StreamingDecoder<'a> {
pub fn new(tokenizer: &'a Tokenizer) -> Self {
Self {
tokenizer,
buffer: Vec::with_capacity(16),
}
}
pub fn add_token(&mut self, token_id: u32) -> Option<String> {
let bytes = if let Some(b) = self.tokenizer.decoder().get(&token_id) {
b.as_slice()
} else if let Some(s) = self.tokenizer.special_tokens_decoder().get(&token_id) {
s.as_bytes()
} else {
return None;
};
self.buffer.extend_from_slice(bytes);
self.extract_complete_utf8()
}
pub fn add_tokens(&mut self, token_ids: &[u32]) -> Option<String> {
for &token_id in token_ids {
let bytes = if let Some(b) = self.tokenizer.decoder().get(&token_id) {
b.as_slice()
} else if let Some(s) = self.tokenizer.special_tokens_decoder().get(&token_id) {
s.as_bytes()
} else {
continue;
};
self.buffer.extend_from_slice(bytes);
}
self.extract_complete_utf8()
}
pub fn flush(&mut self) -> String {
if self.buffer.is_empty() {
return String::new();
}
let result = String::from_utf8_lossy(&self.buffer).into_owned();
self.buffer.clear();
result
}
pub fn reset(&mut self) {
self.buffer.clear();
}
pub fn has_pending(&self) -> bool {
!self.buffer.is_empty()
}
pub fn pending_bytes(&self) -> usize {
self.buffer.len()
}
fn extract_complete_utf8(&mut self) -> Option<String> {
if self.buffer.is_empty() {
return None;
}
let valid_len = self.find_valid_utf8_len();
if valid_len == 0 {
return None;
}
let valid_bytes: Vec<u8> = self.buffer.drain(..valid_len).collect();
let result = unsafe { String::from_utf8_unchecked(valid_bytes) };
Some(result)
}
fn find_valid_utf8_len(&self) -> usize {
let bytes = &self.buffer;
let len = bytes.len();
if len == 0 {
return 0;
}
if std::str::from_utf8(bytes).is_ok() {
return len;
}
for incomplete_len in 1..=3.min(len) {
let check_len = len - incomplete_len;
if check_len == 0 {
continue;
}
if std::str::from_utf8(&bytes[..check_len]).is_ok() {
if self.could_be_incomplete_sequence(&bytes[check_len..]) {
return check_len;
}
}
}
for i in (0..len).rev() {
if std::str::from_utf8(&bytes[..=i]).is_ok() {
return i + 1;
}
}
0
}
fn could_be_incomplete_sequence(&self, bytes: &[u8]) -> bool {
if bytes.is_empty() {
return false;
}
let first = bytes[0];
match first {
0xC0..=0xDF => bytes.len() < 2,
0xE0..=0xEF => bytes.len() < 3,
0xF0..=0xF7 => bytes.len() < 4,
_ => false,
}
}
}
pub struct ByteLevelStreamingDecoder<'a> {
tokenizer: &'a Tokenizer,
buffer: Vec<u8>,
}
impl<'a> ByteLevelStreamingDecoder<'a> {
pub fn new(tokenizer: &'a Tokenizer) -> Self {
Self {
tokenizer,
buffer: Vec::with_capacity(16),
}
}
pub fn add_token(&mut self, token_id: u32) -> Option<String> {
if let Some(encoded_bytes) = self.tokenizer.decoder().get(&token_id) {
if let Some(raw_bytes) = byte_level_decode_bytes(encoded_bytes) {
self.buffer.extend_from_slice(&raw_bytes);
} else {
self.buffer.extend_from_slice(encoded_bytes);
}
} else if let Some(special) = self.tokenizer.special_tokens_decoder().get(&token_id) {
self.buffer.extend_from_slice(special.as_bytes());
} else {
return None;
}
self.extract_complete_utf8()
}
pub fn add_tokens(&mut self, token_ids: &[u32]) -> Option<String> {
for &token_id in token_ids {
if let Some(encoded_bytes) = self.tokenizer.decoder().get(&token_id) {
if let Some(raw_bytes) = byte_level_decode_bytes(encoded_bytes) {
self.buffer.extend_from_slice(&raw_bytes);
} else {
self.buffer.extend_from_slice(encoded_bytes);
}
} else if let Some(special) = self.tokenizer.special_tokens_decoder().get(&token_id) {
self.buffer.extend_from_slice(special.as_bytes());
}
}
self.extract_complete_utf8()
}
pub fn flush(&mut self) -> String {
if self.buffer.is_empty() {
return String::new();
}
let result = String::from_utf8_lossy(&self.buffer).into_owned();
self.buffer.clear();
result
}
pub fn reset(&mut self) {
self.buffer.clear();
}
pub fn has_pending(&self) -> bool {
!self.buffer.is_empty()
}
pub fn pending_bytes(&self) -> usize {
self.buffer.len()
}
fn extract_complete_utf8(&mut self) -> Option<String> {
if self.buffer.is_empty() {
return None;
}
let valid_len = self.find_valid_utf8_len();
if valid_len == 0 {
return None;
}
let valid_bytes: Vec<u8> = self.buffer.drain(..valid_len).collect();
let result = unsafe { String::from_utf8_unchecked(valid_bytes) };
Some(result)
}
fn find_valid_utf8_len(&self) -> usize {
let bytes = &self.buffer;
let len = bytes.len();
if len == 0 {
return 0;
}
if std::str::from_utf8(bytes).is_ok() {
return len;
}
for incomplete_len in 1..=3.min(len) {
let check_len = len - incomplete_len;
if check_len == 0 {
continue;
}
if std::str::from_utf8(&bytes[..check_len]).is_ok()
&& self.could_be_incomplete_sequence(&bytes[check_len..])
{
return check_len;
}
}
for i in (0..len).rev() {
if std::str::from_utf8(&bytes[..=i]).is_ok() {
return i + 1;
}
}
0
}
fn could_be_incomplete_sequence(&self, bytes: &[u8]) -> bool {
if bytes.is_empty() {
return false;
}
let first = bytes[0];
match first {
0xC0..=0xDF => bytes.len() < 2, 0xE0..=0xEF => bytes.len() < 3, 0xF0..=0xF7 => bytes.len() < 4, _ => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rustc_hash::FxHashMap;
fn make_test_tokenizer() -> Tokenizer {
let mut encoder = FxHashMap::default();
for b in 0u8..=255 {
encoder.insert(vec![b], b as u32);
}
encoder.insert("Hello".as_bytes().to_vec(), 256);
encoder.insert("世界".as_bytes().to_vec(), 257);
let special_tokens = FxHashMap::default();
let pattern = r".";
Tokenizer::new(encoder, special_tokens, pattern).unwrap()
}
#[test]
fn test_simple_ascii() {
let tokenizer = make_test_tokenizer();
let mut decoder = StreamingDecoder::new(&tokenizer);
assert_eq!(decoder.add_token(b'H' as u32), Some("H".to_string()));
assert_eq!(decoder.add_token(b'i' as u32), Some("i".to_string()));
assert!(!decoder.has_pending());
}
#[test]
fn test_multi_byte_complete() {
let tokenizer = make_test_tokenizer();
let mut decoder = StreamingDecoder::new(&tokenizer);
assert_eq!(decoder.add_token(257), Some("世界".to_string()));
assert!(!decoder.has_pending());
}
#[test]
fn test_multi_byte_split() {
let tokenizer = make_test_tokenizer();
let mut decoder = StreamingDecoder::new(&tokenizer);
assert_eq!(decoder.add_token(0xE4), None); assert!(decoder.has_pending());
assert_eq!(decoder.pending_bytes(), 1);
assert_eq!(decoder.add_token(0xB8), None); assert_eq!(decoder.pending_bytes(), 2);
assert_eq!(decoder.add_token(0x96), Some("世".to_string())); assert!(!decoder.has_pending());
}
#[test]
fn test_flush_incomplete() {
let tokenizer = make_test_tokenizer();
let mut decoder = StreamingDecoder::new(&tokenizer);
decoder.add_token(0xE4); decoder.add_token(0xB8);
let flushed = decoder.flush();
assert!(flushed.contains('\u{FFFD}')); assert!(!decoder.has_pending());
}
#[test]
fn test_reset() {
let tokenizer = make_test_tokenizer();
let mut decoder = StreamingDecoder::new(&tokenizer);
decoder.add_token(0xE4);
assert!(decoder.has_pending());
decoder.reset();
assert!(!decoder.has_pending());
}
#[test]
fn test_mixed_complete_incomplete() {
let tokenizer = make_test_tokenizer();
let mut decoder = StreamingDecoder::new(&tokenizer);
let result1 = decoder.add_token(b'H' as u32);
assert_eq!(result1, Some("H".to_string()));
assert!(!decoder.has_pending());
let result2 = decoder.add_token(0xE4); assert_eq!(result2, None);
assert!(decoder.has_pending());
}
#[test]
fn test_add_tokens_batch() {
let tokenizer = make_test_tokenizer();
let mut decoder = StreamingDecoder::new(&tokenizer);
let result = decoder.add_tokens(&[b'H' as u32, b'i' as u32, b'!' as u32]);
assert_eq!(result, Some("Hi!".to_string()));
}
use super::super::byte_level::byte_level_encode;
fn make_byte_level_tokenizer() -> Tokenizer {
let mut encoder = FxHashMap::default();
encoder.insert(byte_level_encode(b"Hello").into_bytes(), 100);
encoder.insert(byte_level_encode(b" world").into_bytes(), 101);
encoder.insert(byte_level_encode("你好".as_bytes()).into_bytes(), 102);
let ni_bytes = "你".as_bytes();
for (i, &b) in ni_bytes.iter().enumerate() {
let byte_level = byte_level_encode(&[b]);
encoder.insert(byte_level.into_bytes(), 200 + i as u32);
}
let mut special_tokens = FxHashMap::default();
special_tokens.insert("<|think|>".to_string(), 1000);
let pattern = r".";
Tokenizer::new_byte_level(encoder, special_tokens, pattern).unwrap()
}
#[test]
fn test_byte_level_simple_ascii() {
let tokenizer = make_byte_level_tokenizer();
let mut decoder = ByteLevelStreamingDecoder::new(&tokenizer);
let result = decoder.add_token(100);
assert_eq!(result, Some("Hello".to_string()));
assert!(!decoder.has_pending());
}
#[test]
fn test_byte_level_with_space() {
let tokenizer = make_byte_level_tokenizer();
let mut decoder = ByteLevelStreamingDecoder::new(&tokenizer);
let result = decoder.add_token(101);
assert_eq!(result, Some(" world".to_string()));
}
#[test]
fn test_byte_level_chinese() {
let tokenizer = make_byte_level_tokenizer();
let mut decoder = ByteLevelStreamingDecoder::new(&tokenizer);
let result = decoder.add_token(102);
assert_eq!(result, Some("你好".to_string()));
}
#[test]
fn test_byte_level_split_chinese() {
let tokenizer = make_byte_level_tokenizer();
let mut decoder = ByteLevelStreamingDecoder::new(&tokenizer);
let result1 = decoder.add_token(200);
assert_eq!(result1, None);
assert!(decoder.has_pending());
let result2 = decoder.add_token(201);
assert_eq!(result2, None);
assert!(decoder.has_pending());
let result3 = decoder.add_token(202);
assert_eq!(result3, Some("你".to_string()));
assert!(!decoder.has_pending());
}
#[test]
fn test_byte_level_special_token() {
let tokenizer = make_byte_level_tokenizer();
let mut decoder = ByteLevelStreamingDecoder::new(&tokenizer);
let result = decoder.add_token(1000);
assert_eq!(result, Some("<|think|>".to_string()));
}
#[test]
fn test_byte_level_mixed() {
let tokenizer = make_byte_level_tokenizer();
let mut decoder = ByteLevelStreamingDecoder::new(&tokenizer);
let result = decoder.add_tokens(&[100, 1000, 101]);
assert_eq!(result, Some("Hello<|think|> world".to_string()));
}
#[test]
fn test_byte_level_flush() {
let tokenizer = make_byte_level_tokenizer();
let mut decoder = ByteLevelStreamingDecoder::new(&tokenizer);
decoder.add_token(200);
decoder.add_token(201);
assert!(decoder.has_pending());
let flushed = decoder.flush();
assert!(flushed.contains('\u{FFFD}'));
assert!(!decoder.has_pending());
}
#[test]
fn test_byte_level_reset() {
let tokenizer = make_byte_level_tokenizer();
let mut decoder = ByteLevelStreamingDecoder::new(&tokenizer);
decoder.add_token(200);
assert!(decoder.has_pending());
decoder.reset();
assert!(!decoder.has_pending());
}
}