use super::*;
#[test]
fn test_gguf_non_gpt2_model_string() {
let mut data = Vec::new();
data.extend_from_slice(b"GGUF");
data.extend_from_slice(&3u32.to_le_bytes());
data.extend_from_slice(&0u64.to_le_bytes());
data.extend_from_slice(&6u64.to_le_bytes());
let key1 = b"tokenizer.ggml.tokens";
data.extend_from_slice(&(key1.len() as u64).to_le_bytes());
data.extend_from_slice(key1);
data.extend_from_slice(&9u32.to_le_bytes());
data.extend_from_slice(&8u32.to_le_bytes());
let tokens = ["<unk>", "<s>", "</s>"];
data.extend_from_slice(&(tokens.len() as u64).to_le_bytes());
for token in &tokens {
let bytes = token.as_bytes();
data.extend_from_slice(&(bytes.len() as u64).to_le_bytes());
data.extend_from_slice(bytes);
}
let key2 = b"tokenizer.ggml.scores";
data.extend_from_slice(&(key2.len() as u64).to_le_bytes());
data.extend_from_slice(key2);
data.extend_from_slice(&9u32.to_le_bytes());
data.extend_from_slice(&6u32.to_le_bytes());
data.extend_from_slice(&(tokens.len() as u64).to_le_bytes());
for _ in &tokens {
data.extend_from_slice(&0.0f32.to_le_bytes());
}
let key3 = b"tokenizer.ggml.model";
data.extend_from_slice(&(key3.len() as u64).to_le_bytes());
data.extend_from_slice(key3);
data.extend_from_slice(&8u32.to_le_bytes()); let model_str = b"llama";
data.extend_from_slice(&(model_str.len() as u64).to_le_bytes());
data.extend_from_slice(model_str);
let key4 = b"tokenizer.ggml.bos_token_id";
data.extend_from_slice(&(key4.len() as u64).to_le_bytes());
data.extend_from_slice(key4);
data.extend_from_slice(&4u32.to_le_bytes());
data.extend_from_slice(&1u32.to_le_bytes());
let key5 = b"tokenizer.ggml.eos_token_id";
data.extend_from_slice(&(key5.len() as u64).to_le_bytes());
data.extend_from_slice(key5);
data.extend_from_slice(&4u32.to_le_bytes());
data.extend_from_slice(&2u32.to_le_bytes());
let key6 = b"tokenizer.ggml.unknown_token_id";
data.extend_from_slice(&(key6.len() as u64).to_le_bytes());
data.extend_from_slice(key6);
data.extend_from_slice(&4u32.to_le_bytes());
data.extend_from_slice(&0u32.to_le_bytes());
let result = LlamaTokenizer::from_gguf_bytes(&data);
assert!(result.is_ok());
let tokenizer = result.unwrap();
assert_eq!(tokenizer.model(), TokenizerModel::SentencePiece);
}
pub(crate) fn create_gguf_with_extra_metadata(val_type: u32, val_bytes: &[u8]) -> Vec<u8> {
let mut data = Vec::new();
data.extend_from_slice(b"GGUF");
data.extend_from_slice(&3u32.to_le_bytes());
data.extend_from_slice(&0u64.to_le_bytes());
data.extend_from_slice(&6u64.to_le_bytes());
let key1 = b"tokenizer.ggml.tokens";
data.extend_from_slice(&(key1.len() as u64).to_le_bytes());
data.extend_from_slice(key1);
data.extend_from_slice(&9u32.to_le_bytes());
data.extend_from_slice(&8u32.to_le_bytes());
let tokens = ["<unk>", "<s>", "</s>"];
data.extend_from_slice(&(tokens.len() as u64).to_le_bytes());
for token in &tokens {
let bytes = token.as_bytes();
data.extend_from_slice(&(bytes.len() as u64).to_le_bytes());
data.extend_from_slice(bytes);
}
let key2 = b"tokenizer.ggml.scores";
data.extend_from_slice(&(key2.len() as u64).to_le_bytes());
data.extend_from_slice(key2);
data.extend_from_slice(&9u32.to_le_bytes());
data.extend_from_slice(&6u32.to_le_bytes());
data.extend_from_slice(&(tokens.len() as u64).to_le_bytes());
for _ in &tokens {
data.extend_from_slice(&0.0f32.to_le_bytes());
}
let key3 = b"general.extra";
data.extend_from_slice(&(key3.len() as u64).to_le_bytes());
data.extend_from_slice(key3);
data.extend_from_slice(&val_type.to_le_bytes());
data.extend_from_slice(val_bytes);
let key4 = b"tokenizer.ggml.bos_token_id";
data.extend_from_slice(&(key4.len() as u64).to_le_bytes());
data.extend_from_slice(key4);
data.extend_from_slice(&4u32.to_le_bytes());
data.extend_from_slice(&1u32.to_le_bytes());
let key5 = b"tokenizer.ggml.eos_token_id";
data.extend_from_slice(&(key5.len() as u64).to_le_bytes());
data.extend_from_slice(key5);
data.extend_from_slice(&4u32.to_le_bytes());
data.extend_from_slice(&2u32.to_le_bytes());
let key6 = b"tokenizer.ggml.unknown_token_id";
data.extend_from_slice(&(key6.len() as u64).to_le_bytes());
data.extend_from_slice(key6);
data.extend_from_slice(&4u32.to_le_bytes());
data.extend_from_slice(&0u32.to_le_bytes());
data
}
pub(crate) fn create_gguf_with_array_metadata(elem_type: u32, elem_bytes: &[u8]) -> Vec<u8> {
let elem_size = match elem_type {
0 | 1 | 7 => 1,
2 | 3 => 2,
4..=6 => 4,
10..=12 => 8,
_ => 1,
};
let count = elem_bytes.len() / elem_size;
let mut data = Vec::new();
data.extend_from_slice(b"GGUF");
data.extend_from_slice(&3u32.to_le_bytes());
data.extend_from_slice(&0u64.to_le_bytes());
data.extend_from_slice(&6u64.to_le_bytes());
let key1 = b"tokenizer.ggml.tokens";
data.extend_from_slice(&(key1.len() as u64).to_le_bytes());
data.extend_from_slice(key1);
data.extend_from_slice(&9u32.to_le_bytes());
data.extend_from_slice(&8u32.to_le_bytes());
let tokens = ["<unk>", "<s>", "</s>"];
data.extend_from_slice(&(tokens.len() as u64).to_le_bytes());
for token in &tokens {
let bytes = token.as_bytes();
data.extend_from_slice(&(bytes.len() as u64).to_le_bytes());
data.extend_from_slice(bytes);
}
let key2 = b"tokenizer.ggml.scores";
data.extend_from_slice(&(key2.len() as u64).to_le_bytes());
data.extend_from_slice(key2);
data.extend_from_slice(&9u32.to_le_bytes());
data.extend_from_slice(&6u32.to_le_bytes());
data.extend_from_slice(&(tokens.len() as u64).to_le_bytes());
for _ in &tokens {
data.extend_from_slice(&0.0f32.to_le_bytes());
}
let key3 = b"general.array";
data.extend_from_slice(&(key3.len() as u64).to_le_bytes());
data.extend_from_slice(key3);
data.extend_from_slice(&9u32.to_le_bytes()); data.extend_from_slice(&elem_type.to_le_bytes());
data.extend_from_slice(&(count as u64).to_le_bytes());
data.extend_from_slice(elem_bytes);
let key4 = b"tokenizer.ggml.bos_token_id";
data.extend_from_slice(&(key4.len() as u64).to_le_bytes());
data.extend_from_slice(key4);
data.extend_from_slice(&4u32.to_le_bytes());
data.extend_from_slice(&1u32.to_le_bytes());
let key5 = b"tokenizer.ggml.eos_token_id";
data.extend_from_slice(&(key5.len() as u64).to_le_bytes());
data.extend_from_slice(key5);
data.extend_from_slice(&4u32.to_le_bytes());
data.extend_from_slice(&2u32.to_le_bytes());
let key6 = b"tokenizer.ggml.unknown_token_id";
data.extend_from_slice(&(key6.len() as u64).to_le_bytes());
data.extend_from_slice(key6);
data.extend_from_slice(&4u32.to_le_bytes());
data.extend_from_slice(&0u32.to_le_bytes());
data
}
#[test]
fn test_decode_gpt2_skips_unknown_token_id() {
let tokens = vec![
"<unk>".to_string(),
"<|endoftext|>".to_string(),
"</s>".to_string(),
"Hello".to_string(),
];
let scores = vec![0.0; tokens.len()];
let mut tokenizer = LlamaTokenizer::new(tokens, scores, 1, 2, 0).unwrap();
tokenizer.set_model(TokenizerModel::Gpt2);
let decoded = tokenizer.decode(&[3, 9999]);
assert_eq!(decoded, "Hello");
}
#[test]
fn test_decode_sentencepiece_skips_unknown_token_id() {
let tokenizer = create_test_tokenizer();
let decoded = tokenizer.decode(&[3, 9999]);
assert_eq!(decoded, "Hello");
}
#[test]
fn test_encode_byte_fallback_no_byte_token_in_vocab() {
let tokens = vec!["<unk>".to_string(), "<s>".to_string(), "</s>".to_string()];
let scores = vec![0.0; tokens.len()];
let tokenizer = LlamaTokenizer::new(tokens, scores, 1, 2, 0).unwrap();
let encoded = tokenizer.encode("A");
assert!(!encoded.is_empty());
for &token_id in &encoded {
assert_eq!(token_id, tokenizer.unk_token_id());
}
}
#[test]
fn test_encode_gpt2_space_and_newline_normalization() {
let tokens = vec![
"<unk>".to_string(),
"<s>".to_string(),
"</s>".to_string(),
"Hello".to_string(),
"\u{0120}world".to_string(), "\u{010A}line".to_string(), ];
let scores = vec![0.0; tokens.len()];
let mut tokenizer = LlamaTokenizer::new(tokens, scores, 1, 2, 0).unwrap();
tokenizer.set_model(TokenizerModel::Gpt2);
let encoded = tokenizer.encode("Hello world");
assert!(!encoded.is_empty());
assert!(encoded.contains(&3));
}
#[test]
fn test_skip_value_unknown_type() {
let mut data = Vec::new();
data.extend_from_slice(b"GGUF");
data.extend_from_slice(&3u32.to_le_bytes());
data.extend_from_slice(&0u64.to_le_bytes());
data.extend_from_slice(&6u64.to_le_bytes());
let key1 = b"tokenizer.ggml.tokens";
data.extend_from_slice(&(key1.len() as u64).to_le_bytes());
data.extend_from_slice(key1);
data.extend_from_slice(&9u32.to_le_bytes());
data.extend_from_slice(&8u32.to_le_bytes());
let tokens = ["<unk>", "<s>", "</s>"];
data.extend_from_slice(&(tokens.len() as u64).to_le_bytes());
for token in &tokens {
let bytes = token.as_bytes();
data.extend_from_slice(&(bytes.len() as u64).to_le_bytes());
data.extend_from_slice(bytes);
}
let key2 = b"tokenizer.ggml.scores";
data.extend_from_slice(&(key2.len() as u64).to_le_bytes());
data.extend_from_slice(key2);
data.extend_from_slice(&9u32.to_le_bytes());
data.extend_from_slice(&6u32.to_le_bytes());
data.extend_from_slice(&(tokens.len() as u64).to_le_bytes());
for _ in &tokens {
data.extend_from_slice(&0.0f32.to_le_bytes());
}
let key3 = b"general.unknown_type";
data.extend_from_slice(&(key3.len() as u64).to_le_bytes());
data.extend_from_slice(key3);
data.extend_from_slice(&99u32.to_le_bytes());
let key4 = b"tokenizer.ggml.bos_token_id";
data.extend_from_slice(&(key4.len() as u64).to_le_bytes());
data.extend_from_slice(key4);
data.extend_from_slice(&4u32.to_le_bytes());
data.extend_from_slice(&1u32.to_le_bytes());
let key5 = b"tokenizer.ggml.eos_token_id";
data.extend_from_slice(&(key5.len() as u64).to_le_bytes());
data.extend_from_slice(key5);
data.extend_from_slice(&4u32.to_le_bytes());
data.extend_from_slice(&2u32.to_le_bytes());
let key6 = b"tokenizer.ggml.unknown_token_id";
data.extend_from_slice(&(key6.len() as u64).to_le_bytes());
data.extend_from_slice(key6);
data.extend_from_slice(&4u32.to_le_bytes());
data.extend_from_slice(&0u32.to_le_bytes());
let result = LlamaTokenizer::from_gguf_bytes(&data);
let _ = result;
}
#[test]
fn test_skip_value_array_unknown_elem_type() {
let mut data = Vec::new();
data.extend_from_slice(b"GGUF");
data.extend_from_slice(&3u32.to_le_bytes());
data.extend_from_slice(&0u64.to_le_bytes());
data.extend_from_slice(&6u64.to_le_bytes());
let key1 = b"tokenizer.ggml.tokens";
data.extend_from_slice(&(key1.len() as u64).to_le_bytes());
data.extend_from_slice(key1);
data.extend_from_slice(&9u32.to_le_bytes());
data.extend_from_slice(&8u32.to_le_bytes());
let tokens = ["<unk>", "<s>", "</s>"];
data.extend_from_slice(&(tokens.len() as u64).to_le_bytes());
for token in &tokens {
let bytes = token.as_bytes();
data.extend_from_slice(&(bytes.len() as u64).to_le_bytes());
data.extend_from_slice(bytes);
}
let key2 = b"tokenizer.ggml.scores";
data.extend_from_slice(&(key2.len() as u64).to_le_bytes());
data.extend_from_slice(key2);
data.extend_from_slice(&9u32.to_le_bytes());
data.extend_from_slice(&6u32.to_le_bytes());
data.extend_from_slice(&(tokens.len() as u64).to_le_bytes());
for _ in &tokens {
data.extend_from_slice(&0.0f32.to_le_bytes());
}
let key3 = b"general.weird_array";
data.extend_from_slice(&(key3.len() as u64).to_le_bytes());
data.extend_from_slice(key3);
data.extend_from_slice(&9u32.to_le_bytes()); data.extend_from_slice(&99u32.to_le_bytes()); data.extend_from_slice(&0u64.to_le_bytes());
let key4 = b"tokenizer.ggml.bos_token_id";
data.extend_from_slice(&(key4.len() as u64).to_le_bytes());
data.extend_from_slice(key4);
data.extend_from_slice(&4u32.to_le_bytes());
data.extend_from_slice(&1u32.to_le_bytes());
let key5 = b"tokenizer.ggml.eos_token_id";
data.extend_from_slice(&(key5.len() as u64).to_le_bytes());
data.extend_from_slice(key5);
data.extend_from_slice(&4u32.to_le_bytes());
data.extend_from_slice(&2u32.to_le_bytes());
let key6 = b"tokenizer.ggml.unknown_token_id";
data.extend_from_slice(&(key6.len() as u64).to_le_bytes());
data.extend_from_slice(key6);
data.extend_from_slice(&4u32.to_le_bytes());
data.extend_from_slice(&0u32.to_le_bytes());
let result = LlamaTokenizer::from_gguf_bytes(&data);
assert!(result.is_ok());
}
#[test]
fn test_decode_sentencepiece_leading_space_removal() {
let tokens = vec![
"<unk>".to_string(),
"<s>".to_string(),
"</s>".to_string(),
"▁Hello".to_string(),
"▁world".to_string(),
];
let scores = vec![0.0; tokens.len()];
let tokenizer = LlamaTokenizer::new(tokens, scores, 1, 2, 0).unwrap();
let decoded = tokenizer.decode(&[3, 4]);
assert_eq!(decoded, "Hello world");
}