use splintr::{Tokenizer, LLAMA3_PATTERN};
use std::sync::LazyLock;
static TOKENIZER: LazyLock<Tokenizer> = LazyLock::new(create_llama3_tokenizer_impl);
#[test]
fn test_llama3_hello_world_tokens() {
let tokenizer = create_llama3_tokenizer();
let tokens = tokenizer.encode("Hello world");
assert_eq!(
tokens,
vec![9906, 1917],
"Token IDs for 'Hello world' changed"
);
}
#[test]
fn test_llama3_hello_world_punctuation_tokens() {
let tokenizer = create_llama3_tokenizer();
let tokens = tokenizer.encode("Hello, world!");
assert_eq!(
tokens,
vec![9906, 11, 1917, 0],
"Token IDs for 'Hello, world!' changed"
);
}
#[test]
fn test_llama3_chinese_tokens() {
let tokenizer = create_llama3_tokenizer();
let tokens = tokenizer.encode("你好世界");
assert_eq!(
tokens,
vec![57668, 53901, 102616],
"Token IDs for '你好世界' changed"
);
}
#[test]
fn test_llama3_emoji_tokens() {
let tokenizer = create_llama3_tokenizer();
let tokens = tokenizer.encode("Hello 🌍 World!");
assert_eq!(
tokens,
vec![9906, 11410, 234, 235, 4435, 0],
"Token IDs for emoji text changed"
);
}
#[test]
fn test_llama3_encode_decode_roundtrip() {
let tokenizer = create_llama3_tokenizer();
let test_cases = vec![
"Hello, world!",
"The quick brown fox jumps over the lazy dog.",
"Rust is a systems programming language.",
"1234567890",
"Special characters: !@#$%^&*()",
"Multi-line\ntext\nwith\nnewlines",
"Unicode: こんにちは 世界 🦀",
];
for text in test_cases {
let tokens = tokenizer.encode(text);
let decoded = tokenizer.decode(&tokens).unwrap();
assert_eq!(decoded, text, "Roundtrip failed for: {:?}", text);
}
}
#[test]
fn test_llama3_vocab_size() {
let tokenizer = create_llama3_tokenizer();
assert!(
tokenizer.vocab_size() >= 128000,
"Vocab size should be at least 128,000, got {}",
tokenizer.vocab_size()
);
}
#[test]
fn test_llama3_meta_special_tokens() {
let tokenizer = create_llama3_tokenizer();
let tokens = tokenizer.encode_with_special("<|begin_of_text|>Hello<|end_of_text|>");
assert!(
tokens.contains(&128000),
"Should contain begin_of_text (128000)"
);
assert!(
tokens.contains(&128001),
"Should contain end_of_text (128001)"
);
let tokens = tokenizer.encode_with_special("<|start_header_id|>system<|end_header_id|>");
assert!(
tokens.contains(&128006),
"Should contain start_header_id (128006)"
);
assert!(
tokens.contains(&128007),
"Should contain end_header_id (128007)"
);
let tokens = tokenizer.encode_with_special("<|eot_id|>");
assert!(tokens.contains(&128009), "Should contain eot_id (128009)");
}
#[test]
fn test_llama3_1_special_tokens() {
let tokenizer = create_llama3_tokenizer();
let tokens = tokenizer.encode_with_special("<|finetune_right_pad_id|>");
assert!(
tokens.contains(&128004),
"Should contain finetune_right_pad_id (128004)"
);
let tokens = tokenizer.encode_with_special("<|eom_id|>");
assert!(tokens.contains(&128008), "Should contain eom_id (128008)");
let tokens = tokenizer.encode_with_special("<|python_tag|>");
assert!(
tokens.contains(&128010),
"Should contain python_tag (128010)"
);
}
#[test]
fn test_llama3_agent_tokens() {
let tokenizer = create_llama3_tokenizer();
let tokens = tokenizer.encode_with_special("<|system|>You are helpful.<|user|>Hi<|assistant|>");
assert!(tokens.contains(&128300), "Should contain system (128300)");
assert!(tokens.contains(&128301), "Should contain user (128301)");
assert!(
tokens.contains(&128302),
"Should contain assistant (128302)"
);
let tokens = tokenizer.encode_with_special("<|think|>Let me reason...<|/think|>");
assert!(tokens.contains(&128305), "Should contain think (128305)");
assert!(
tokens.contains(&128306),
"Should contain think_end (128306)"
);
let tokens = tokenizer.encode_with_special("<|function|>get_weather<|/function|>");
assert!(tokens.contains(&128315), "Should contain function (128315)");
assert!(
tokens.contains(&128316),
"Should contain function_end (128316)"
);
}
#[test]
fn test_llama3_chat_format() {
let tokenizer = create_llama3_tokenizer();
let chat = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n";
let tokens = tokenizer.encode_with_special(chat);
assert!(tokens.contains(&128000)); assert!(tokens.contains(&128006)); assert!(tokens.contains(&128007)); assert!(tokens.contains(&128009));
let decoded = tokenizer.decode(&tokens).unwrap();
assert_eq!(decoded, chat);
}
#[test]
fn test_llama3_batch_encode() {
let tokenizer = create_llama3_tokenizer();
let texts = vec![
"Hello, world!".to_string(),
"How are you?".to_string(),
"I'm doing great!".to_string(),
];
let batch_tokens = tokenizer.encode_batch(&texts);
assert_eq!(batch_tokens.len(), 3);
for (i, text) in texts.iter().enumerate() {
let individual = tokenizer.encode(text);
assert_eq!(
batch_tokens[i], individual,
"Batch encoding should match individual encoding for text {}: {:?}",
i, text
);
}
}
#[test]
fn test_llama3_special_token_decode() {
let tokenizer = create_llama3_tokenizer();
let decoded = tokenizer.decode(&[128000]).unwrap();
assert_eq!(decoded, "<|begin_of_text|>");
let decoded = tokenizer.decode(&[128009]).unwrap();
assert_eq!(decoded, "<|eot_id|>");
let decoded = tokenizer.decode(&[128008]).unwrap();
assert_eq!(decoded, "<|eom_id|>");
let decoded = tokenizer.decode(&[128010]).unwrap();
assert_eq!(decoded, "<|python_tag|>");
}
#[test]
fn test_llama3_2_vision_tokens() {
let tokenizer = create_llama3_tokenizer();
let tokens = tokenizer.encode_with_special("<|step_id|>");
assert!(tokens.contains(&128005), "Should contain step_id (128005)");
let tokens = tokenizer.encode_with_special("<|image|>content<|/image|>");
assert!(tokens.contains(&128256), "Should contain image (128256)");
assert!(
tokens.contains(&128257),
"Should contain image_end (128257)"
);
let decoded = tokenizer.decode(&[128005]).unwrap();
assert_eq!(decoded, "<|step_id|>");
let decoded = tokenizer.decode(&[128256]).unwrap();
assert_eq!(decoded, "<|image|>");
}
#[test]
fn test_llama3_empty_input() {
let tokenizer = create_llama3_tokenizer();
let tokens = tokenizer.encode("");
assert!(tokens.is_empty(), "Empty input should produce empty tokens");
let decoded = tokenizer.decode(&[]).unwrap();
assert!(
decoded.is_empty(),
"Empty tokens should decode to empty string"
);
}
#[test]
fn test_llama3_from_pretrained_variants() {
let _t1 = create_llama3_tokenizer_by_name("llama3");
let _t2 = create_llama3_tokenizer_by_name("llama3.1");
let _t3 = create_llama3_tokenizer_by_name("llama3.2");
let _t4 = create_llama3_tokenizer_by_name("llama3.3");
let text = "Hello, world!";
let t1 = create_llama3_tokenizer_by_name("llama3");
let t2 = create_llama3_tokenizer_by_name("llama3.3");
assert_eq!(
t1.encode(text),
t2.encode(text),
"All Llama 3 variants should produce same encoding"
);
}
fn create_llama3_tokenizer() -> &'static Tokenizer {
&TOKENIZER
}
fn create_llama3_tokenizer_by_name(_name: &str) -> Tokenizer {
create_llama3_tokenizer_impl()
}
fn create_llama3_tokenizer_impl() -> Tokenizer {
let vocab_bytes = include_bytes!("../python/splintr/vocabs/llama3.tiktoken");
let mut special = rustc_hash::FxHashMap::default();
special.insert("<|begin_of_text|>".to_string(), 128000);
special.insert("<|end_of_text|>".to_string(), 128001);
special.insert("<|reserved_special_token_0|>".to_string(), 128002);
special.insert("<|reserved_special_token_1|>".to_string(), 128003);
special.insert("<|finetune_right_pad_id|>".to_string(), 128004);
special.insert("<|step_id|>".to_string(), 128005); special.insert("<|start_header_id|>".to_string(), 128006);
special.insert("<|end_header_id|>".to_string(), 128007);
special.insert("<|eom_id|>".to_string(), 128008);
special.insert("<|eot_id|>".to_string(), 128009);
special.insert("<|python_tag|>".to_string(), 128010);
special.insert("<|image|>".to_string(), 128256);
special.insert("<|/image|>".to_string(), 128257);
special.insert("<|system|>".to_string(), 128300);
special.insert("<|user|>".to_string(), 128301);
special.insert("<|assistant|>".to_string(), 128302);
special.insert("<|im_start|>".to_string(), 128303);
special.insert("<|im_end|>".to_string(), 128304);
special.insert("<|think|>".to_string(), 128305);
special.insert("<|/think|>".to_string(), 128306);
special.insert("<|plan|>".to_string(), 128307);
special.insert("<|/plan|>".to_string(), 128308);
special.insert("<|step|>".to_string(), 128309);
special.insert("<|/step|>".to_string(), 128310);
special.insert("<|act|>".to_string(), 128311);
special.insert("<|/act|>".to_string(), 128312);
special.insert("<|observe|>".to_string(), 128313);
special.insert("<|/observe|>".to_string(), 128314);
special.insert("<|function|>".to_string(), 128315);
special.insert("<|/function|>".to_string(), 128316);
special.insert("<|result|>".to_string(), 128317);
special.insert("<|/result|>".to_string(), 128318);
special.insert("<|error|>".to_string(), 128319);
special.insert("<|/error|>".to_string(), 128320);
special.insert("<|code|>".to_string(), 128321);
special.insert("<|/code|>".to_string(), 128322);
special.insert("<|output|>".to_string(), 128323);
special.insert("<|/output|>".to_string(), 128324);
Tokenizer::from_bytes(vocab_bytes, LLAMA3_PATTERN, special).unwrap()
}