use oxibonsai_tokenizer::{
base64_decode, base64_encode,
trainer::{BpeTrainer, TrainerConfig},
SerializationError, TokenizerState,
};
use std::io::{BufReader, Cursor};
#[test]
fn base64_roundtrip_ascii() {
let original = b"Hello, World!";
let encoded = base64_encode(original);
let decoded = base64_decode(&encoded).expect("decode should succeed");
assert_eq!(decoded, original);
}
#[test]
fn base64_roundtrip_binary() {
let original: Vec<u8> = (0u8..=255u8).collect();
let encoded = base64_encode(&original);
let decoded = base64_decode(&encoded).expect("decode should succeed");
assert_eq!(decoded, original);
}
#[test]
fn base64_decode_invalid() {
let result = base64_decode("SGVsbG8!!");
assert!(result.is_err(), "expected error for invalid base64");
}
#[test]
fn tokenizer_state_new_empty() {
let state = TokenizerState::new();
assert_eq!(state.vocab_size(), 0);
assert!(state.vocab.is_empty());
assert!(state.merges.is_empty());
assert!(state.special_tokens.is_empty());
}
#[test]
fn tokenizer_state_save_load_empty() {
let state = TokenizerState::new();
let mut buf: Vec<u8> = Vec::new();
state.save_to(&mut buf).expect("save_to should succeed");
let mut reader = BufReader::new(buf.as_slice());
let loaded = TokenizerState::load_from(&mut reader).expect("load_from should succeed");
assert_eq!(loaded.vocab_size(), 0);
assert!(loaded.merges.is_empty());
assert!(loaded.special_tokens.is_empty());
}
#[test]
fn tokenizer_state_save_load_with_vocab() {
let mut state = TokenizerState::new();
state.vocab.insert(10, "hello".to_string());
state.vocab.insert(11, "world".to_string());
state.vocab.insert(12, "foo bar".to_string()); state.vocab.insert(13, "<unk>".to_string());
let mut buf: Vec<u8> = Vec::new();
state.save_to(&mut buf).expect("save should succeed");
let mut reader = BufReader::new(buf.as_slice());
let loaded = TokenizerState::load_from(&mut reader).expect("load should succeed");
assert_eq!(loaded.vocab_size(), 4);
assert_eq!(loaded.vocab.get(&10), Some(&"hello".to_string()));
assert_eq!(loaded.vocab.get(&11), Some(&"world".to_string()));
assert_eq!(loaded.vocab.get(&12), Some(&"foo bar".to_string()));
assert_eq!(loaded.vocab.get(&13), Some(&"<unk>".to_string()));
}
#[test]
fn tokenizer_state_save_load_with_merges() {
let mut state = TokenizerState::new();
state.vocab.insert(0, "a".to_string());
state.vocab.insert(1, "b".to_string());
state.vocab.insert(2, "ab".to_string());
state.merges.push((0, 1, 2));
state.merges.push((2, 0, 3));
let mut buf: Vec<u8> = Vec::new();
state.save_to(&mut buf).expect("save should succeed");
let mut reader = BufReader::new(buf.as_slice());
let loaded = TokenizerState::load_from(&mut reader).expect("load should succeed");
assert_eq!(loaded.merges.len(), 2);
assert_eq!(loaded.merges[0], (0, 1, 2));
assert_eq!(loaded.merges[1], (2, 0, 3));
}
#[test]
fn tokenizer_state_save_load_special_tokens() {
let mut state = TokenizerState::new();
state.vocab.insert(0, "<bos>".to_string());
state.vocab.insert(1, "<eos>".to_string());
state.special_tokens.insert("<bos>".to_string(), 0);
state.special_tokens.insert("<eos>".to_string(), 1);
let mut buf: Vec<u8> = Vec::new();
state.save_to(&mut buf).expect("save should succeed");
let mut reader = BufReader::new(buf.as_slice());
let loaded = TokenizerState::load_from(&mut reader).expect("load should succeed");
assert_eq!(loaded.special_tokens.get("<bos>"), Some(&0));
assert_eq!(loaded.special_tokens.get("<eos>"), Some(&1));
}
#[test]
fn tokenizer_state_save_tempfile() {
let mut state = TokenizerState::new();
state.vocab.insert(42, "rust".to_string());
let mut tmp = std::env::temp_dir();
tmp.push("oxibonsai_serialization_test.txt");
state.save(&tmp).expect("save to file should succeed");
let loaded = TokenizerState::load(&tmp).expect("load from file should succeed");
assert_eq!(loaded.vocab.get(&42), Some(&"rust".to_string()));
let _ = std::fs::remove_file(&tmp);
}
#[test]
fn tokenizer_state_invalid_magic() {
let bad_data = b"not a tokenizer\nvocab_size 0\nmerges 0\n";
let mut reader = BufReader::new(bad_data.as_ref());
let result = TokenizerState::load_from(&mut reader);
match result {
Err(SerializationError::InvalidMagic { .. }) => {} other => panic!("expected InvalidMagic, got: {other:?}"),
}
}
#[test]
fn tokenizer_state_from_trained() {
let mut trainer = BpeTrainer::new(TrainerConfig::new(300));
let corpus = ["hello world", "hello rust", "world rust"];
let trained = trainer.train(&corpus).expect("training should succeed");
let state = TokenizerState::from_trained(&trained);
assert!(state.vocab_size() > 0);
assert_eq!(state.vocab_size(), trained.vocab_size());
assert_eq!(state.merges.len(), trained.merges.len());
}
#[test]
fn tokenizer_state_to_oxi_tokenizer() {
let mut state = TokenizerState::new();
for (i, c) in (b'a'..=b'z').enumerate() {
state.vocab.insert(i as u32, (c as char).to_string());
}
let tokenizer = state.to_oxi_tokenizer();
assert_eq!(tokenizer.vocab_size(), 26);
let result = tokenizer.encode("abc");
let _ = result;
}
#[test]
fn base64_roundtrip_unicode() {
let text = "日本語テスト🦀";
let encoded = base64_encode(text.as_bytes());
let decoded = base64_decode(&encoded).expect("decode should succeed");
let restored = String::from_utf8(decoded).expect("must be valid UTF-8");
assert_eq!(restored, text);
}
#[test]
fn tokenizer_state_default_is_empty() {
let state = TokenizerState::default();
assert_eq!(state.vocab_size(), 0);
}
#[test]
fn save_load_roundtrip_via_cursor() {
let mut state = TokenizerState::new();
state.vocab.insert(0, "a".to_string());
state.vocab.insert(1, "b".to_string());
state.merges.push((0, 1, 2));
state.special_tokens.insert("<pad>".to_string(), 3);
state.vocab.insert(3, "<pad>".to_string());
let mut buf: Vec<u8> = Vec::new();
state.save_to(&mut buf).expect("save");
let cursor = Cursor::new(buf);
let mut reader = BufReader::new(cursor);
let loaded = TokenizerState::load_from(&mut reader).expect("load");
assert_eq!(loaded.vocab.get(&0), Some(&"a".to_string()));
assert_eq!(loaded.vocab.get(&1), Some(&"b".to_string()));
assert_eq!(loaded.merges, vec![(0, 1, 2)]);
assert_eq!(loaded.special_tokens.get("<pad>"), Some(&3));
}