#[must_use]
pub fn normalize_architecture(name: &str) -> &'static str {
match name {
"PhiForCausalLM" => "phi2",
"Phi3ForCausalLM" => "phi",
"LlamaForCausalLM" | "llama" => "llama",
"Qwen2ForCausalLM" | "qwen2" | "qwen2.5" | "qwen" => "qwen2",
"Qwen3ForCausalLM" | "qwen3" => "qwen3",
"Qwen3MoeForCausalLM" | "qwen3-moe" | "qwen3moe" => "qwen3moe",
"MistralForCausalLM" | "mistral" => "mistral",
"GPT2LMHeadModel" | "gpt2" => "gpt2",
"GPTNeoXForCausalLM" | "gpt_neox" | "gptneox" => "gpt_neox",
"GemmaForCausalLM" | "gemma" => "gemma",
"Gemma2ForCausalLM" | "gemma2" => "gemma2",
_ => "llama", }
}
#[must_use]
pub fn gpt2_global_names_for_embedding() -> Vec<&'static str> {
vec!["wte.weight", "transformer.wte.weight"]
}
#[must_use]
pub fn fused_qkv_template_count_for_arch(arch: &str) -> (usize, usize) {
match arch {
"gpt2" | "gpt_neox" => (1, 0),
_ => (0, 1),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Tname001Verdict { Pass, Fail }
#[must_use]
pub fn verdict_from_yaml_rust_parity(yaml_names: &[String], rust_names: &[String]) -> Tname001Verdict {
if yaml_names.is_empty() || rust_names.is_empty() { return Tname001Verdict::Fail; }
let mut a = yaml_names.to_vec();
let mut b = rust_names.to_vec();
a.sort();
b.sort();
if a == b { Tname001Verdict::Pass } else { Tname001Verdict::Fail }
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Tname002Verdict { Pass, Fail }
#[must_use]
pub fn verdict_from_arch_map_complete(
referenced: &[String],
mapped: &[String],
) -> Tname002Verdict {
if referenced.is_empty() { return Tname002Verdict::Fail; }
for arch in referenced {
if !mapped.iter().any(|m| m == arch) { return Tname002Verdict::Fail; }
}
Tname002Verdict::Pass
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Tname003Verdict { Pass, Fail }
#[must_use]
pub fn verdict_from_required_has_fallback(required_with_count: &[(bool, usize)]) -> Tname003Verdict {
if required_with_count.is_empty() { return Tname003Verdict::Fail; }
for &(is_required, count) in required_with_count {
if is_required && count == 0 { return Tname003Verdict::Fail; }
}
Tname003Verdict::Pass
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Tname004Verdict { Pass, Fail }
#[must_use]
pub fn verdict_from_phi_normalization() -> Tname004Verdict {
if normalize_architecture("PhiForCausalLM") != "phi2" { return Tname004Verdict::Fail; }
if normalize_architecture("Phi3ForCausalLM") != "phi" { return Tname004Verdict::Fail; }
Tname004Verdict::Pass
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Tname005Verdict { Pass, Fail }
#[must_use]
pub fn verdict_from_unknown_arch_default(unknown_name: &str) -> Tname005Verdict {
if normalize_architecture(unknown_name) == "llama" { Tname005Verdict::Pass } else { Tname005Verdict::Fail }
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Tname006Verdict { Pass, Fail }
#[must_use]
pub fn verdict_from_gpt2_embedding_names(observed: &[String]) -> Tname006Verdict {
let canonical = gpt2_global_names_for_embedding();
if observed.len() != canonical.len() { return Tname006Verdict::Fail; }
for name in canonical {
if !observed.iter().any(|o| o == name) { return Tname006Verdict::Fail; }
}
if observed.iter().any(|o| o.starts_with("model.")) { return Tname006Verdict::Fail; }
Tname006Verdict::Pass
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Tname007Verdict { Pass, Fail }
#[must_use]
pub fn verdict_from_fused_qkv_resolution(arch: &str) -> Tname007Verdict {
let (fused, separate_q) = fused_qkv_template_count_for_arch(arch);
match arch {
"gpt2" | "gpt_neox" => {
if fused > 0 && separate_q == 0 { Tname007Verdict::Pass } else { Tname007Verdict::Fail }
}
_ => {
if fused == 0 && separate_q > 0 { Tname007Verdict::Pass } else { Tname007Verdict::Fail }
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn s(items: &[&str]) -> Vec<String> { items.iter().map(|i| i.to_string()).collect() }
#[test] fn ref_phi_normalization() {
assert_eq!(normalize_architecture("PhiForCausalLM"), "phi2");
assert_eq!(normalize_architecture("Phi3ForCausalLM"), "phi");
}
#[test] fn ref_unknown_to_llama() {
assert_eq!(normalize_architecture("FutureArch2027"), "llama");
assert_eq!(normalize_architecture(""), "llama");
assert_eq!(normalize_architecture("xyz"), "llama");
}
#[test] fn ref_qwen_aliases() {
assert_eq!(normalize_architecture("Qwen2ForCausalLM"), "qwen2");
assert_eq!(normalize_architecture("qwen2.5"), "qwen2");
assert_eq!(normalize_architecture("Qwen3ForCausalLM"), "qwen3");
}
#[test] fn tname001_pass_match() {
let yaml = s(&["a", "b", "c"]);
let rust = s(&["c", "b", "a"]); assert_eq!(verdict_from_yaml_rust_parity(&yaml, &rust), Tname001Verdict::Pass);
}
#[test] fn tname001_fail_drift() {
let yaml = s(&["a", "b"]);
let rust = s(&["a", "c"]);
assert_eq!(verdict_from_yaml_rust_parity(&yaml, &rust), Tname001Verdict::Fail);
}
#[test] fn tname001_fail_empty() {
assert_eq!(verdict_from_yaml_rust_parity(&[], &s(&["a"])), Tname001Verdict::Fail);
}
#[test] fn tname002_pass_complete() {
let refd = s(&["llama", "qwen2"]);
let mapped = s(&["llama", "qwen2", "phi"]);
assert_eq!(verdict_from_arch_map_complete(&refd, &mapped), Tname002Verdict::Pass);
}
#[test] fn tname002_fail_missing() {
let refd = s(&["llama", "newarch"]);
let mapped = s(&["llama"]);
assert_eq!(verdict_from_arch_map_complete(&refd, &mapped), Tname002Verdict::Fail);
}
#[test] fn tname003_pass_all_required_have_fallback() {
let pairs = vec![(true, 3), (false, 0), (true, 1)];
assert_eq!(verdict_from_required_has_fallback(&pairs), Tname003Verdict::Pass);
}
#[test] fn tname003_fail_required_no_fallback() {
let pairs = vec![(true, 3), (true, 0)];
assert_eq!(verdict_from_required_has_fallback(&pairs), Tname003Verdict::Fail);
}
#[test] fn tname003_fail_empty() {
assert_eq!(verdict_from_required_has_fallback(&[]), Tname003Verdict::Fail);
}
#[test] fn tname004_pass() {
assert_eq!(verdict_from_phi_normalization(), Tname004Verdict::Pass);
}
#[test] fn tname005_pass_unknown() {
assert_eq!(verdict_from_unknown_arch_default("FutureArch2027"), Tname005Verdict::Pass);
}
#[test] fn tname005_pass_empty() {
assert_eq!(verdict_from_unknown_arch_default(""), Tname005Verdict::Pass);
}
#[test] fn tname005_fail_known_arch() {
assert_eq!(verdict_from_unknown_arch_default("qwen2"), Tname005Verdict::Fail);
}
#[test] fn tname006_pass_canonical() {
let names = s(&["wte.weight", "transformer.wte.weight"]);
assert_eq!(verdict_from_gpt2_embedding_names(&names), Tname006Verdict::Pass);
}
#[test] fn tname006_pass_reordered() {
let names = s(&["transformer.wte.weight", "wte.weight"]);
assert_eq!(verdict_from_gpt2_embedding_names(&names), Tname006Verdict::Pass);
}
#[test] fn tname006_fail_with_model_prefix() {
let names = s(&["model.embed_tokens.weight", "wte.weight"]);
assert_eq!(verdict_from_gpt2_embedding_names(&names), Tname006Verdict::Fail);
}
#[test] fn tname006_fail_short_list() {
let names = s(&["wte.weight"]);
assert_eq!(verdict_from_gpt2_embedding_names(&names), Tname006Verdict::Fail);
}
#[test] fn tname007_pass_gpt2_fused() {
assert_eq!(verdict_from_fused_qkv_resolution("gpt2"), Tname007Verdict::Pass);
}
#[test] fn tname007_pass_gpt_neox_fused() {
assert_eq!(verdict_from_fused_qkv_resolution("gpt_neox"), Tname007Verdict::Pass);
}
#[test] fn tname007_pass_llama_separate() {
assert_eq!(verdict_from_fused_qkv_resolution("llama"), Tname007Verdict::Pass);
}
#[test] fn tname007_pass_qwen2_separate() {
assert_eq!(verdict_from_fused_qkv_resolution("qwen2"), Tname007Verdict::Pass);
}
#[test] fn fused_qkv_helper_canonical() {
assert_eq!(fused_qkv_template_count_for_arch("gpt2"), (1, 0));
assert_eq!(fused_qkv_template_count_for_arch("llama"), (0, 1));
assert_eq!(fused_qkv_template_count_for_arch("qwen2"), (0, 1));
}
}