use std::fs::File;
use std::io::Read;
use std::path::Path;
use super::model::MoonshineError;
pub struct BinTokenizer {
tokens_to_bytes: Vec<Vec<u8>>,
space_string: &'static str,
}
impl BinTokenizer {
pub fn new(path: &Path) -> Result<Self, MoonshineError> {
let tokenizer_path = path.join("tokenizer.bin");
if !tokenizer_path.exists() {
return Err(MoonshineError::TokenizerNotFound(
tokenizer_path.display().to_string(),
));
}
let mut file = File::open(&tokenizer_path).map_err(MoonshineError::Io)?;
let mut data = Vec::new();
file.read_to_end(&mut data).map_err(MoonshineError::Io)?;
let mut tokens_to_bytes = Vec::new();
let mut offset = 0;
while offset < data.len() {
let first_byte = data[offset];
offset += 1;
if first_byte == 0 {
tokens_to_bytes.push(Vec::new());
continue;
}
let byte_count = if first_byte < 128 {
first_byte as usize
} else {
if offset >= data.len() {
break;
}
let second_byte = data[offset];
offset += 1;
(second_byte as usize * 128) + first_byte as usize - 128
};
if offset + byte_count > data.len() {
break;
}
let bytes = data[offset..offset + byte_count].to_vec();
offset += byte_count;
tokens_to_bytes.push(bytes);
}
if tokens_to_bytes.is_empty() {
return Err(MoonshineError::Tokenization(
"No tokens found in tokenizer.bin".to_string(),
));
}
log::info!(
"Loaded {} tokens from {:?}",
tokens_to_bytes.len(),
tokenizer_path
);
Ok(Self {
tokens_to_bytes,
space_string: "\u{2581}", })
}
pub fn decode(&self, tokens: &[i64]) -> Result<String, MoonshineError> {
let mut result_bytes: Vec<u8> = Vec::new();
for &token in tokens {
let idx = token as usize;
if idx >= self.tokens_to_bytes.len() {
continue;
}
let bytes = &self.tokens_to_bytes[idx];
if bytes.is_empty() {
continue;
}
if bytes.len() > 2 && bytes[0] == b'<' && bytes[bytes.len() - 1] == b'>' {
continue;
}
result_bytes.extend_from_slice(bytes);
}
let text = String::from_utf8_lossy(&result_bytes);
let text = text.replace(self.space_string, " ");
let text = text.trim().to_string();
Ok(text)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bin_format_parsing() {
let data: Vec<u8> = vec![
0x00, 0x03, b'c', b'a', b't', 0x01, b' ', ];
let dir = std::env::temp_dir().join("bin_tokenizer_test");
std::fs::create_dir_all(&dir).unwrap();
let path = dir.join("tokenizer.bin");
std::fs::write(&path, &data).unwrap();
let tokenizer = BinTokenizer::new(&dir).unwrap();
assert_eq!(tokenizer.tokens_to_bytes.len(), 3);
assert_eq!(tokenizer.tokens_to_bytes[0], Vec::<u8>::new());
assert_eq!(tokenizer.tokens_to_bytes[1], b"cat".to_vec());
assert_eq!(tokenizer.tokens_to_bytes[2], b" ".to_vec());
std::fs::remove_dir_all(&dir).ok();
}
}