#[cfg(test)]
mod tests {
use minbpe::test_common::{LLAMA_TEXT, SPECIAL_TOKENS};
use minbpe::AllowedSpecial;
use minbpe::BasicTokenizer;
use minbpe::Loadable;
use minbpe::RegexTokenizerStruct;
use minbpe::RegexTokenizerTrait;
use minbpe::Saveable;
use minbpe::Token;
use minbpe::Trainable;
use indexmap::IndexMap;
use tempfile::tempdir;
fn test_wikipedia_example_inner(tokenizer: &mut Box<dyn Trainable>) {
let text = "aaabdaaabac";
tokenizer.train(text, 256 + 3, false);
let ids = tokenizer.encode(text);
assert_eq!(ids, [258, 100, 258, 97, 99]);
let encoded = tokenizer.encode(text);
let decoded = tokenizer.decode(&encoded);
assert_eq!(decoded, text);
}
#[test]
fn test_wikipedia_example() {
let tokenizers: Vec<Box<dyn Trainable>> = vec![
Box::new(BasicTokenizer::new()),
Box::<RegexTokenizerStruct>::default(),
];
for mut tokenizer in tokenizers {
test_wikipedia_example_inner(&mut tokenizer);
}
}
fn test_save_load_inner(special_tokens: &IndexMap<String, Token>) {
let text = LLAMA_TEXT;
let mut tokenizer = RegexTokenizerStruct::default();
tokenizer.train(text, 256 + 64, false);
tokenizer.set_special_tokens(special_tokens.clone());
let encoded = tokenizer.encode_special(text, AllowedSpecial::All);
let decoded = tokenizer.decode(&encoded);
assert_eq!(decoded, text);
let dir = tempdir().unwrap();
tokenizer.save(dir.path(), "test_tokenizer_tmp");
let mut tokenizer = RegexTokenizerStruct::default();
let model_file = dir.path().join("test_tokenizer_tmp.model");
tokenizer.load(&model_file);
assert_eq!(tokenizer.decode(&encoded), text);
assert_eq!(
tokenizer.decode(&tokenizer.encode_special(text, AllowedSpecial::All)),
text
);
assert_eq!(tokenizer.encode_special(text, AllowedSpecial::All), encoded);
}
#[test]
fn test_save_load() {
let special_tokens = IndexMap::new();
test_save_load_inner(&special_tokens);
let special_tokens = &SPECIAL_TOKENS;
test_save_load_inner(special_tokens);
}
}