use super::family::load_families;
use super::{FamilyInfo, FAMILY_ALIASES};
use std::path::Path;
pub(crate) const ALIAS_ARCHITECTURES: &[(&str, &str)] = &[
("bloom", "BloomForCausalLM"),
("bloomz", "BloomForCausalLM"),
("bloom_560m", "BloomForCausalLM"),
("falcon", "FalconForCausalLM"),
("mixtral", "MixtralForCausalLM"),
("phi3", "Phi3ForCausalLM"),
("phi3small", "Phi3SmallForCausalLM"),
("codellama", "LlamaForCausalLM"),
("vicuna", "LlamaForCausalLM"),
("solar", "LlamaForCausalLM"),
("nemotron", "LlamaForCausalLM"),
("olmo", "OlmoForCausalLM"),
("olmo2", "Olmo2ForCausalLM"),
("granite", "GraniteForCausalLM"),
("internlm2", "InternLM2ForCausalLM"),
("yi", "LlamaForCausalLM"),
("baichuan", "BaichuanForCausalLM"),
("stablelm", "StableLmForCausalLM"),
("starcoder2", "Starcoder2ForCausalLM"),
("codestral", "MistralForCausalLM"),
("mathstral", "MistralForCausalLM"),
("pixtral", "LlavaMistralForCausalLM"),
("zephyr", "MistralForCausalLM"),
("openchat", "MistralForCausalLM"),
("openhermes", "MistralForCausalLM"),
("qwen2_moe", "Qwen2MoeForCausalLM"),
("qwen2moe", "Qwen2MoeForCausalLM"),
("qwen3_moe", "Qwen3MoeForCausalLM"),
("qwen3moe", "Qwen3MoeForCausalLM"),
("deepseek_v2", "DeepseekV2ForCausalLM"),
("deepseekv2", "DeepseekV2ForCausalLM"),
("qwq", "Qwen2ForCausalLM"),
("bigscience", "BloomForCausalLM"),
("qwen3_next", "Qwen3ForCausalLM"),
("qwen3next", "Qwen3ForCausalLM"),
("smollm", "LlamaForCausalLM"),
("smollm2", "LlamaForCausalLM"),
("gpt_neo", "GPTNeoForCausalLM"),
("gptneo", "GPTNeoForCausalLM"),
("gpt_neox", "GPTNeoXForCausalLM"),
("gptneox", "GPTNeoXForCausalLM"),
("gpt_j", "GPTJForCausalLM"),
("gptj", "GPTJForCausalLM"),
("gpt_bigcode", "GPTBigCodeForCausalLM"),
("gptbigcode", "GPTBigCodeForCausalLM"),
("starcoder1", "GPTBigCodeForCausalLM"),
("codegen", "CodeGenForCausalLM"),
("xglm", "XGLMForCausalLM"),
("opt", "OPTForCausalLM"),
("galactica", "OPTForCausalLM"),
("roberta", "RobertaForMaskedLM"),
("deberta", "DebertaV2ForMaskedLM"),
("electra", "ElectraForPreTraining"),
("distilbert", "DistilBertModel"),
("tinyllama", "LlamaForCausalLM"),
("phi4", "Phi3ForCausalLM"),
("codegemma", "CodeGemmaForCausalLM"),
("gemma2", "Gemma2ForCausalLM"),
("gemma3", "Gemma3ForCausalLM"),
];
pub fn family_aliases() -> &'static [(&'static str, &'static str)] {
FAMILY_ALIASES
}
pub(crate) fn normalize_input(input: &str) -> String {
input
.to_lowercase()
.trim()
.replace('-', "_")
.replace('.', "_")
}
pub(crate) fn compact_input(normalized: &str) -> Option<String> {
let compact = normalized.replace('_', "");
if compact != normalized {
Some(compact)
} else {
None
}
}
pub fn resolve_family(input: &str) -> Option<FamilyInfo> {
let lower = input.to_lowercase();
let lower = lower.trim();
if lower.is_empty() {
return None;
}
let ascii_only: String = lower.chars().filter(|c| c.is_ascii()).collect();
let ascii_only = ascii_only.trim();
if ascii_only.is_empty() {
return None;
}
let lower = ascii_only;
let families = load_families();
let normalized = normalize_input(lower);
let compact = compact_input(&normalized);
if let Some(f) = families.iter().find(|f| {
f.family == lower
|| f.family == normalized
|| compact.as_deref().is_some_and(|c| f.family == c)
|| compact
.as_deref()
.is_some_and(|c| compact_input(&f.family).as_deref() == Some(c))
}) {
return Some(f.clone());
}
let alias_match = FAMILY_ALIASES
.iter()
.find(|(alias, _)| *alias == lower)
.or_else(|| {
FAMILY_ALIASES
.iter()
.find(|(alias, _)| *alias == normalized.as_str())
})
.or_else(|| {
compact
.as_deref()
.and_then(|c| FAMILY_ALIASES.iter().find(|(alias, _)| *alias == c))
});
if let Some((matched_alias, target)) = alias_match {
if let Some(f) = families.iter().find(|f| f.family == *target) {
let mut aliased = f.clone();
aliased.display_name = format!("{} (via {} kernel pipeline)", matched_alias, f.family);
return Some(aliased);
}
}
if let Some(f) = families.iter().find(|f| {
f.architectures
.iter()
.any(|a| a.to_lowercase() == lower || a == input)
}) {
return Some(f.clone());
}
let stripped = strip_arch_suffix(lower);
if stripped != lower {
if let Some((_, target)) = FAMILY_ALIASES.iter().find(|(alias, _)| *alias == stripped) {
if let Some(f) = families.iter().find(|f| f.family == *target) {
let mut aliased = f.clone();
aliased.display_name = format!("{stripped} (via {target} kernel pipeline)");
return Some(aliased);
}
}
if let Some(f) = families.iter().find(|f| f.family == stripped) {
return Some(f.clone());
}
}
let search_forms: Vec<&str> = {
let mut v = vec![];
if normalized.len() >= 3 {
v.push(normalized.as_str());
}
if let Some(ref c) = compact {
if c.len() >= 3 {
v.push(c.as_str());
}
}
v
};
if !search_forms.is_empty() {
for search in &search_forms {
if let Some((matched_alias, target)) = FAMILY_ALIASES
.iter()
.find(|(alias, _)| search.starts_with(alias))
{
if let Some(f) = families.iter().find(|f| f.family == *target) {
let mut aliased = f.clone();
aliased.display_name =
format!("{} (via {} kernel pipeline)", matched_alias, f.family);
return Some(aliased);
}
}
}
for search in &search_forms {
if let Some(f) = families
.iter()
.find(|f| f.family.starts_with(*search) || search.starts_with(f.family.as_str()))
{
return Some(f.clone());
}
}
for search in &search_forms {
if let Some((matched_alias, target)) = FAMILY_ALIASES
.iter()
.find(|(alias, _)| alias.starts_with(*search))
{
if let Some(f) = families.iter().find(|f| f.family == *target) {
let mut aliased = f.clone();
aliased.display_name =
format!("{} (via {} kernel pipeline)", matched_alias, f.family);
return Some(aliased);
}
}
}
}
None
}
pub(crate) fn strip_arch_suffix(s: &str) -> &str {
const SUFFIXES: &[&str] = &["forconditionalgeneration", "forcausallm", "model"];
for suffix in SUFFIXES {
if let Some(prefix) = s.strip_suffix(suffix) {
if !prefix.is_empty() {
return prefix;
}
}
}
s
}
pub fn resolve_from_config_json(path: &Path) -> Option<FamilyInfo> {
let content = std::fs::read_to_string(path).ok()?;
let trimmed = content.trim();
if trimmed.starts_with('[') {
return None;
}
let model_type = super::config::extract_json_string(&content, "model_type");
let model_type = match model_type {
Some(mt) => mt,
None => {
let arch =
super::config::extract_json_string(&content, "architectures").or_else(|| {
let pos = content.find("\"architectures\"")?;
let after = &content[pos..];
let bracket = after.find('[')?;
let inner = &after[bracket + 1..];
let quote_start = inner.find('"')?;
let rest = &inner[quote_start + 1..];
let quote_end = rest.find('"')?;
Some(rest[..quote_end].to_string())
})?;
let lower = arch.to_lowercase();
strip_arch_suffix(&lower).to_string()
}
};
resolve_family(&model_type)
}