use std::collections::{BTreeMap, BTreeSet};
use std::fmt::Write;
use crate::diagnostics::model_source::ModelSource;
use crate::diagnostics::types::TensorRecord;
pub fn generate_profile(source: &dyn ModelSource, profile_name: &str) -> String {
let arch = source
.declared_architecture()
.unwrap_or_else(|| profile_name.to_string());
let metadata = source.metadata();
let tensors = source.tensors();
let n_layers = metadata
.get(&format!("{arch}.block_count"))
.and_then(|v| v.as_symbol_value())
.or_else(|| {
metadata
.iter()
.find(|(k, _)| k.ends_with(".block_count"))
.and_then(|(_, v)| v.as_symbol_value())
})
.unwrap_or(0);
let (per_layer, one_off) = collapse_tensors(&tensors, n_layers);
let mut s = String::new();
let _ = writeln!(&mut s, "name = \"{profile_name}\"");
let _ = writeln!(&mut s, "architecture = \"{arch}\"");
let _ = writeln!(&mut s);
let _ = writeln!(
&mut s,
"# Symbols — wire these to your metadata keys. The generator"
);
let _ = writeln!(
&mut s,
"# guesses based on `{arch}.*` prefixes where possible.",
);
add_symbol(&mut s, "hidden", &format!("metadata:{arch}.embedding_length"));
add_symbol(&mut s, "n_heads", &format!("metadata:{arch}.attention.head_count"));
add_symbol(&mut s, "n_layers", &format!("metadata:{arch}.block_count"));
let _ = writeln!(&mut s);
let arch_keys: BTreeSet<&str> = metadata
.keys()
.filter_map(|k| k.strip_prefix(&format!("{arch}.")).map(|_| k.as_str()))
.collect();
if !arch_keys.is_empty() {
let _ = writeln!(&mut s, "# Required metadata. Trim aggressively.");
for key in &arch_keys {
let _ = writeln!(&mut s, "[[required_metadata]]");
let _ = writeln!(&mut s, "key = \"{key}\"");
let _ = writeln!(&mut s);
}
}
for t in one_off {
let _ = writeln!(&mut s, "[[expected_tensors]]");
let _ = writeln!(&mut s, "name = \"{}\"", t.name);
let _ = writeln!(&mut s);
}
for (pattern, layers) in per_layer {
let _ = writeln!(&mut s, "[[expected_tensors]]");
let _ = writeln!(&mut s, "name = \"{}\"", pattern);
let _ = writeln!(&mut s, "per_layer = \"n_layers\"");
if layers.len() as u64 != n_layers && n_layers > 0 {
let _ = writeln!(
&mut s,
"# NOTE: present on {}/{n_layers} layers. Move to [[optional_tensors]] if irregular.",
layers.len()
);
}
let _ = writeln!(&mut s);
}
s
}
fn add_symbol(s: &mut String, name: &str, source: &str) {
let _ = writeln!(s, "[[symbols]]");
let _ = writeln!(s, "name = \"{name}\"");
let _ = writeln!(s, "source = \"{source}\"");
}
fn collapse_tensors(
tensors: &[TensorRecord],
_n_layers: u64,
) -> (BTreeMap<String, Vec<u32>>, Vec<TensorRecord>) {
let mut per_layer: BTreeMap<String, Vec<u32>> = BTreeMap::new();
let mut one_off: Vec<TensorRecord> = Vec::new();
for t in tensors {
if let Some((pattern, layer)) = extract_layer(&t.name) {
per_layer.entry(pattern).or_default().push(layer);
} else {
one_off.push(t.clone());
}
}
(per_layer, one_off)
}
fn extract_layer(name: &str) -> Option<(String, u32)> {
let rest = name.strip_prefix("blk.")?;
let (digits, tail) = rest.split_once('.')?;
let layer: u32 = digits.parse().ok()?;
Some((format!("blk.{{layer}}.{tail}"), layer))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::diagnostics::profile::Profile;
use crate::diagnostics::types::{MetadataBundle, MetadataValue, TensorDtype};
struct MockSource {
arch: Option<String>,
metadata: MetadataBundle,
tensors: Vec<TensorRecord>,
}
impl ModelSource for MockSource {
fn declared_architecture(&self) -> Option<String> {
self.arch.clone()
}
fn metadata(&self) -> MetadataBundle {
self.metadata.clone()
}
fn tensors(&self) -> Vec<TensorRecord> {
self.tensors.clone()
}
fn format_label(&self) -> String {
"mock".into()
}
}
fn tensor(name: &str) -> TensorRecord {
TensorRecord {
name: name.into(),
shape: vec![1],
dtype: TensorDtype::F16,
}
}
#[test]
fn emits_parseable_toml_for_a_trivial_source() {
let mut md = MetadataBundle::new();
md.insert("llama.embedding_length".into(), MetadataValue::UInt(4096));
md.insert("llama.block_count".into(), MetadataValue::UInt(2));
let src = MockSource {
arch: Some("llama".into()),
metadata: md,
tensors: vec![
tensor("token_embd.weight"),
tensor("output_norm.weight"),
tensor("blk.0.attn_q.weight"),
tensor("blk.1.attn_q.weight"),
tensor("blk.0.ffn_norm.weight"),
tensor("blk.1.ffn_norm.weight"),
],
};
let toml = generate_profile(&src, "scaffold");
let p = Profile::parse_toml(&toml).expect("generated TOML must parse");
assert_eq!(p.name, "scaffold");
assert_eq!(p.architecture, "llama");
assert!(
p.expected_tensors.iter().any(|t| t.name == "token_embd.weight"),
"expected token_embd.weight as a one-off"
);
assert!(
p.expected_tensors.iter().any(|t| t.name == "blk.{layer}.attn_q.weight"),
"expected blk.*.attn_q.weight to collapse"
);
}
#[test]
fn flags_irregular_per_layer_coverage_in_comments() {
let mut md = MetadataBundle::new();
md.insert("test.block_count".into(), MetadataValue::UInt(4));
let src = MockSource {
arch: Some("test".into()),
metadata: md,
tensors: vec![
tensor("blk.0.ssm_alpha.weight"),
tensor("blk.1.ssm_alpha.weight"),
],
};
let toml = generate_profile(&src, "scaffold");
assert!(
toml.contains("NOTE: present on 2/4 layers"),
"generator should flag irregular coverage: {toml}"
);
}
#[test]
fn falls_back_to_profile_name_when_no_arch_declared() {
let src = MockSource {
arch: None,
metadata: MetadataBundle::new(),
tensors: vec![tensor("stuff.weight")],
};
let toml = generate_profile(&src, "custom");
assert!(toml.contains("architecture = \"custom\""));
}
#[test]
fn generated_profile_round_trips_through_compare_matches() {
use crate::diagnostics::compare::{compare, ReportContext};
use crate::diagnostics::report::Verdict;
let mut md = MetadataBundle::new();
md.insert("arch.embedding_length".into(), MetadataValue::UInt(32));
md.insert("arch.attention.head_count".into(), MetadataValue::UInt(4));
md.insert("arch.block_count".into(), MetadataValue::UInt(2));
let src = MockSource {
arch: Some("arch".into()),
metadata: md,
tensors: vec![
tensor("token_embd.weight"),
tensor("output_norm.weight"),
tensor("blk.0.attn_q.weight"),
tensor("blk.1.attn_q.weight"),
],
};
let toml = generate_profile(&src, "arch");
let profile = Profile::parse_toml(&toml).unwrap().resolve(&|_| None).unwrap();
let ctx = ReportContext {
file_path: "/x",
arch_source: "general.architecture",
format_kind: "mock",
};
let report = compare(&src, Some(&profile), &ctx);
assert_eq!(
report.verdict,
Verdict::Matches,
"generated profile must match its seed inventory; report: {report:#?}"
);
}
}