use super::resolve::{strip_arch_suffix, ALIAS_ARCHITECTURES};
use super::{ConfigField, FamilyInfo, KernelClass};
use std::collections::BTreeMap;
use std::path::Path;
pub fn extract_json_string(json: &str, key: &str) -> Option<String> {
let search = format!("\"{key}\"");
let pos = json.find(&search)?;
let after = &json[pos + search.len()..];
let after = after.trim_start().strip_prefix(':')?;
let after = after.trim_start();
if let Some(after) = after.strip_prefix('"') {
let end = after.find('"')?;
Some(after[..end].to_string())
} else if after.starts_with('[') || after.starts_with('{') {
None
} else {
let end = after.find(|c: char| c == ',' || c == '\n' || c == '}' || c == ' ')?;
let val = after[..end].trim();
if val.is_empty() || val == "null" {
None
} else {
Some(val.to_string())
}
}
}
pub fn extract_config_mapping(path: &Path) -> BTreeMap<String, ConfigField> {
let mut map = BTreeMap::new();
let Ok(content) = std::fs::read_to_string(path) else {
return map;
};
if let Some(pos) = content.find("\"architectures\"") {
let after = &content[pos..];
if let Some(bracket) = after.find('[') {
let inner = &after[bracket + 1..];
if let Some(quote_start) = inner.find('"') {
let rest = &inner[quote_start + 1..];
if let Some(quote_end) = rest.find('"') {
let arch = &rest[..quote_end];
map.insert(
"_architectures".to_string(),
ConfigField {
value: arch.to_string(),
rationale: "HuggingFace architecture class".to_string(),
},
);
}
}
}
}
let fields = [
("model_type", "Architecture class dispatch"),
("hidden_act", "Activation kernel selection"),
("rms_norm_eps", "RMSNorm (not LayerNorm)"),
("layer_norm_epsilon", "LayerNorm (not RMSNorm)"),
("layer_norm_eps", "LayerNorm (not RMSNorm)"),
("norm_epsilon", "Normalization epsilon"),
("num_key_value_heads", "GQA vs MHA vs MQA"),
("num_kv_heads", "GQA vs MHA (Falcon field name)"),
("multi_query", "MQA flag (Falcon-7B)"),
("num_attention_heads", "Number of query heads"),
("rope_theta", "RoPE positional encoding"),
("intermediate_size", "MLP width (SwiGLU detection)"),
("hidden_size", "Model hidden dimension"),
("num_hidden_layers", "Transformer depth"),
("num_local_experts", "MoE expert routing"),
("num_experts", "MoE expert routing"),
("n_routed_experts", "MoE expert routing (DeepSeek)"),
("num_experts_per_tok", "MoE active experts per token"),
("moe_intermediate_size", "MoE per-expert MLP width"),
("head_dim", "Explicit attention head dimension"),
(
"tie_word_embeddings",
"Weight sharing: embedding <-> lm_head",
),
("vocab_size", "Vocabulary size"),
("max_position_embeddings", "Maximum sequence length"),
];
for (key, rationale) in &fields {
if let Some(val) = extract_json_string(&content, key) {
let enriched = enrich_rationale(key, &val, &content);
map.insert(
(*key).to_string(),
ConfigField {
value: val,
rationale: enriched.unwrap_or_else(|| (*rationale).to_string()),
},
);
}
}
map
}
pub(crate) fn enrich_rationale(key: &str, value: &str, json: &str) -> Option<String> {
match key {
"hidden_act" => enrich_activation(value),
"rms_norm_eps" => Some("RMSNorm (not LayerNorm)".to_string()),
"num_key_value_heads" => enrich_kv_heads(value, json),
"rope_theta" => Some("RoPE positional encoding".to_string()),
"intermediate_size" => enrich_intermediate_size(value, json),
"num_local_experts" | "num_experts" | "n_routed_experts" => enrich_experts(value),
"num_experts_per_tok" => enrich_experts_per_tok(value),
"tie_word_embeddings" => enrich_tie_embeddings(value),
"num_attention_heads" => enrich_attention_heads(value, json),
"hidden_size" => enrich_hidden_size(value, json),
"num_hidden_layers" => enrich_hidden_layers(value),
"vocab_size" => enrich_vocab_size(value, json),
"max_position_embeddings" => enrich_max_position(value),
_ => None,
}
}
fn enrich_activation(value: &str) -> Option<String> {
match value {
"silu" => Some("SiLU activation (not GELU)".to_string()),
"gelu" | "gelu_new" | "gelu_pytorch_tanh" | "gelu_fast" => {
Some(format!("GELU activation: {value} (not SiLU)"))
}
_ => Some(format!("Activation: {value}")),
}
}
fn enrich_kv_heads(value: &str, json: &str) -> Option<String> {
let num_heads =
extract_json_string(json, "num_attention_heads").and_then(|v| v.parse::<u32>().ok());
let kv_heads = value.parse::<u32>().ok();
match (num_heads, kv_heads) {
(Some(h), Some(kv)) if kv == 1 => Some(format!("MQA ({kv} KV head < {h} Q heads)")),
(Some(h), Some(kv)) if kv < h => Some(format!("GQA ({kv} KV heads < {h} Q heads)")),
(Some(h), Some(kv)) if kv == h => Some(format!("MHA ({kv} KV heads == {h} Q heads)")),
_ => None,
}
}
fn enrich_intermediate_size(value: &str, json: &str) -> Option<String> {
let hidden = extract_json_string(json, "hidden_size").and_then(|v| v.parse::<f64>().ok());
let inter = value.parse::<f64>().ok();
let act = extract_json_string(json, "hidden_act")
.unwrap_or_default()
.to_lowercase();
let is_gelu = act.contains("gelu");
let is_silu = act == "silu" || act == "swish";
match (hidden, inter) {
(Some(h), Some(i)) if h > 0.0 => {
let ratio = i / h;
let mlp_type = if is_gelu {
"GELU FFN"
} else if is_silu {
"SwiGLU MLP"
} else if ratio > 2.5 {
"SwiGLU MLP"
} else {
"Standard FFN"
};
Some(format!("{mlp_type} ({i:.0}/{h:.0} = {ratio:.2}x)"))
}
_ => None,
}
}
fn enrich_experts(value: &str) -> Option<String> {
let n: i32 = value.parse().unwrap_or(0);
if n > 1 {
Some(format!("MoE with {n} experts (expert routing kernel)"))
} else if n == 1 {
Some("1 expert (dense model, not MoE)".to_string())
} else if n < 0 {
Some(format!("Invalid: {n} experts (negative)"))
} else {
None
}
}
fn enrich_experts_per_tok(value: &str) -> Option<String> {
let n: u32 = value.parse().unwrap_or(0);
if n > 0 {
let plural = if n == 1 { "expert" } else { "experts" };
Some(format!("{n} active {plural} per token"))
} else {
None
}
}
fn enrich_tie_embeddings(value: &str) -> Option<String> {
match value {
"true" => Some("Shared: embedding == lm_head (saves memory)".to_string()),
"false" => Some("Separate embedding and lm_head weights".to_string()),
_ => None,
}
}
fn enrich_attention_heads(value: &str, json: &str) -> Option<String> {
let kv = extract_json_string(json, "num_key_value_heads").and_then(|v| v.parse::<u32>().ok());
let n: u32 = value.parse().unwrap_or(0);
match kv {
Some(kv_n) if kv_n == 1 => Some(format!("{n} query heads, MQA (1 KV head)")),
Some(kv_n) if kv_n < n => {
let ratio = n / kv_n;
Some(format!(
"{n} query heads, GQA ({ratio} queries per KV group)"
))
}
Some(kv_n) if kv_n == n => Some(format!("{n} heads, MHA (no KV grouping)")),
_ => None,
}
}
fn enrich_hidden_size(value: &str, json: &str) -> Option<String> {
let n: u64 = value.parse().unwrap_or(0);
if n > 0 {
let params_est = estimate_params(n, json);
Some(format!("Hidden dim {n}{params_est}"))
} else {
None
}
}
fn estimate_params(n: u64, json: &str) -> String {
let Some(layers) =
extract_json_string(json, "num_hidden_layers").and_then(|v| v.parse::<u64>().ok())
else {
return String::new();
};
let inter = extract_json_string(json, "intermediate_size").and_then(|v| v.parse::<u64>().ok());
let vocab = extract_json_string(json, "vocab_size")
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(0);
let tied = extract_json_string(json, "tie_word_embeddings").map_or(false, |v| v == "true");
let embed_params = if tied { vocab * n } else { 2 * vocab * n };
let kv_heads =
extract_json_string(json, "num_key_value_heads").and_then(|v| v.parse::<u64>().ok());
let head_dim_val = extract_json_string(json, "head_dim")
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or_else(|| {
let nh = extract_json_string(json, "num_attention_heads")
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(1);
if nh > 0 {
n / nh
} else {
0
}
});
let num_heads = extract_json_string(json, "num_attention_heads")
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(1);
let kv_dim = kv_heads.map_or(n, |kv| kv * head_dim_val);
let attn_params = 2 * num_heads * head_dim_val * n + 2 * kv_dim * n;
let n_experts = extract_json_string(json, "num_local_experts")
.or_else(|| extract_json_string(json, "num_experts"))
.or_else(|| extract_json_string(json, "n_routed_experts"))
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(0);
let moe_inter =
extract_json_string(json, "moe_intermediate_size").and_then(|v| v.parse::<u64>().ok());
let est = if let Some(i) = inter {
let act = extract_json_string(json, "hidden_act")
.unwrap_or_default()
.to_lowercase();
let is_gated = act == "silu" || act == "swish" || act.contains("gegelu");
let mlp_factor = if is_gated { 3 } else { 2 };
let dense_mlp = mlp_factor * n * i;
let expert_mlp = if n_experts > 1 {
let ei = moe_inter.unwrap_or(i);
n_experts * mlp_factor * n * ei } else {
0
};
let mlp_total = if n_experts > 1 {
expert_mlp + dense_mlp
} else {
dense_mlp
};
layers * (attn_params + mlp_total + 2 * n) + embed_params
} else {
layers * 12 * n * n + embed_params
};
if est > 1_000_000_000 {
format!(", ~{:.1}B params", est as f64 / 1e9)
} else if est > 1_000_000 {
format!(", ~{:.0}M params", est as f64 / 1e6)
} else {
String::new()
}
}
fn enrich_hidden_layers(value: &str) -> Option<String> {
let n: u32 = value.parse().unwrap_or(0);
if n > 0 {
Some(format!("{n} transformer layers"))
} else {
None
}
}
fn enrich_vocab_size(value: &str, json: &str) -> Option<String> {
let n: u64 = value.parse().unwrap_or(0);
let hidden = extract_json_string(json, "hidden_size").and_then(|v| v.parse::<u64>().ok());
if let Some(h) = hidden {
let embed_mb = (n * h * 2) as f64 / 1_048_576.0; Some(format!("{n} tokens (embedding: {embed_mb:.0} MB at fp16)"))
} else if n > 0 {
Some(format!("{n} tokens"))
} else {
None
}
}
fn enrich_max_position(value: &str) -> Option<String> {
let n: u64 = value.parse().unwrap_or(0);
if n >= 1_048_576 {
Some(format!("{n} max seq len (1M+ context)"))
} else if n >= 524_288 {
Some(format!("{n} max seq len (512K+ context)"))
} else if n >= 262_144 {
Some(format!("{n} max seq len (256K+ context)"))
} else if n >= 131_072 {
Some(format!("{n} max seq len (128K+ context)"))
} else if n >= 32_768 {
Some(format!("{n} max seq len (32K+ context)"))
} else if n >= 8_192 {
Some(format!("{n} max seq len (8K+ context)"))
} else if n > 0 {
Some(format!("{n} max seq len"))
} else {
None
}
}
pub fn extract_architecture_display(
family: &FamilyInfo,
config_mapping: &BTreeMap<String, ConfigField>,
) -> String {
if let Some(arch) = config_mapping.get("_architectures") {
return arch.value.clone();
}
if let Some(mt) = config_mapping.get("model_type") {
return mt.value.clone();
}
if family.display_name.contains(" (via ") {
if let Some(alias_name) = family.display_name.split(" (via ").next() {
if let Some((_, hf_arch)) = ALIAS_ARCHITECTURES
.iter()
.find(|(alias, _)| *alias == alias_name)
{
return (*hf_arch).to_string();
}
return alias_name.to_string();
}
}
family
.architectures
.first()
.map_or("Unknown".to_string(), Clone::clone)
}
pub fn detect_constraint_mismatches(
family: &FamilyInfo,
config_mapping: &BTreeMap<String, ConfigField>,
) -> Vec<String> {
let mut warnings = Vec::new();
warnings.extend(check_type_mismatch(config_mapping));
warnings.extend(check_activation_mismatch(family, config_mapping));
warnings.extend(check_normalization_mismatch(family, config_mapping));
warnings.extend(check_attention_mismatch(family, config_mapping));
warnings.extend(check_moe_mismatch(family, config_mapping));
warnings.extend(check_dimension_validity(config_mapping));
warnings
}
fn check_type_mismatch(config_mapping: &BTreeMap<String, ConfigField>) -> Vec<String> {
let mut warnings = Vec::new();
if let Some(mt) = config_mapping.get("model_type") {
if let Some(arch) = config_mapping.get("_architectures") {
let arch_lower = arch.value.to_lowercase();
let mt_lower = mt.value.to_lowercase();
let arch_family = strip_arch_suffix(&arch_lower);
let mt_compact = mt_lower.replace('_', "");
let arch_compact = arch_family.replace('_', "");
if arch_family != mt_lower
&& arch_compact != mt_compact
&& !arch_lower.starts_with(&mt_lower)
{
warnings.push(format!(
"model_type '{}' conflicts with architectures ['{}']. Using model_type for dispatch.",
mt.value, arch.value
));
}
}
}
warnings
}
fn check_activation_mismatch(
family: &FamilyInfo,
config_mapping: &BTreeMap<String, ConfigField>,
) -> Vec<String> {
let mut warnings = Vec::new();
if let Some(act) = config_mapping.get("hidden_act") {
let config_act = act.value.to_lowercase();
let family_act = family.constraints.activation.to_lowercase();
let config_is_gelu = config_act.contains("gelu");
let family_is_gelu = family_act.contains("gelu");
let config_is_silu = config_act == "silu" || config_act == "swish";
let family_is_silu = family_act == "silu" || family_act == "swish";
if (config_is_gelu && family_is_silu) || (config_is_silu && family_is_gelu) {
warnings.push(format!(
"Activation mismatch: config.json has '{}' but family '{}' uses '{}'",
act.value, family.family, family.constraints.activation
));
}
if config_act == "gegelu" && family_act != "gegelu" {
warnings.push(format!(
"Activation variant: config.json uses 'gegelu' (Gated GELU) but family '{}' uses '{}'. Different kernel.",
family.family, family.constraints.activation
));
}
}
warnings
}
fn check_normalization_mismatch(
family: &FamilyInfo,
config_mapping: &BTreeMap<String, ConfigField>,
) -> Vec<String> {
let mut warnings = Vec::new();
let has_rms = config_mapping.contains_key("rms_norm_eps");
let has_ln = config_mapping.contains_key("layer_norm_epsilon")
|| config_mapping.contains_key("layer_norm_eps")
|| config_mapping.contains_key("norm_epsilon");
let family_norm = family.constraints.norm_type.to_lowercase();
if has_rms && has_ln {
warnings.push(
"Conflicting norm config: both rms_norm_eps (RMSNorm) and layer_norm_epsilon (LayerNorm) present. Only one should exist.".to_string()
);
} else if has_rms && !has_ln && family_norm == "layernorm" {
warnings.push(format!(
"Norm mismatch: config.json has rms_norm_eps (RMSNorm) but family '{}' uses LayerNorm",
family.family
));
} else if has_ln && !has_rms && family_norm == "rmsnorm" {
warnings.push(format!(
"Norm mismatch: config.json has layer_norm_epsilon (LayerNorm) but family '{}' uses RMSNorm",
family.family
));
}
warnings
}
fn check_attention_mismatch(
family: &FamilyInfo,
config_mapping: &BTreeMap<String, ConfigField>,
) -> Vec<String> {
let mut warnings = Vec::new();
let kv_field = config_mapping
.get("num_key_value_heads")
.or_else(|| config_mapping.get("num_kv_heads"));
if let Some(kv) = kv_field {
if let Some(q) = config_mapping.get("num_attention_heads") {
let kv_n: u32 = kv.value.parse().unwrap_or(0);
let q_n: u32 = q.value.parse().unwrap_or(0);
if kv_n > q_n && q_n > 0 {
warnings.push(format!(
"Invalid attention config: num_key_value_heads ({kv_n}) > num_attention_heads ({q_n}). KV heads cannot exceed query heads."
));
} else if kv_n > 0 && q_n > 0 && q_n % kv_n != 0 {
warnings.push(format!(
"Invalid GQA config: num_attention_heads ({q_n}) not divisible by num_key_value_heads ({kv_n}). GQA requires even grouping."
));
} else {
let config_attn = if kv_n == 1 {
"mqa"
} else if q_n > 0 && kv_n < q_n {
"gqa"
} else {
"mha"
};
let family_attn = family.constraints.attention_type.to_lowercase();
let is_mha_gqa_compat = config_attn == "mha" && family_attn == "gqa";
if config_attn != family_attn && !family_attn.is_empty() && !is_mha_gqa_compat {
warnings.push(format!(
"Attention mismatch: config.json implies {} but family '{}' uses {}",
config_attn.to_uppercase(),
family.family,
family.constraints.attention_type.to_uppercase()
));
}
}
}
} else if let Some(mq) = config_mapping.get("multi_query") {
if mq.value == "true" {
let family_attn = family.constraints.attention_type.to_lowercase();
if family_attn != "mqa" && !family_attn.is_empty() {
warnings.push(format!(
"Attention mismatch: config.json has multi_query=true (MQA) but family '{}' uses {}",
family.family,
family.constraints.attention_type.to_uppercase()
));
}
}
}
warnings
}
fn check_moe_mismatch(
family: &FamilyInfo,
config_mapping: &BTreeMap<String, ConfigField>,
) -> Vec<String> {
let mut warnings = Vec::new();
let expert_field = config_mapping
.get("num_local_experts")
.or_else(|| config_mapping.get("num_experts"))
.or_else(|| config_mapping.get("n_routed_experts"));
if let Some(ef) = expert_field {
let n_experts: i32 = ef.value.parse().unwrap_or(0);
if n_experts < 0 {
warnings.push(format!(
"Invalid config: expert count ({}) is negative.",
ef.value
));
} else if n_experts > 1 && family.kernel_class != KernelClass::E {
warnings.push(format!(
"MoE model ({n_experts} experts) mapped to non-MoE class {}. Expert routing kernel not covered.",
family.kernel_class.letter()
));
}
} else if family.kernel_class != KernelClass::E {
let dn = family.display_name.to_lowercase();
let is_alias = dn.contains(" (via ");
let alias_name = if is_alias {
dn.split(" (via ").next().unwrap_or("")
} else {
""
};
if family.family.contains("moe")
|| alias_name.contains("moe")
|| alias_name.contains("mixtral")
|| alias_name.contains("mixture")
{
warnings.push(format!(
"MoE architecture detected (from name) but mapped to non-MoE class {}. Expert routing kernel not covered.",
family.kernel_class.letter()
));
}
}
warnings
}
fn check_dimension_validity(config_mapping: &BTreeMap<String, ConfigField>) -> Vec<String> {
let mut warnings = Vec::new();
for (key, label) in &[
("hidden_size", "Hidden size"),
("num_attention_heads", "Attention heads"),
("num_hidden_layers", "Hidden layers"),
("vocab_size", "Vocabulary size"),
] {
if let Some(field) = config_mapping.get(*key) {
if let Ok(n) = field.value.parse::<i64>() {
if n < 0 {
warnings.push(format!(
"Invalid config: {label} ({key}={n}) is negative. Must be positive."
));
} else if n == 0 && (*key == "hidden_size" || *key == "num_attention_heads") {
warnings.push(format!(
"Invalid config: {label} ({key}=0) is zero. Would cause division by zero in kernel dispatch."
));
}
}
}
}
if let Some(field) = config_mapping.get("hidden_size") {
if let Ok(n) = field.value.parse::<u64>() {
if n > 100_000 {
warnings.push(format!(
"Implausible hidden_size={n}. Largest known models have hidden_size ~16384."
));
}
}
}
let has_explicit_head_dim = config_mapping.contains_key("head_dim");
if !has_explicit_head_dim {
if let (Some(hs), Some(nh)) = (
config_mapping.get("hidden_size"),
config_mapping.get("num_attention_heads"),
) {
if let (Ok(h), Ok(n)) = (hs.value.parse::<u64>(), nh.value.parse::<u64>()) {
if n > 0 && h > 0 && h % n != 0 {
warnings.push(format!(
"Invalid config: hidden_size ({h}) not divisible by num_attention_heads ({n}). Head dimension must be an integer."
));
}
}
}
}
warnings
}