use crate::pre_tokenizers::byte_level::BYTE_TO_CHAR;
const CHAR_TO_BYTE: [u8; 324] = build_char_to_byte();
const fn build_char_to_byte() -> [u8; 324] {
let mut table = [0u8; 324];
let mut i = 0u16;
while i < 256 {
let ch = BYTE_TO_CHAR[i as usize];
table[ch as usize] = i as u8;
i += 1;
}
table
}
#[derive(Debug)]
pub struct ByteLevelDecoder;
impl ByteLevelDecoder {
pub fn decode_chain(&self, tokens: Vec<String>) -> Vec<String> {
let joined: String = tokens.into_iter().collect();
let mut bytes: Vec<u8> = Vec::with_capacity(joined.len());
for c in joined.chars() {
let cp = c as usize;
if cp < CHAR_TO_BYTE.len() {
bytes.push(CHAR_TO_BYTE[cp]);
} else {
let mut buf = [0u8; 4];
let s = c.encode_utf8(&mut buf);
bytes.extend_from_slice(s.as_bytes());
}
}
vec![String::from_utf8_lossy(&bytes).into_owned()]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn roundtrip_ascii() {
let dec = ByteLevelDecoder;
let result = dec.decode_chain(vec!["Hello".to_string()]);
assert_eq!(result, vec!["Hello"]);
}
#[test]
fn roundtrip_space() {
let dec = ByteLevelDecoder;
let result = dec.decode_chain(vec!["\u{120}Hello".to_string()]);
assert_eq!(result, vec![" Hello"]);
}
#[test]
fn roundtrip_multibyte() {
let dec = ByteLevelDecoder;
let encoded: String = [0xE2u8, 0x82, 0xAC]
.iter()
.map(|&b| BYTE_TO_CHAR[b as usize])
.collect();
let result = dec.decode_chain(vec![encoded]);
assert_eq!(result, vec!["€"]);
}
#[test]
fn non_gpt2_chars_preserved() {
let dec = ByteLevelDecoder;
let result = dec.decode_chain(vec![
"<\u{FF5C}begin\u{2581}of\u{2581}sentence\u{FF5C}>".to_string(),
]);
assert_eq!(result, vec!["<|begin▁of▁sentence|>"]);
}
}