use std::path::Path;
use std::sync::OnceLock;
use serde::{Deserialize, Serialize};
use crate::diagnostics::shape::ShapeExpr;
use crate::diagnostics::DiagnosticError;
pub const MAX_INHERITANCE_DEPTH: usize = 4;
pub const BUILTIN_PROFILE_SOURCES: &[(&str, &str)] = &[
("llama3", include_str!("../../profiles/llama3.toml")),
("qwen3", include_str!("../../profiles/qwen3.toml")),
("qwen35moe", include_str!("../../profiles/qwen35moe.toml")),
];
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Profile {
pub name: String,
pub architecture: String,
#[serde(default)]
pub extends: Option<String>,
#[serde(default)]
pub symbols: Vec<SymbolDef>,
#[serde(default)]
pub required_metadata: Vec<RequiredMetadata>,
#[serde(default)]
pub expected_tensors: Vec<ExpectedTensor>,
#[serde(default)]
pub optional_tensors: Vec<ExpectedTensor>,
#[serde(default)]
pub hints: Vec<Hint>,
#[serde(default)]
pub allowed_metadata_prefixes: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SymbolDef {
pub name: String,
pub source: SymbolSource,
}
#[derive(Debug, Clone, PartialEq)]
pub enum SymbolSource {
Metadata(String),
Derived(String),
}
impl SymbolSource {
pub fn as_wire_string(&self) -> String {
match self {
SymbolSource::Metadata(k) => format!("metadata:{k}"),
SymbolSource::Derived(e) => format!("derived:{e}"),
}
}
}
impl Serialize for SymbolSource {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&self.as_wire_string())
}
}
impl<'de> Deserialize<'de> for SymbolSource {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
match s.split_once(':') {
Some(("metadata", rest)) => Ok(SymbolSource::Metadata(rest.to_string())),
Some(("derived", rest)) => Ok(SymbolSource::Derived(rest.to_string())),
_ => Err(serde::de::Error::custom(format!(
"symbol source `{s}` must be `metadata:<key>` or `derived:<expr>`"
))),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct RequiredMetadata {
pub key: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ExpectedTensor {
pub name: String,
#[serde(default)]
pub shape: Option<Vec<ShapeExpr>>,
#[serde(default)]
pub per_layer: Option<String>,
#[serde(default)]
pub dtype: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Hint {
#[serde(default)]
pub when_missing: Vec<String>,
#[serde(default)]
pub when_unexpected: Vec<String>,
pub message: String,
#[serde(default)]
pub name: Option<String>,
}
impl Profile {
pub fn parse_toml(source: &str) -> Result<Self, DiagnosticError> {
toml::from_str(source).map_err(|e| DiagnosticError::LoadProfile(e.to_string()))
}
pub fn resolve<F>(mut self, lookup: &F) -> Result<Self, DiagnosticError>
where
F: Fn(&str) -> Option<String>,
{
let mut chain: Vec<String> = vec![self.name.clone()];
let mut current = self.extends.take();
while let Some(base_name) = current {
if chain.contains(&base_name) {
chain.push(base_name);
return Err(DiagnosticError::CyclicProfileInheritance {
chain: chain.join(" -> "),
});
}
chain.push(base_name.clone());
if chain.len() > MAX_INHERITANCE_DEPTH + 1 {
return Err(DiagnosticError::LoadProfile(format!(
"inheritance chain exceeds max depth {}: {}",
MAX_INHERITANCE_DEPTH,
chain.join(" -> ")
)));
}
let base_src = lookup(&base_name).ok_or(
DiagnosticError::UnknownProfileBase {
name: self.name.clone(),
base: base_name,
},
)?;
let mut base = Profile::parse_toml(&base_src)?;
current = base.extends.take();
self = merge(base, self);
}
Ok(self)
}
}
fn merge(mut base: Profile, child: Profile) -> Profile {
let merged_name = child.name;
let merged_arch = child.architecture;
let merged_extends = child.extends;
let mut symbols = base.symbols;
for c in child.symbols {
if let Some(pos) = symbols.iter().position(|s| s.name == c.name) {
symbols[pos] = c;
} else {
symbols.push(c);
}
}
for c in child.required_metadata {
if !base
.required_metadata
.iter()
.any(|b| b.key == c.key)
{
base.required_metadata.push(c);
}
}
let mut expected = base.expected_tensors;
for c in child.expected_tensors {
if let Some(pos) = expected.iter().position(|t| t.name == c.name) {
expected[pos] = c;
} else {
expected.push(c);
}
}
let mut optional = base.optional_tensors;
optional.retain(|o| !expected.iter().any(|e| e.name == o.name));
for c in child.optional_tensors {
if expected.iter().any(|e| e.name == c.name) {
continue;
}
if let Some(pos) = optional.iter().position(|t| t.name == c.name) {
optional[pos] = c;
} else {
optional.push(c);
}
}
let mut hints = base.hints;
hints.extend(child.hints);
let mut prefixes = base.allowed_metadata_prefixes;
for p in child.allowed_metadata_prefixes {
if !prefixes.contains(&p) {
prefixes.push(p);
}
}
Profile {
name: merged_name,
architecture: merged_arch,
extends: merged_extends,
symbols,
required_metadata: base.required_metadata,
expected_tensors: expected,
optional_tensors: optional,
hints,
allowed_metadata_prefixes: prefixes,
}
}
static RESOLVED_BUILTINS: OnceLock<Vec<(&'static str, Profile)>> = OnceLock::new();
fn resolved_builtins() -> &'static [(&'static str, Profile)] {
RESOLVED_BUILTINS.get_or_init(|| {
BUILTIN_PROFILE_SOURCES
.iter()
.map(|(name, source)| {
let profile = Profile::parse_toml(source)
.and_then(|p| p.resolve(&builtin_source))
.unwrap_or_else(|e| {
panic!("built-in profile `{name}` failed to load: {e}")
});
(*name, profile)
})
.collect()
})
}
pub fn load_builtin_profile(name: &str) -> Result<Profile, DiagnosticError> {
resolved_builtins()
.iter()
.find(|(n, _)| *n == name)
.map(|(_, p)| p.clone())
.ok_or_else(|| DiagnosticError::LoadProfile(format!("no built-in profile named `{name}`")))
}
pub fn find_profile_for_architecture(
arch: &str,
) -> Result<Option<Profile>, DiagnosticError> {
let best = resolved_builtins()
.iter()
.filter(|(_, p)| p.architecture == arch)
.max_by_key(|(_, p)| p.expected_tensors.len())
.map(|(_, p)| p.clone());
Ok(best)
}
pub fn load_profile_file(path: &Path) -> Result<Profile, DiagnosticError> {
let source = std::fs::read_to_string(path)
.map_err(|e| DiagnosticError::LoadProfile(format!("read {}: {e}", path.display())))?;
let raw = Profile::parse_toml(&source)?;
raw.resolve(&builtin_source)
}
pub fn list_builtin_profile_names() -> Vec<&'static str> {
BUILTIN_PROFILE_SOURCES.iter().map(|(n, _)| *n).collect()
}
fn builtin_source(name: &str) -> Option<String> {
BUILTIN_PROFILE_SOURCES
.iter()
.find(|(n, _)| *n == name)
.map(|(_, s)| (*s).to_string())
}
#[cfg(test)]
mod tests {
use super::*;
fn parse(src: &str) -> Profile {
Profile::parse_toml(src).expect("parse")
}
#[test]
fn parses_minimal_profile() {
let src = r#"
name = "llama3"
architecture = "llama"
"#;
let p = parse(src);
assert_eq!(p.name, "llama3");
assert_eq!(p.architecture, "llama");
assert!(p.extends.is_none());
assert!(p.symbols.is_empty());
}
#[test]
fn parses_tensor_with_shape_expressions() {
let src = r#"
name = "qwen2"
architecture = "qwen2"
[[expected_tensors]]
name = "blk.{layer}.attn_q.weight"
per_layer = "n_layers"
shape = ["n_heads * head_dim", "hidden"]
"#;
let p = parse(src);
assert_eq!(p.expected_tensors.len(), 1);
let t = &p.expected_tensors[0];
assert_eq!(t.name, "blk.{layer}.attn_q.weight");
assert_eq!(t.per_layer.as_deref(), Some("n_layers"));
let shape = t.shape.as_ref().unwrap();
assert_eq!(shape[0].as_source(), "n_heads * head_dim");
assert_eq!(shape[1].as_source(), "hidden");
}
#[test]
fn parses_hints() {
let src = r#"
name = "qwen35moe"
architecture = "qwen35moe"
[[hints]]
name = "ffn_rename"
when_missing = ["blk.*.ffn_norm.weight"]
when_unexpected = ["blk.*.post_attention_norm.weight"]
message = "FFN norm renamed."
"#;
let p = parse(src);
assert_eq!(p.hints.len(), 1);
let h = &p.hints[0];
assert_eq!(h.name.as_deref(), Some("ffn_rename"));
assert_eq!(h.when_missing, vec!["blk.*.ffn_norm.weight"]);
assert_eq!(h.when_unexpected, vec!["blk.*.post_attention_norm.weight"]);
assert_eq!(h.message, "FFN norm renamed.");
}
#[test]
fn resolve_with_no_extends_is_identity() {
let src = r#"
name = "x"
architecture = "x"
"#;
let p = parse(src);
let resolved = p.clone().resolve(&|_| None).unwrap();
assert_eq!(resolved, p);
}
#[test]
fn resolve_inherits_symbols_and_tensors() {
let base = r#"
name = "base"
architecture = "qwen3moe"
[[symbols]]
name = "hidden"
source = "metadata:qwen3moe.embedding_length"
[[expected_tensors]]
name = "token_embd.weight"
shape = ["vocab", "hidden"]
[[optional_tensors]]
name = "blk.{layer}.post_attention_norm.weight"
per_layer = "n_layers"
"#;
let child = r#"
name = "child"
architecture = "qwen35moe"
extends = "base"
[[expected_tensors]]
name = "blk.{layer}.post_attention_norm.weight"
per_layer = "n_layers"
shape = ["hidden"]
[[hints]]
when_unexpected = ["blk.*.ssm_beta.weight"]
message = "DeltaNet variant."
"#;
let lookup = move |n: &str| if n == "base" { Some(base.to_string()) } else { None };
let resolved = Profile::parse_toml(child).unwrap().resolve(&lookup).unwrap();
assert_eq!(resolved.symbols.len(), 1);
assert_eq!(resolved.symbols[0].name, "hidden");
assert!(resolved.extends.is_none());
assert!(
resolved
.expected_tensors
.iter()
.any(|t| t.name == "blk.{layer}.post_attention_norm.weight"),
"expected post_attention_norm to be promoted to expected"
);
assert!(resolved.optional_tensors.is_empty());
assert!(resolved
.expected_tensors
.iter()
.any(|t| t.name == "token_embd.weight"));
assert_eq!(resolved.hints.len(), 1);
}
#[test]
fn resolve_child_symbol_overrides_base() {
let base = r#"
name = "base"
architecture = "x"
[[symbols]]
name = "hidden"
source = "metadata:old_key"
"#;
let child = r#"
name = "child"
architecture = "x"
extends = "base"
[[symbols]]
name = "hidden"
source = "metadata:new_key"
"#;
let lookup = move |n: &str| if n == "base" { Some(base.to_string()) } else { None };
let resolved = Profile::parse_toml(child).unwrap().resolve(&lookup).unwrap();
assert_eq!(resolved.symbols.len(), 1);
assert_eq!(
resolved.symbols[0].source,
SymbolSource::Metadata("new_key".into())
);
}
#[test]
fn resolve_rejects_cyclic_inheritance() {
let a = r#"
name = "a"
architecture = "x"
extends = "b"
"#;
let b = r#"
name = "b"
architecture = "x"
extends = "a"
"#;
let lookup = move |n: &str| match n {
"a" => Some(a.to_string()),
"b" => Some(b.to_string()),
_ => None,
};
let err = Profile::parse_toml(a).unwrap().resolve(&lookup).unwrap_err();
match err {
DiagnosticError::CyclicProfileInheritance { chain } => {
assert!(chain.contains("a"));
assert!(chain.contains("b"));
}
other => panic!("expected CyclicProfileInheritance, got {other:?}"),
}
}
#[test]
fn resolve_rejects_unknown_base() {
let src = r#"
name = "x"
architecture = "x"
extends = "ghost"
"#;
let err = Profile::parse_toml(src)
.unwrap()
.resolve(&|_| None)
.unwrap_err();
assert!(matches!(
err,
DiagnosticError::UnknownProfileBase { ref base, .. } if base == "ghost"
));
}
#[test]
fn resolve_rejects_deeper_than_max_depth() {
let tomls: Vec<(&str, String)> = vec![
("a", "name = \"a\"\narchitecture = \"x\"\nextends = \"a_base\"".into()),
("a_base", "name = \"a_base\"\narchitecture = \"x\"\nextends = \"l2\"".into()),
("l2", "name = \"l2\"\narchitecture = \"x\"\nextends = \"l3\"".into()),
("l3", "name = \"l3\"\narchitecture = \"x\"\nextends = \"l4\"".into()),
("l4", "name = \"l4\"\narchitecture = \"x\"\nextends = \"l5\"".into()),
("l5", "name = \"l5\"\narchitecture = \"x\"".into()),
];
let tomls_clone = tomls.clone();
let lookup = move |n: &str| {
tomls_clone
.iter()
.find(|(name, _)| *name == n)
.map(|(_, s)| s.clone())
};
let err = Profile::parse_toml(&tomls[0].1).unwrap().resolve(&lookup).unwrap_err();
assert!(
matches!(err, DiagnosticError::LoadProfile(ref msg) if msg.contains("max depth")),
"got {err:?}"
);
}
#[test]
fn load_builtin_profile_returns_error_for_unknown_name() {
let err = load_builtin_profile("no-such-profile").unwrap_err();
assert!(matches!(err, DiagnosticError::LoadProfile(_)));
}
#[test]
fn list_builtin_profile_names_returns_declared_order() {
let names = list_builtin_profile_names();
assert_eq!(names, vec!["llama3", "qwen3", "qwen35moe"]);
}
#[test]
fn every_builtin_profile_loads_and_resolves_inheritance() {
for name in list_builtin_profile_names() {
let profile = load_builtin_profile(name).unwrap_or_else(|e| {
panic!("failed to load built-in profile `{name}`: {e}")
});
assert_eq!(profile.name, name);
assert!(profile.extends.is_none(), "`{name}`'s extends should be resolved");
assert!(!profile.expected_tensors.is_empty(),
"`{name}` should declare at least one expected tensor");
}
}
#[test]
fn qwen35moe_inherits_attention_tensors_from_ancestors() {
let p = load_builtin_profile("qwen35moe").unwrap();
assert!(p.expected_tensors.iter().any(|t| t.name == "blk.{layer}.attn_q.weight"));
assert!(p.expected_tensors.iter().any(|t| t.name == "blk.{layer}.attn_k_norm.weight"));
assert!(p.expected_tensors.iter().any(|t| t.name == "blk.{layer}.post_attention_norm.weight"));
assert!(p.expected_tensors.iter().any(|t| t.name == "token_embd.weight"));
let hidden = p.symbols.iter().find(|s| s.name == "hidden").unwrap();
assert_eq!(
hidden.source,
SymbolSource::Metadata("qwen35moe.embedding_length".into())
);
}
}