use std::collections::HashMap;
use std::fs::File;
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::path::Path;
pub const FORMAT_MAGIC: &str = "oxitokenizer v1";
const BASE64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
pub fn base64_encode(bytes: &[u8]) -> String {
let mut out = Vec::with_capacity(bytes.len().div_ceil(3) * 4);
for chunk in bytes.chunks(3) {
let b0 = chunk[0] as u32;
let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
let combined = (b0 << 16) | (b1 << 8) | b2;
out.push(BASE64_CHARS[((combined >> 18) & 0x3F) as usize]);
out.push(BASE64_CHARS[((combined >> 12) & 0x3F) as usize]);
if chunk.len() > 1 {
out.push(BASE64_CHARS[((combined >> 6) & 0x3F) as usize]);
} else {
out.push(b'=');
}
if chunk.len() > 2 {
out.push(BASE64_CHARS[(combined & 0x3F) as usize]);
} else {
out.push(b'=');
}
}
String::from_utf8(out).unwrap_or_default()
}
pub fn base64_decode(s: &str) -> Result<Vec<u8>, SerializationError> {
let s = s.trim_end_matches('=');
let mut out = Vec::with_capacity((s.len() * 3) / 4 + 1);
let decode_char = |c: u8| -> Option<u32> {
match c {
b'A'..=b'Z' => Some((c - b'A') as u32),
b'a'..=b'z' => Some((c - b'a' + 26) as u32),
b'0'..=b'9' => Some((c - b'0' + 52) as u32),
b'+' => Some(62),
b'/' => Some(63),
_ => None,
}
};
let chars: Vec<u8> = s.bytes().collect();
for chunk in chars.chunks(4) {
match chunk.len() {
4 => {
let v0 = decode_char(chunk[0]).ok_or_else(|| {
SerializationError::Base64Error(format!("invalid char '{}'", chunk[0] as char))
})?;
let v1 = decode_char(chunk[1]).ok_or_else(|| {
SerializationError::Base64Error(format!("invalid char '{}'", chunk[1] as char))
})?;
let v2 = decode_char(chunk[2]).ok_or_else(|| {
SerializationError::Base64Error(format!("invalid char '{}'", chunk[2] as char))
})?;
let v3 = decode_char(chunk[3]).ok_or_else(|| {
SerializationError::Base64Error(format!("invalid char '{}'", chunk[3] as char))
})?;
let combined = (v0 << 18) | (v1 << 12) | (v2 << 6) | v3;
out.push(((combined >> 16) & 0xFF) as u8);
out.push(((combined >> 8) & 0xFF) as u8);
out.push((combined & 0xFF) as u8);
}
3 => {
let v0 = decode_char(chunk[0]).ok_or_else(|| {
SerializationError::Base64Error(format!("invalid char '{}'", chunk[0] as char))
})?;
let v1 = decode_char(chunk[1]).ok_or_else(|| {
SerializationError::Base64Error(format!("invalid char '{}'", chunk[1] as char))
})?;
let v2 = decode_char(chunk[2]).ok_or_else(|| {
SerializationError::Base64Error(format!("invalid char '{}'", chunk[2] as char))
})?;
let combined = (v0 << 18) | (v1 << 12) | (v2 << 6);
out.push(((combined >> 16) & 0xFF) as u8);
out.push(((combined >> 8) & 0xFF) as u8);
}
2 => {
let v0 = decode_char(chunk[0]).ok_or_else(|| {
SerializationError::Base64Error(format!("invalid char '{}'", chunk[0] as char))
})?;
let v1 = decode_char(chunk[1]).ok_or_else(|| {
SerializationError::Base64Error(format!("invalid char '{}'", chunk[1] as char))
})?;
let combined = (v0 << 18) | (v1 << 12);
out.push(((combined >> 16) & 0xFF) as u8);
}
1 => {
return Err(SerializationError::Base64Error(
"truncated base64 group of 1 char".to_string(),
));
}
_ => {}
}
}
Ok(out)
}
#[derive(Debug, thiserror::Error)]
pub enum SerializationError {
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("invalid format magic: expected '{expected}', got '{got}'")]
InvalidMagic { expected: String, got: String },
#[error("parse error on line {line}: {msg}")]
ParseError { line: usize, msg: String },
#[error("base64 decode error: {0}")]
Base64Error(String),
#[error("duplicate token id {0}")]
DuplicateId(u32),
}
#[derive(Debug)]
pub struct TokenizerState {
pub vocab: HashMap<u32, String>,
pub merges: Vec<(u32, u32, u32)>,
pub special_tokens: HashMap<String, u32>,
}
impl TokenizerState {
pub fn new() -> Self {
Self {
vocab: HashMap::new(),
merges: Vec::new(),
special_tokens: HashMap::new(),
}
}
pub fn from_trained(trained: &crate::trainer::TrainedTokenizer) -> Self {
let mut state = Self::new();
for (&id, token) in &trained.vocab {
if token.starts_with('<') && token.ends_with('>') {
state.special_tokens.insert(token.clone(), id);
}
state.vocab.insert(id, token.clone());
}
for rule in &trained.merges {
state.merges.push((rule.left, rule.right, rule.merged));
}
state
}
pub fn vocab_size(&self) -> usize {
self.vocab.len()
}
pub fn save_to<W: Write>(&self, writer: &mut W) -> Result<(), SerializationError> {
writeln!(writer, "{}", FORMAT_MAGIC)?;
writeln!(writer, "vocab_size {}", self.vocab.len())?;
writeln!(writer, "merges {}", self.merges.len())?;
let mut vocab_entries: Vec<(u32, &str)> =
self.vocab.iter().map(|(&id, s)| (id, s.as_str())).collect();
vocab_entries.sort_by_key(|(id, _)| *id);
for (id, token) in &vocab_entries {
let encoded = base64_encode(token.as_bytes());
writeln!(writer, "tok_id {id} {encoded}")?;
}
for &(left, right, merged) in &self.merges {
writeln!(writer, "merge {left} {right} {merged}")?;
}
let mut special_entries: Vec<(&str, u32)> = self
.special_tokens
.iter()
.map(|(k, &v)| (k.as_str(), v))
.collect();
special_entries.sort_by_key(|(name, _)| *name);
for (name, id) in &special_entries {
let encoded = base64_encode(name.as_bytes());
writeln!(writer, "special {encoded} {id}")?;
}
Ok(())
}
pub fn load_from<R: BufRead>(reader: &mut R) -> Result<Self, SerializationError> {
let mut lines = reader.lines();
let mut line_no: usize = 0;
let mut next_line = |line_no: &mut usize| -> Result<String, SerializationError> {
*line_no += 1;
match lines.next() {
Some(Ok(l)) => Ok(l),
Some(Err(e)) => Err(SerializationError::Io(e)),
None => Err(SerializationError::ParseError {
line: *line_no,
msg: "unexpected end of file".to_string(),
}),
}
};
let magic_line = next_line(&mut line_no)?;
if magic_line.trim() != FORMAT_MAGIC {
return Err(SerializationError::InvalidMagic {
expected: FORMAT_MAGIC.to_string(),
got: magic_line.trim().to_string(),
});
}
let vocab_size_line = next_line(&mut line_no)?;
let vocab_size = parse_count_line(&vocab_size_line, "vocab_size", line_no)?;
let merges_line = next_line(&mut line_no)?;
let merges_count = parse_count_line(&merges_line, "merges", line_no)?;
let mut vocab: HashMap<u32, String> = HashMap::with_capacity(vocab_size);
for _ in 0..vocab_size {
let l = next_line(&mut line_no)?;
let parts: Vec<&str> = l.trim().splitn(3, ' ').collect();
if parts.len() != 3 || parts[0] != "tok_id" {
return Err(SerializationError::ParseError {
line: line_no,
msg: format!("expected 'tok_id <id> <b64>', got '{l}'"),
});
}
let id: u32 = parts[1]
.parse()
.map_err(|_| SerializationError::ParseError {
line: line_no,
msg: format!("invalid token id '{}'", parts[1]),
})?;
let token_bytes = base64_decode(parts[2])?;
let token =
String::from_utf8(token_bytes).map_err(|_| SerializationError::ParseError {
line: line_no,
msg: "token text is not valid UTF-8".to_string(),
})?;
if vocab.contains_key(&id) {
return Err(SerializationError::DuplicateId(id));
}
vocab.insert(id, token);
}
let mut merges: Vec<(u32, u32, u32)> = Vec::with_capacity(merges_count);
for _ in 0..merges_count {
let l = next_line(&mut line_no)?;
let parts: Vec<&str> = l.trim().splitn(4, ' ').collect();
if parts.len() != 4 || parts[0] != "merge" {
return Err(SerializationError::ParseError {
line: line_no,
msg: format!("expected 'merge <left> <right> <merged>', got '{l}'"),
});
}
let left: u32 = parts[1]
.parse()
.map_err(|_| SerializationError::ParseError {
line: line_no,
msg: format!("invalid merge left id '{}'", parts[1]),
})?;
let right: u32 = parts[2]
.parse()
.map_err(|_| SerializationError::ParseError {
line: line_no,
msg: format!("invalid merge right id '{}'", parts[2]),
})?;
let merged: u32 = parts[3]
.parse()
.map_err(|_| SerializationError::ParseError {
line: line_no,
msg: format!("invalid merge merged id '{}'", parts[3]),
})?;
merges.push((left, right, merged));
}
let mut special_tokens: HashMap<String, u32> = HashMap::new();
for maybe_line in lines {
line_no += 1;
let l = maybe_line.map_err(SerializationError::Io)?;
let l = l.trim();
if l.is_empty() {
continue;
}
let parts: Vec<&str> = l.splitn(3, ' ').collect();
if parts.len() != 3 || parts[0] != "special" {
return Err(SerializationError::ParseError {
line: line_no,
msg: format!("expected 'special <b64> <id>', got '{l}'"),
});
}
let name_bytes = base64_decode(parts[1])?;
let name =
String::from_utf8(name_bytes).map_err(|_| SerializationError::ParseError {
line: line_no,
msg: "special token name is not valid UTF-8".to_string(),
})?;
let id: u32 = parts[2]
.parse()
.map_err(|_| SerializationError::ParseError {
line: line_no,
msg: format!("invalid special token id '{}'", parts[2]),
})?;
special_tokens.insert(name, id);
}
Ok(TokenizerState {
vocab,
merges,
special_tokens,
})
}
pub fn save(&self, path: &Path) -> Result<(), SerializationError> {
let file = File::create(path).map_err(SerializationError::Io)?;
let mut writer = BufWriter::new(file);
self.save_to(&mut writer)?;
writer.flush().map_err(SerializationError::Io)?;
Ok(())
}
pub fn load(path: &Path) -> Result<Self, SerializationError> {
let file = File::open(path).map_err(SerializationError::Io)?;
let mut reader = BufReader::new(file);
Self::load_from(&mut reader)
}
pub fn to_oxi_tokenizer(&self) -> crate::OxiTokenizer {
use crate::{
bpe::BpeMerges,
tokenizer::{OxiTokenizer, TokenizerConfig},
vocab::Vocabulary,
};
let mut vocabulary = Vocabulary::new();
for (&id, token) in &self.vocab {
if self.special_tokens.contains_key(token.as_str()) {
vocabulary.add_special(token, id);
} else {
vocabulary.insert(token, id);
}
}
let mut bpe_merges = BpeMerges::new();
for &(left_id, right_id, merged_id) in &self.merges {
let left_str = self.vocab.get(&left_id).map(|s| s.as_str()).unwrap_or("");
let right_str = self.vocab.get(&right_id).map(|s| s.as_str()).unwrap_or("");
bpe_merges.add_merge(left_str, right_str, merged_id);
}
let config = TokenizerConfig::default();
OxiTokenizer::new(vocabulary, bpe_merges, config)
}
}
impl Default for TokenizerState {
fn default() -> Self {
Self::new()
}
}
fn parse_count_line(
line: &str,
keyword: &str,
line_no: usize,
) -> Result<usize, SerializationError> {
let parts: Vec<&str> = line.trim().splitn(2, ' ').collect();
if parts.len() != 2 || parts[0] != keyword {
return Err(SerializationError::ParseError {
line: line_no,
msg: format!("expected '{keyword} <N>', got '{line}'"),
});
}
parts[1]
.parse::<usize>()
.map_err(|_| SerializationError::ParseError {
line: line_no,
msg: format!("invalid count value '{}'", parts[1]),
})
}
#[cfg(test)]
mod inline_tests {
use super::*;
#[test]
fn base64_encode_decode_hello() {
let original = b"Hello, World!";
let encoded = base64_encode(original);
let decoded = base64_decode(&encoded).expect("decode should succeed");
assert_eq!(decoded, original);
}
#[test]
fn base64_empty() {
let encoded = base64_encode(b"");
assert_eq!(encoded, "");
let decoded = base64_decode("").expect("decode empty");
assert!(decoded.is_empty());
}
#[test]
fn tokenizer_state_roundtrip_basic() {
let mut state = TokenizerState::new();
state.vocab.insert(0, "<unk>".to_string());
state.vocab.insert(1, "a".to_string());
state.merges.push((0, 1, 2));
let mut buf = Vec::new();
state.save_to(&mut buf).expect("save should succeed");
let mut reader = std::io::BufReader::new(buf.as_slice());
let loaded = TokenizerState::load_from(&mut reader).expect("load should succeed");
assert_eq!(loaded.vocab.get(&0), Some(&"<unk>".to_string()));
assert_eq!(loaded.vocab.get(&1), Some(&"a".to_string()));
assert_eq!(loaded.merges, vec![(0, 1, 2)]);
}
}