use super::*;
use std::sync::atomic::{AtomicU64, Ordering};
fn fresh_dir(tag: &str) -> PathBuf {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
let dir = std::env::temp_dir().join(format!("mlxrs-lm-gguf-{tag}-{}-{n}", std::process::id()));
let _ = std::fs::remove_dir_all(&dir);
std::fs::create_dir_all(&dir).unwrap();
dir
}
#[test]
fn translate_weight_names_table_matches_python_reference() {
let cases: &[(&str, &str)] = &[
(
"model.layers.0.input_layernorm.weight",
"blk.0.attn_norm.weight",
),
(
"model.layers.12.post_attention_layernorm.weight",
"blk.12.ffn_norm.weight",
),
(
"model.layers.3.block_sparse_moe.gate.weight",
"blk.3.ffn_gate_inp.weight",
),
(
"model.layers.3.block_sparse_moe.experts.0.w1.weight",
"blk.3.ffn_gate.0.weight",
),
(
"model.layers.3.block_sparse_moe.experts.7.w2.weight",
"blk.3.ffn_down.7.weight",
),
(
"model.layers.3.block_sparse_moe.experts.15.w3.weight",
"blk.3.ffn_up.15.weight",
),
(
"model.layers.1.mlp.gate_proj.weight",
"blk.1.ffn_gate.weight",
),
(
"model.layers.1.mlp.down_proj.weight",
"blk.1.ffn_down.weight",
),
("model.layers.1.mlp.up_proj.weight", "blk.1.ffn_up.weight"),
(
"model.layers.2.self_attn.q_proj.weight",
"blk.2.attn_q.weight",
),
(
"model.layers.2.self_attn.k_proj.weight",
"blk.2.attn_k.weight",
),
(
"model.layers.2.self_attn.v_proj.weight",
"blk.2.attn_v.weight",
),
(
"model.layers.2.self_attn.o_proj.weight",
"blk.2.attn_output.weight",
),
(
"model.layers.5.input_layernorm.weight",
"blk.5.attn_norm.weight",
),
(
"model.layers.5.post_attention_layernorm.weight",
"blk.5.ffn_norm.weight",
),
("model.embed_tokens.weight", "token_embd.weight"),
("model.norm.weight", "output_norm.weight"),
("lm_head.weight", "output.weight"),
];
for (input, expected) in cases {
assert_eq!(
&translate_weight_names(input),
expected,
"translate_weight_names({input:?}) mismatch",
);
}
assert_eq!(
translate_weight_names("some.unrelated.key"),
"some.unrelated.key"
);
}
#[test]
fn permute_weights_q_k_matches_python_reference() {
let data: Vec<f32> = (0..8).map(|x| x as f32).collect();
let w = Array::from_slice::<f32>(&data, &(8_usize, 1)).unwrap();
let mut out = permute_weights(&w, 2, Some(2)).unwrap();
assert_eq!(out.shape(), vec![8, 1]);
assert_eq!(
out.to_vec::<f32>().unwrap(),
vec![0.0, 2.0, 1.0, 3.0, 4.0, 6.0, 5.0, 7.0]
);
}
#[test]
fn permute_weights_kv_overrides_n_head() {
let data: Vec<f32> = (0..4).map(|x| x as f32).collect();
let w = Array::from_slice::<f32>(&data, &(4_usize, 1)).unwrap();
let mut out = permute_weights(&w, 4, Some(2)).unwrap();
assert_eq!(out.shape(), vec![4, 1]);
assert_eq!(out.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0, 3.0]);
}
#[test]
fn permute_weights_rejects_invalid_leading_dim() {
let w = Array::from_slice::<f32>(&[0.0; 6], &(6_usize, 1)).unwrap();
let err = permute_weights(&w, 4, Some(4)).unwrap_err();
let msg = format!("{err:?}");
assert!(msg.contains("permute_weights"), "{msg}");
}
fn write_tokenizer_fixture(dir: &std::path::Path) -> crate::tokenizer::Tokenizer {
use serde_json::json;
let tok = json!({
"version": "1.0",
"model": {
"type": "BPE",
"vocab": {
"<unk>": 0,
"<s>": 1,
"</s>": 2,
"a": 3,
},
"merges": []
},
"added_tokens": [
{"id": 0, "content": "<unk>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true},
{"id": 1, "content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true},
{"id": 2, "content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true},
{"id": 100, "content": "<extra>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": false},
],
});
std::fs::write(dir.join("tokenizer.json"), tok.to_string()).unwrap();
let cfg = json!({
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
});
std::fs::write(dir.join("tokenizer_config.json"), cfg.to_string()).unwrap();
crate::tokenizer::Tokenizer::from_path(dir, None).unwrap()
}
#[test]
fn hf_vocab_to_gguf_round_trip() {
let dir = fresh_dir("hf_vocab");
let tokenizer = write_tokenizer_fixture(&dir);
let vocab = HfVocab::from_tokenizer(&tokenizer).unwrap();
assert_eq!(vocab.vocab_size_base(), 4);
assert_eq!(vocab.vocab_size(), 5);
let triples = vocab.all_tokens().unwrap();
assert_eq!(triples.len(), 5);
assert_eq!(triples[0].2, TokenType::Control);
assert_eq!(triples[1].2, TokenType::Control);
assert_eq!(triples[2].2, TokenType::Control);
assert_eq!(triples[3].2, TokenType::Normal);
assert_eq!(triples[4].0, "<extra>");
assert_eq!(triples[4].2, TokenType::UserDefined);
for (_, score, _) in &triples {
assert!((score - -1000.0).abs() < 1e-6, "score {score} != -1000.0");
}
let mut config_json = serde_json::json!({
"model_type": "llama",
"hidden_size": 8,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"head_dim": 4,
"rope_theta": 10000.0,
"vocab_size": 5,
"tie_word_embeddings": false,
"intermediate_size": 16,
"max_position_embeddings": 32,
"rms_norm_eps": 1e-5,
});
let raw_json = serde_json::to_string(&config_json).unwrap();
let config = Config::from_json(&raw_json).unwrap();
config_json = serde_json::from_str(&raw_json).unwrap();
let meta = prepare_metadata(&config, &config_json, &vocab).unwrap();
assert!(meta.contains_key("tokenizer.ggml.tokens"));
assert!(meta.contains_key("tokenizer.ggml.scores"));
assert!(meta.contains_key("tokenizer.ggml.token_type"));
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn prepare_metadata_minimal_llama_config() {
let dir = fresh_dir("prep_meta");
let tokenizer = write_tokenizer_fixture(&dir);
let vocab = HfVocab::from_tokenizer(&tokenizer).unwrap();
let config_text = serde_json::json!({
"model_type": "llama",
"hidden_size": 16,
"num_hidden_layers": 4,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"head_dim": 4,
"rope_theta": 500_000.0,
"vocab_size": 5,
"tie_word_embeddings": false,
"intermediate_size": 64,
"max_position_embeddings": 128,
"rms_norm_eps": 1e-5,
"num_local_experts": 8,
"num_experts_per_tok": 2,
"_name_or_path": "foo/bar-7b",
"rope_scaling": { "type": "linear", "factor": 2.0 },
})
.to_string();
let raw_json: serde_json::Value = serde_json::from_str(&config_text).unwrap();
let config = Config::from_json(&config_text).unwrap();
let meta = prepare_metadata(&config, &raw_json, &vocab).unwrap();
fn unwrap_u32_scalar(m: &HashMap<String, GgufMetadata>, key: &str) -> u32 {
match m.get(key) {
Some(GgufMetadata::Array(a)) => {
let mut a = a.try_clone().unwrap();
a.to_vec::<u32>().unwrap()[0]
}
Some(_) => panic!("metadata key {key} was not a scalar array"),
None => panic!("missing metadata key {key}"),
}
}
fn unwrap_f32_scalar(m: &HashMap<String, GgufMetadata>, key: &str) -> f32 {
match m.get(key) {
Some(GgufMetadata::Array(a)) => {
let mut a = a.try_clone().unwrap();
a.to_vec::<f32>().unwrap()[0]
}
Some(_) => panic!("metadata key {key} was not a scalar array"),
None => panic!("missing metadata key {key}"),
}
}
fn unwrap_string(m: &HashMap<String, GgufMetadata>, key: &str) -> String {
match m.get(key) {
Some(GgufMetadata::String(s)) => s.clone(),
Some(_) => panic!("metadata key {key} was not a string"),
None => panic!("missing metadata key {key}"),
}
}
assert_eq!(unwrap_u32_scalar(&meta, "llama.context_length"), 128);
assert_eq!(unwrap_u32_scalar(&meta, "llama.embedding_length"), 16);
assert_eq!(unwrap_u32_scalar(&meta, "llama.block_count"), 4);
assert_eq!(unwrap_u32_scalar(&meta, "llama.feed_forward_length"), 64);
assert_eq!(unwrap_u32_scalar(&meta, "llama.rope.dimension_count"), 4);
assert_eq!(unwrap_u32_scalar(&meta, "llama.attention.head_count"), 4);
assert_eq!(unwrap_u32_scalar(&meta, "llama.attention.head_count_kv"), 2);
assert_eq!(unwrap_u32_scalar(&meta, "llama.expert_count"), 8);
assert_eq!(unwrap_u32_scalar(&meta, "llama.expert_used_count"), 2);
assert!((unwrap_f32_scalar(&meta, "llama.attention.layer_norm_rms_epsilon") - 1e-5).abs() < 1e-9);
assert!((unwrap_f32_scalar(&meta, "llama.rope.freq_base") - 500_000.0).abs() < 1e-3);
assert_eq!(unwrap_string(&meta, "llama.rope.scaling.type"), "linear",);
assert!((unwrap_f32_scalar(&meta, "llama.rope.scaling.factor") - 2.0).abs() < 1e-6);
assert_eq!(unwrap_u32_scalar(&meta, "general.file_type"), 1);
assert_eq!(unwrap_u32_scalar(&meta, "general.quantization_version"), 1);
assert_eq!(unwrap_u32_scalar(&meta, "general.alignment"), 32);
assert_eq!(unwrap_string(&meta, "general.architecture"), "llama");
assert_eq!(unwrap_string(&meta, "general.name"), "bar-7b");
assert_eq!(unwrap_string(&meta, "tokenizer.ggml.model"), "llama");
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn convert_to_gguf_end_to_end_minimal() {
let dir = fresh_dir("e2e");
let _ = write_tokenizer_fixture(&dir);
let config = serde_json::json!({
"model_type": "llama",
"hidden_size": 4,
"num_hidden_layers": 1,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"head_dim": 2,
"rope_theta": 10000.0,
"vocab_size": 5,
"tie_word_embeddings": false,
"intermediate_size": 8,
"max_position_embeddings": 16,
"rms_norm_eps": 1e-5,
});
std::fs::write(dir.join("config.json"), config.to_string()).unwrap();
let mut weights: HashMap<String, Array> = HashMap::new();
let w4x4 = || Array::from_slice::<f32>(&[0.5_f32; 16], &(4_usize, 4)).unwrap();
let w8x4 = || Array::from_slice::<f32>(&[0.25_f32; 32], &(8_usize, 4)).unwrap();
let w4x8 = || Array::from_slice::<f32>(&[0.125_f32; 32], &(4_usize, 8)).unwrap();
let n4 = || Array::from_slice::<f32>(&[1.0_f32; 4], &(4_usize,)).unwrap();
let e5x4 = || Array::from_slice::<f32>(&[0.0_f32; 20], &(5_usize, 4)).unwrap();
weights.insert("model.embed_tokens.weight".into(), e5x4());
weights.insert("model.layers.0.input_layernorm.weight".into(), n4());
weights.insert(
"model.layers.0.post_attention_layernorm.weight".into(),
n4(),
);
weights.insert("model.layers.0.self_attn.q_proj.weight".into(), w4x4());
weights.insert("model.layers.0.self_attn.k_proj.weight".into(), w4x4());
weights.insert("model.layers.0.self_attn.v_proj.weight".into(), w4x4());
weights.insert("model.layers.0.self_attn.o_proj.weight".into(), w4x4());
weights.insert("model.layers.0.mlp.gate_proj.weight".into(), w8x4());
weights.insert("model.layers.0.mlp.up_proj.weight".into(), w8x4());
weights.insert("model.layers.0.mlp.down_proj.weight".into(), w4x8());
weights.insert("model.norm.weight".into(), n4());
weights.insert("lm_head.weight".into(), e5x4());
crate::io::save_safetensors(&dir.join("model.safetensors"), &weights).unwrap();
let gguf_path = dir.join("out.gguf");
convert_to_gguf(&ConvertToGgufArgs {
model_path: dir.clone(),
gguf_path: gguf_path.clone(),
})
.unwrap();
assert!(gguf_path.exists(), "gguf file not written");
let (loaded_weights, _meta) = crate::io::load_gguf(&gguf_path).unwrap();
let expected_keys: std::collections::BTreeSet<&str> = [
"token_embd.weight",
"blk.0.attn_norm.weight",
"blk.0.ffn_norm.weight",
"blk.0.attn_q.weight",
"blk.0.attn_k.weight",
"blk.0.attn_v.weight",
"blk.0.attn_output.weight",
"blk.0.ffn_gate.weight",
"blk.0.ffn_up.weight",
"blk.0.ffn_down.weight",
"output_norm.weight",
"output.weight",
]
.iter()
.copied()
.collect();
let got_keys: std::collections::BTreeSet<&str> =
loaded_weights.keys().map(String::as_str).collect();
assert_eq!(got_keys, expected_keys, "weight name set mismatch");
for norm_key in [
"blk.0.attn_norm.weight",
"blk.0.ffn_norm.weight",
"output_norm.weight",
] {
let a = loaded_weights.get(norm_key).unwrap();
assert_eq!(a.dtype().unwrap(), Dtype::F32, "{norm_key} should be F32");
}
let _ = std::fs::remove_dir_all(&dir);
}
fn write_sentinel_weights(dir: &std::path::Path) {
let garbage = vec![0xAB_u8; 1024 * 1024];
std::fs::write(dir.join("model.safetensors"), &garbage).unwrap();
}
fn assert_no_safetensors_load_signature(msg: &str) {
for needle in [
"safetensors",
"load_safetensors",
"header",
"deserializ",
"mlx_load",
] {
assert!(
!msg.to_lowercase().contains(needle),
"unexpected weight-load signature {needle:?} in error: {msg}"
);
}
}
#[test]
fn convert_to_gguf_rejects_unsupported_arch() {
let dir = fresh_dir("reject_arch");
let _ = write_tokenizer_fixture(&dir);
let config = serde_json::json!({
"model_type": "qwen3",
"hidden_size": 4,
"num_hidden_layers": 1,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"head_dim": 2,
"rope_theta": 10000.0,
"vocab_size": 5,
"tie_word_embeddings": false,
"intermediate_size": 8,
"max_position_embeddings": 16,
"rms_norm_eps": 1e-5,
});
std::fs::write(dir.join("config.json"), config.to_string()).unwrap();
write_sentinel_weights(&dir);
let gguf_path = dir.join("out.gguf");
let err = convert_to_gguf(&ConvertToGgufArgs {
model_path: dir.clone(),
gguf_path,
})
.unwrap_err();
let Error::UnknownEnumValue(p) = &err else {
panic!("expected Error::UnknownEnumValue for unsupported arch, got {err:?}");
};
assert_eq!(p.value(), "qwen3");
assert!(
p.type_name().contains("model_type"),
"type_name should name the rejected field: {}",
p.type_name()
);
let msg = format!("{err:?}");
assert_no_safetensors_load_signature(&msg);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn convert_to_gguf_rejects_quantized() {
let dir = fresh_dir("reject_quant");
let _ = write_tokenizer_fixture(&dir);
let config = serde_json::json!({
"model_type": "llama",
"hidden_size": 4,
"num_hidden_layers": 1,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"head_dim": 2,
"rope_theta": 10000.0,
"vocab_size": 5,
"tie_word_embeddings": false,
"intermediate_size": 8,
"max_position_embeddings": 16,
"rms_norm_eps": 1e-5,
"quantization": { "group_size": 64, "bits": 4 },
});
std::fs::write(dir.join("config.json"), config.to_string()).unwrap();
write_sentinel_weights(&dir);
let gguf_path = dir.join("out.gguf");
let err = convert_to_gguf(&ConvertToGgufArgs {
model_path: dir.clone(),
gguf_path,
})
.unwrap_err();
let Error::InvariantViolation(p) = &err else {
panic!("expected Error::InvariantViolation for quantized checkpoint, got {err:?}");
};
assert_eq!(p.context(), "convert_to_gguf: checkpoint quantization");
assert!(p.requirement().contains("must be None"));
let msg = format!("{err:?}");
assert_no_safetensors_load_signature(&msg);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn convert_to_gguf_rejects_quantization_config_key() {
let dir = fresh_dir("reject_quant_cfg");
let _ = write_tokenizer_fixture(&dir);
let config = serde_json::json!({
"model_type": "llama",
"hidden_size": 4,
"num_hidden_layers": 1,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"head_dim": 2,
"rope_theta": 10000.0,
"vocab_size": 5,
"tie_word_embeddings": false,
"intermediate_size": 8,
"max_position_embeddings": 16,
"rms_norm_eps": 1e-5,
"quantization_config": { "group_size": 64, "bits": 4 },
});
std::fs::write(dir.join("config.json"), config.to_string()).unwrap();
write_sentinel_weights(&dir);
let gguf_path = dir.join("out.gguf");
let err = convert_to_gguf(&ConvertToGgufArgs {
model_path: dir.clone(),
gguf_path,
})
.unwrap_err();
let Error::InvariantViolation(p) = &err else {
panic!("expected Error::InvariantViolation for quantized checkpoint, got {err:?}");
};
assert_eq!(p.context(), "convert_to_gguf: checkpoint quantization");
assert!(p.requirement().contains("must be None"));
let msg = format!("{err:?}");
assert_no_safetensors_load_signature(&msg);
let _ = std::fs::remove_dir_all(&dir);
}
fn write_base_vocab_special_fixture(dir: &std::path::Path) -> crate::tokenizer::Tokenizer {
use serde_json::json;
let tok = json!({
"version": "1.0",
"model": {
"type": "BPE",
"vocab": {
"<unk>": 0,
"<s>": 1,
"</s>": 2,
"<pad>": 3,
"a": 4,
"b": 5,
},
"merges": []
},
"added_tokens": [],
});
std::fs::write(dir.join("tokenizer.json"), tok.to_string()).unwrap();
let cfg = json!({
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
"pad_token": "<pad>",
"additional_special_tokens": ["b"],
});
std::fs::write(dir.join("tokenizer_config.json"), cfg.to_string()).unwrap();
crate::tokenizer::Tokenizer::from_path(dir, None).unwrap()
}
#[test]
fn convert_to_gguf_uses_base_vocab_special_token_ids() {
let dir = fresh_dir("base_vocab_specials");
let tokenizer = write_base_vocab_special_fixture(&dir);
let vocab = HfVocab::from_tokenizer(&tokenizer).unwrap();
assert!(vocab.special_ids.contains(&0), "unk (id 0) missing");
assert!(vocab.special_ids.contains(&1), "bos (id 1) missing");
assert!(vocab.special_ids.contains(&2), "eos (id 2) missing");
assert!(vocab.special_ids.contains(&3), "pad (id 3) missing");
assert!(vocab.special_ids.contains(&5), "additional 'b' missing");
assert!(
!vocab.special_ids.contains(&4),
"plain 'a' must NOT be classified Control"
);
assert_eq!(vocab.get_token_type(0, "<unk>"), TokenType::Control);
assert_eq!(vocab.get_token_type(1, "<s>"), TokenType::Control);
assert_eq!(vocab.get_token_type(2, "</s>"), TokenType::Control);
assert_eq!(vocab.get_token_type(3, "<pad>"), TokenType::Control);
assert_eq!(vocab.get_token_type(4, "a"), TokenType::Normal);
assert_eq!(vocab.get_token_type(5, "b"), TokenType::Control);
let config_text = serde_json::json!({
"model_type": "llama",
"hidden_size": 8,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"head_dim": 4,
"rope_theta": 10000.0,
"vocab_size": 6,
"tie_word_embeddings": false,
"intermediate_size": 16,
"max_position_embeddings": 32,
"rms_norm_eps": 1e-5,
})
.to_string();
let raw_json: serde_json::Value = serde_json::from_str(&config_text).unwrap();
let config = Config::from_json(&config_text).unwrap();
let meta = prepare_metadata(&config, &raw_json, &vocab).unwrap();
let toktype_vals = match meta.get("tokenizer.ggml.token_type").unwrap() {
GgufMetadata::Array(a) => {
let mut a = a.try_clone().unwrap();
a.to_vec::<u32>().unwrap()
}
_ => panic!("token_type was not an Array"),
};
assert_eq!(toktype_vals.len(), 6);
assert_eq!(toktype_vals[0], TokenType::Control as u32, "unk (id 0)");
assert_eq!(toktype_vals[1], TokenType::Control as u32, "bos (id 1)");
assert_eq!(toktype_vals[2], TokenType::Control as u32, "eos (id 2)");
assert_eq!(toktype_vals[3], TokenType::Control as u32, "pad (id 3)");
assert_eq!(toktype_vals[4], TokenType::Normal as u32, "'a' (id 4)");
assert_eq!(
toktype_vals[5],
TokenType::Control as u32,
"additional 'b' (id 5)"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn convert_to_gguf_special_ids_unions_added_and_base_vocab() {
let dir = fresh_dir("union_specials");
use serde_json::json;
let tok = json!({
"version": "1.0",
"model": {
"type": "BPE",
"vocab": {
"<unk>": 0,
"<s>": 1,
"a": 2,
"<extra>": 3, },
"merges": []
},
"added_tokens": [
{"id": 100, "content": "<added>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true},
],
});
std::fs::write(dir.join("tokenizer.json"), tok.to_string()).unwrap();
let cfg = json!({
"bos_token": "<s>",
"additional_special_tokens": ["<extra>"],
});
std::fs::write(dir.join("tokenizer_config.json"), cfg.to_string()).unwrap();
let tokenizer = crate::tokenizer::Tokenizer::from_path(&dir, None).unwrap();
let added_id = tokenizer
.hf()
.get_added_tokens_decoder()
.iter()
.find(|(_, t)| t.content == "<added>")
.map(|(id, _)| *id)
.expect("`<added>` must appear in added_tokens_decoder");
let vocab = HfVocab::from_tokenizer(&tokenizer).unwrap();
assert!(
vocab.special_ids.contains(&added_id),
"added <added> (id {added_id}) missing — source (a) failed; special_ids={:?}",
vocab.special_ids,
);
assert!(
vocab.special_ids.contains(&1),
"bos <s> (base id 1) missing — source (b) failed; special_ids={:?}",
vocab.special_ids,
);
assert!(
vocab.special_ids.contains(&3),
"additional <extra> (base id 3) missing — source (b) failed; special_ids={:?}",
vocab.special_ids,
);
assert!(
!vocab.special_ids.contains(&2),
"plain 'a' (id 2) should not be in special_ids; special_ids={:?}",
vocab.special_ids,
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn convert_to_gguf_added_token_via_additional_special_tokens_classifies_as_control() {
let dir = fresh_dir("added_via_additional_special");
use serde_json::json;
let tok = json!({
"version": "1.0",
"model": {
"type": "BPE",
"vocab": {
"<unk>": 0,
"<s>": 1,
"</s>": 2,
"a": 3,
},
"merges": []
},
"added_tokens": [
{"id": 100, "content": "<custom>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": false},
],
});
std::fs::write(dir.join("tokenizer.json"), tok.to_string()).unwrap();
let cfg = json!({
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
"additional_special_tokens": ["<custom>"],
});
std::fs::write(dir.join("tokenizer_config.json"), cfg.to_string()).unwrap();
let tokenizer = crate::tokenizer::Tokenizer::from_path(&dir, None).unwrap();
let vocab = HfVocab::from_tokenizer(&tokenizer).unwrap();
let custom_id = tokenizer
.hf()
.get_added_tokens_decoder()
.iter()
.find(|(_, t)| t.content == "<custom>")
.map(|(id, _)| *id)
.expect("`<custom>` must appear in added_tokens_decoder");
assert!(
custom_id >= vocab.vocab_size_base(),
"`<custom>` id {custom_id} should be past base vocab ({})",
vocab.vocab_size_base(),
);
assert!(
vocab.special_ids.contains(&custom_id),
"special-ids union failed: special_ids should contain `<custom>` id {custom_id}; \
special_ids={:?}",
vocab.special_ids,
);
assert!(
!vocab.specials.contains_key("<custom>"),
"fixture invariant: `<custom>` should NOT be in `specials` (the gap this test covers); \
specials={:?}",
vocab.specials,
);
let config_text = serde_json::json!({
"model_type": "llama",
"hidden_size": 8,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"head_dim": 4,
"rope_theta": 10000.0,
"vocab_size": vocab.vocab_size(),
"tie_word_embeddings": false,
"intermediate_size": 16,
"max_position_embeddings": 32,
"rms_norm_eps": 1e-5,
})
.to_string();
let raw_json: serde_json::Value = serde_json::from_str(&config_text).unwrap();
let config = Config::from_json(&config_text).unwrap();
let meta = prepare_metadata(&config, &raw_json, &vocab).unwrap();
let toktype_vals = match meta.get("tokenizer.ggml.token_type").unwrap() {
GgufMetadata::Array(a) => {
let mut a = a.try_clone().unwrap();
a.to_vec::<u32>().unwrap()
}
_ => panic!("token_type was not an Array"),
};
assert_eq!(toktype_vals.len() as u32, vocab.vocab_size());
assert_eq!(
toktype_vals[custom_id as usize],
TokenType::Control as u32,
"`<custom>` (id {custom_id}) should classify as Control, \
got {} (UserDefined would be {}); full token_type={:?}",
toktype_vals[custom_id as usize],
TokenType::UserDefined as u32,
toktype_vals,
);
assert_ne!(
toktype_vals[custom_id as usize],
TokenType::UserDefined as u32,
"explicit not-UserDefined check",
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn convert_to_gguf_malformed_tokenizer_rejects_before_weight_load() {
let dir = fresh_dir("malformed_tokenizer");
let config = serde_json::json!({
"model_type": "llama",
"hidden_size": 4,
"num_hidden_layers": 1,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"head_dim": 2,
"rope_theta": 10000.0,
"vocab_size": 5,
"tie_word_embeddings": false,
"intermediate_size": 8,
"max_position_embeddings": 16,
"rms_norm_eps": 1e-5,
});
std::fs::write(dir.join("config.json"), config.to_string()).unwrap();
std::fs::write(
dir.join("tokenizer.json"),
"{ this is not valid tokenizer json }",
)
.unwrap();
write_sentinel_weights(&dir);
let gguf_path = dir.join("out.gguf");
let err = convert_to_gguf(&ConvertToGgufArgs {
model_path: dir.clone(),
gguf_path,
})
.unwrap_err();
let msg = format!("{err:?}");
assert!(
msg.to_lowercase().contains("tokenizer"),
"error should name tokenizer-loading failure; got: {msg}"
);
assert_no_safetensors_load_signature(&msg);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn convert_to_gguf_directory_at_tokenizer_path_rejects_before_weight_load() {
let dir = fresh_dir("dir_at_tokenizer");
let config = serde_json::json!({
"model_type": "llama",
"hidden_size": 4,
"num_hidden_layers": 1,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"head_dim": 2,
"rope_theta": 10000.0,
"vocab_size": 5,
"tie_word_embeddings": false,
"intermediate_size": 8,
"max_position_embeddings": 16,
"rms_norm_eps": 1e-5,
});
std::fs::write(dir.join("config.json"), config.to_string()).unwrap();
std::fs::create_dir_all(dir.join("tokenizer.json")).unwrap();
write_sentinel_weights(&dir);
let gguf_path = dir.join("out.gguf");
let err = convert_to_gguf(&ConvertToGgufArgs {
model_path: dir.clone(),
gguf_path,
})
.unwrap_err();
let msg = format!("{err:?}");
assert!(
msg.to_lowercase().contains("tokenizer"),
"error should name tokenizer-loading failure; got: {msg}"
);
assert_no_safetensors_load_signature(&msg);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn is_byte_token_classifier() {
assert!(is_byte_token("<0x0A>"));
assert!(is_byte_token("<0xff>"));
assert!(is_byte_token("<0xAB>"));
assert!(!is_byte_token("<0xZ>"));
assert!(!is_byte_token("<0x0AB>"));
assert!(!is_byte_token("0x0A"));
assert!(!is_byte_token("<unk>"));
}