use super::{Constraints, FamilyInfo, KernelClass, FAMILY_YAMLS};
pub(crate) fn yaml_str(text: &str, key: &str) -> Option<String> {
let search = format!("{key}:");
for line in text.lines() {
let trimmed = line.trim();
if trimmed.starts_with(&search) {
let val = trimmed[search.len()..].trim();
let val = val.split('#').next().unwrap_or(val).trim();
let val = val.trim_matches('"').trim_matches('\'');
if val.is_empty() || val == "null" {
return None;
}
return Some(val.to_string());
}
}
None
}
pub(crate) fn yaml_bool(text: &str, key: &str) -> bool {
yaml_str(text, key).map_or(false, |v| v == "true")
}
pub(crate) fn yaml_list(text: &str, key: &str) -> Vec<String> {
let search = format!("{key}:");
let mut in_section = false;
let mut items = Vec::new();
for line in text.lines() {
let trimmed = line.trim();
if trimmed.starts_with(&search) {
in_section = true;
continue;
}
if in_section {
if let Some(item) = trimmed.strip_prefix("- ") {
items.push(item.trim_matches('"').trim_matches('\'').to_string());
} else if !trimmed.is_empty() && !trimmed.starts_with('#') {
break;
}
}
}
items
}
pub(crate) fn extract_constraints(yaml_text: &str) -> Constraints {
let constraints_section = yaml_text
.find("\nconstraints:")
.map(|pos| &yaml_text[pos..])
.unwrap_or("");
Constraints {
attention_type: yaml_str(constraints_section, "attention_type").unwrap_or_default(),
activation: yaml_str(constraints_section, "activation").unwrap_or_default(),
norm_type: yaml_str(constraints_section, "norm_type").unwrap_or_default(),
mlp_type: yaml_str(constraints_section, "mlp_type").unwrap_or_default(),
positional_encoding: yaml_str(constraints_section, "positional_encoding")
.unwrap_or_default(),
has_bias: yaml_bool(constraints_section, "has_bias"),
tied_embeddings: yaml_bool(constraints_section, "tied_embeddings"),
}
}
pub(crate) fn derive_kernel_class(c: &Constraints) -> KernelClass {
let attn = c.attention_type.as_str();
let norm = c.norm_type.as_str();
let act = c.activation.as_str();
let mlp = c.mlp_type.as_str();
let pos = c.positional_encoding.as_str();
if (attn == "gqa" || attn == "mha")
&& norm == "rmsnorm"
&& act == "silu"
&& mlp == "swiglu"
&& pos == "rope"
{
return KernelClass::A;
}
if norm == "rmsnorm" && act == "gelu" && mlp == "gated_mlp" && pos == "rope" {
return KernelClass::F;
}
if attn == "mha" && norm == "layernorm" && act == "gelu" {
return KernelClass::B;
}
if attn == "mqa" && norm == "layernorm" && act == "gelu" && pos == "alibi" {
return KernelClass::C;
}
if norm == "layernorm" && (attn == "gqa" || act == "silu") {
return KernelClass::D;
}
if attn == "ssm" {
return KernelClass::Ssm;
}
if attn == "linear" {
return KernelClass::Linear;
}
KernelClass::Unknown
}
pub fn load_families() -> Vec<FamilyInfo> {
FAMILY_YAMLS
.iter()
.map(|(name, yaml_text)| {
let constraints = extract_constraints(yaml_text);
let kernel_class = derive_kernel_class(&constraints);
let display_name =
yaml_str(yaml_text, "display_name").unwrap_or_else(|| name.to_string());
let architectures = yaml_list(yaml_text, "architectures");
FamilyInfo {
family: (*name).to_string(),
display_name,
architectures,
constraints,
kernel_class,
}
})
.collect()
}