use std::collections::{BTreeMap, BTreeSet};
use crate::diagnostics::model_source::ModelSource;
use crate::diagnostics::profile::{ExpectedTensor, Hint, Profile, SymbolDef, SymbolSource};
use crate::diagnostics::report::{
ArchitectureInfo, DiagnosticReport, FormatInfo, Hypothesis, HypothesisTriggers,
MetadataDelta, MetadataDeltas, MissingTensor, ProfileRef, ResolvedSymbol,
ShapeComparisonSkipped, ShapeMismatch, Summary, TensorPattern, UnexpectedTensor, Verdict,
Warning, SCHEMA_VERSION,
};
use crate::diagnostics::shape::{ShapeExpr, SymbolTable};
use crate::diagnostics::types::{MetadataBundle, TensorRecord};
const LAYER_PLACEHOLDER: &str = "{layer}";
#[derive(Debug, Clone)]
pub struct ReportContext<'a> {
pub file_path: &'a str,
pub arch_source: &'a str,
pub format_kind: &'a str,
}
pub fn compare(
source: &dyn ModelSource,
profile: Option<&Profile>,
ctx: &ReportContext<'_>,
) -> DiagnosticReport {
let metadata = source.metadata();
let inventory = source.tensors();
let Some(profile) = profile else {
return inventory_only_report(source, &metadata, &inventory, ctx);
};
let mut warnings: Vec<Warning> = Vec::new();
let (symbols, resolved_symbols) = resolve_symbols(&profile.symbols, &metadata, &mut warnings);
let (required_patterns, required_warnings) =
expand_patterns(&profile.expected_tensors, &symbols, false);
let (optional_patterns, optional_warnings) =
expand_patterns(&profile.optional_tensors, &symbols, true);
warnings.extend(required_warnings);
warnings.extend(optional_warnings);
let classification = classify(&inventory, &required_patterns, &optional_patterns, &symbols);
let metadata_deltas = classify_metadata(profile, &metadata);
let hypotheses =
fire_hypotheses(&profile.hints, &classification.missing, &classification.unexpected);
let verdict = classification.verdict_with(&metadata_deltas);
let summary = Summary {
required_missing: classification
.missing
.iter()
.filter(|m| !m.optional)
.count() as u32,
optional_missing: classification
.missing
.iter()
.filter(|m| m.optional)
.count() as u32,
unexpected_patterns: classification.unexpected.len() as u32,
shape_mismatches: classification.shape_mismatches.len() as u32,
hypotheses: hypotheses.len() as u32,
warnings: warnings.len() as u32,
};
DiagnosticReport {
schema_version: SCHEMA_VERSION,
file: ctx.file_path.to_string(),
format: FormatInfo {
kind: ctx.format_kind.to_string(),
label: source.format_label(),
tensor_count: inventory.len() as u32,
metadata_count: metadata.len() as u32,
},
architecture: ArchitectureInfo {
declared: source.declared_architecture(),
source: ctx.arch_source.to_string(),
},
profile: ProfileRef::Builtin {
name: profile.name.clone(),
extends: None,
},
verdict,
symbols: resolved_symbols,
missing_tensors: classification.missing,
unexpected_tensors: classification.unexpected,
shape_mismatches: classification.shape_mismatches,
shape_comparisons_skipped: classification.shape_comparisons_skipped,
metadata_deltas,
hypotheses,
warnings,
summary,
}
}
fn inventory_only_report(
source: &dyn ModelSource,
metadata: &MetadataBundle,
inventory: &[crate::diagnostics::types::TensorRecord],
ctx: &ReportContext<'_>,
) -> DiagnosticReport {
let unexpected: Vec<UnexpectedTensor> = inventory
.iter()
.map(|t| UnexpectedTensor {
pattern: TensorPattern {
name: t.name.clone(),
per_layer_count: None,
layers: vec![],
},
shape: t.shape.clone(),
dtype: t.dtype.clone(),
})
.collect();
DiagnosticReport {
schema_version: SCHEMA_VERSION,
file: ctx.file_path.to_string(),
format: FormatInfo {
kind: ctx.format_kind.to_string(),
label: source.format_label(),
tensor_count: inventory.len() as u32,
metadata_count: metadata.len() as u32,
},
architecture: ArchitectureInfo {
declared: source.declared_architecture(),
source: ctx.arch_source.to_string(),
},
profile: ProfileRef::None,
verdict: Verdict::UnknownArchitecture,
symbols: vec![],
missing_tensors: vec![],
unexpected_tensors: unexpected,
shape_mismatches: vec![],
shape_comparisons_skipped: vec![],
metadata_deltas: MetadataDeltas::default(),
hypotheses: vec![],
warnings: vec![],
summary: Summary {
required_missing: 0,
optional_missing: 0,
unexpected_patterns: inventory.len() as u32,
shape_mismatches: 0,
hypotheses: 0,
warnings: 0,
},
}
}
fn resolve_symbols(
defs: &[SymbolDef],
metadata: &MetadataBundle,
warnings: &mut Vec<Warning>,
) -> (SymbolTable, Vec<ResolvedSymbol>) {
let mut table = SymbolTable::new();
let mut resolved = Vec::new();
for def in defs {
let wire = def.source.as_wire_string();
match &def.source {
SymbolSource::Metadata(key) => {
if let Some(n) = metadata.get(key).and_then(|v| v.as_symbol_value()) {
table.insert(def.name.clone(), n);
resolved.push(ResolvedSymbol {
name: def.name.clone(),
value: n,
source: wire,
});
} else {
warnings.push(Warning {
code: "unresolved_symbol".into(),
message: format!(
"symbol `{}` could not resolve from metadata key `{key}`",
def.name
),
});
}
}
SymbolSource::Derived(expr_src) => {
let expr = ShapeExpr::from_str(expr_src);
match expr.evaluate(&table) {
Ok(n) => {
table.insert(def.name.clone(), n);
resolved.push(ResolvedSymbol {
name: def.name.clone(),
value: n,
source: wire,
});
}
Err(e) => warnings.push(Warning {
code: "unresolved_derived_symbol".into(),
message: format!("symbol `{}`: {e}", def.name),
}),
}
}
}
}
(table, resolved)
}
#[derive(Debug, Clone)]
struct ExpandedPattern {
source_name: String,
concrete_name: String,
layer: Option<u32>,
per_layer_symbol: Option<String>,
optional: bool,
shape: Option<Vec<ShapeExpr>>,
wildcard: bool,
}
fn expand_patterns(
entries: &[ExpectedTensor],
symbols: &SymbolTable,
optional: bool,
) -> (Vec<ExpandedPattern>, Vec<Warning>) {
let mut out = Vec::new();
let mut warnings = Vec::new();
for entry in entries {
let Some(sym) = entry.per_layer.as_ref() else {
out.push(ExpandedPattern {
source_name: entry.name.clone(),
concrete_name: entry.name.clone(),
layer: None,
per_layer_symbol: None,
optional,
shape: entry.shape.clone(),
wildcard: false,
});
continue;
};
if !entry.name.contains(LAYER_PLACEHOLDER) {
warnings.push(Warning {
code: "missing_layer_placeholder".into(),
message: format!(
"tensor pattern `{}` declares per_layer = `{sym}` but has no `{{layer}}` placeholder",
entry.name
),
});
continue;
}
let Some(&count) = symbols.get(sym) else {
warnings.push(Warning {
code: "could_not_expand_per_layer".into(),
message: format!(
"tensor pattern `{}` needs symbol `{sym}` for per-layer expansion",
entry.name
),
});
out.push(ExpandedPattern {
source_name: entry.name.clone(),
concrete_name: entry.name.clone(),
layer: None,
per_layer_symbol: Some(sym.clone()),
optional,
shape: entry.shape.clone(),
wildcard: true,
});
continue;
};
for i in 0..count as u32 {
out.push(ExpandedPattern {
source_name: entry.name.clone(),
concrete_name: entry.name.replace(LAYER_PLACEHOLDER, &i.to_string()),
layer: Some(i),
per_layer_symbol: Some(sym.clone()),
optional,
shape: entry.shape.clone(),
wildcard: false,
});
}
}
(out, warnings)
}
#[derive(Debug, Default)]
struct Classification {
missing: Vec<MissingTensor>,
unexpected: Vec<UnexpectedTensor>,
shape_mismatches: Vec<ShapeMismatch>,
shape_comparisons_skipped: Vec<ShapeComparisonSkipped>,
}
impl Classification {
fn verdict_with(&self, metadata_deltas: &MetadataDeltas) -> Verdict {
let required_missing = self.missing.iter().any(|m| !m.optional);
let any_shape_mismatch = !self.shape_mismatches.is_empty();
let required_meta_missing = !metadata_deltas.missing_required.is_empty();
if required_missing || any_shape_mismatch || required_meta_missing {
return Verdict::ProfileMismatch;
}
let only_optional_extras = !self.missing.is_empty()
|| !self.unexpected.is_empty()
|| !metadata_deltas.unexpected.is_empty();
if only_optional_extras {
Verdict::OptionalExtras
} else {
Verdict::Matches
}
}
}
fn classify(
inventory: &[TensorRecord],
required: &[ExpandedPattern],
optional: &[ExpandedPattern],
symbols: &SymbolTable,
) -> Classification {
let mut by_name: BTreeMap<&str, &ExpandedPattern> = BTreeMap::new();
for p in required.iter().chain(optional.iter()) {
if !p.wildcard {
by_name.insert(p.concrete_name.as_str(), p);
}
}
let mut matched: BTreeSet<String> = BTreeSet::new();
let mut missing_map: BTreeMap<(String, bool), Vec<u32>> = BTreeMap::new();
let mut unexpected_map: BTreeMap<String, Vec<&TensorRecord>> = BTreeMap::new();
let mut shape_mismatches: Vec<ShapeMismatch> = Vec::new();
let mut skipped: Vec<ShapeComparisonSkipped> = Vec::new();
let wildcards: Vec<&ExpandedPattern> = required
.iter()
.chain(optional.iter())
.filter(|p| p.wildcard)
.collect();
for tensor in inventory {
if let Some(pat) = by_name.get(tensor.name.as_str()) {
matched.insert(tensor.name.clone());
if let Some(expected_shape) = pat.shape.as_ref() {
let resolved: Vec<Result<u64, _>> =
expected_shape.iter().map(|e| e.evaluate(symbols)).collect();
if resolved.iter().any(|r| r.is_err()) {
skipped.push(ShapeComparisonSkipped {
pattern: TensorPattern {
name: pat.source_name.clone(),
per_layer_count: pat
.per_layer_symbol
.as_ref()
.and_then(|s| symbols.get(s).copied())
.map(|n| n as u32),
layers: pat.layer.iter().copied().collect(),
},
reason: resolved
.iter()
.find_map(|r| r.as_ref().err().cloned())
.map(|e| e.to_string())
.unwrap_or_default(),
});
continue;
}
let resolved_ok: Vec<u64> =
resolved.into_iter().map(|r| r.unwrap()).collect();
let actual_u64: Vec<u64> = tensor.shape.iter().map(|&d| d as u64).collect();
if resolved_ok != actual_u64 {
shape_mismatches.push(ShapeMismatch {
pattern: TensorPattern {
name: pat.source_name.clone(),
per_layer_count: pat
.per_layer_symbol
.as_ref()
.and_then(|s| symbols.get(s).copied())
.map(|n| n as u32),
layers: pat.layer.iter().copied().collect(),
},
expected_shape: expected_shape.clone(),
actual_shape: tensor.shape.clone(),
resolved_expected: resolved_ok.into_iter().map(Some).collect(),
});
}
}
continue;
}
let absorbed_by_wildcard = wildcards.iter().any(|p| wildcard_matches(p, &tensor.name));
if absorbed_by_wildcard {
matched.insert(tensor.name.clone());
continue;
}
unexpected_map
.entry(unexpected_group_key(&tensor.name))
.or_default()
.push(tensor);
}
for pat in required.iter().chain(optional.iter()) {
if pat.wildcard {
continue;
}
if !matched.contains(pat.concrete_name.as_str()) {
missing_map
.entry((pat.source_name.clone(), pat.optional))
.or_default()
.extend(pat.layer);
}
}
let missing: Vec<MissingTensor> = missing_map
.into_iter()
.map(|((name, is_optional), mut layers)| {
layers.sort();
let per_layer_count = required
.iter()
.chain(optional.iter())
.find(|p| p.source_name == name)
.and_then(|p| p.per_layer_symbol.as_ref().and_then(|s| symbols.get(s).copied()))
.map(|n| n as u32);
let expected_shape = required
.iter()
.chain(optional.iter())
.find(|p| p.source_name == name)
.and_then(|p| p.shape.clone());
MissingTensor {
pattern: TensorPattern {
name,
per_layer_count,
layers,
},
expected_shape,
optional: is_optional,
}
})
.collect();
let unexpected: Vec<UnexpectedTensor> = unexpected_map
.into_iter()
.map(|(key, records)| {
let dtype = records[0].dtype.clone();
let shape = records[0].shape.clone();
let layers: Vec<u32> = records
.iter()
.filter_map(|r| layer_index_from_name(&r.name))
.collect();
let per_layer_count = if layers.is_empty() {
None
} else {
Some(records.len() as u32)
};
UnexpectedTensor {
pattern: TensorPattern {
name: key,
per_layer_count,
layers,
},
shape,
dtype,
}
})
.collect();
Classification {
missing,
unexpected,
shape_mismatches,
shape_comparisons_skipped: skipped,
}
}
fn parse_blk_name(name: &str) -> Option<(u32, &str)> {
let rest = name.strip_prefix("blk.")?;
let (digits, tail) = rest.split_once('.')?;
let layer: u32 = digits.parse().ok()?;
Some((layer, tail))
}
fn unexpected_group_key(name: &str) -> String {
match parse_blk_name(name) {
Some((_, tail)) => format!("blk.{{layer}}.{tail}"),
None => name.to_string(),
}
}
fn layer_index_from_name(name: &str) -> Option<u32> {
parse_blk_name(name).map(|(l, _)| l)
}
fn wildcard_matches(pattern: &ExpandedPattern, name: &str) -> bool {
let Some(suffix) = pattern.source_name.strip_prefix("blk.{layer}.") else {
return false;
};
matches!(parse_blk_name(name), Some((_, tail)) if tail == suffix)
}
fn classify_metadata(profile: &Profile, metadata: &MetadataBundle) -> MetadataDeltas {
let mut missing_required = Vec::new();
for req in &profile.required_metadata {
if !metadata.contains_key(&req.key) {
missing_required.push(MetadataDelta {
key: req.key.clone(),
value: None,
});
}
}
let default_prefixes = default_allowed_prefixes(&profile.architecture);
let effective_prefixes: Vec<String> = if profile.allowed_metadata_prefixes.is_empty() {
default_prefixes
} else {
profile.allowed_metadata_prefixes.clone()
};
let required_keys: BTreeSet<&str> = profile
.required_metadata
.iter()
.map(|r| r.key.as_str())
.collect();
let mut unexpected = Vec::new();
for (key, value) in metadata {
let allowed = required_keys.contains(key.as_str())
|| effective_prefixes.iter().any(|p| key.starts_with(p));
if !allowed {
unexpected.push(MetadataDelta {
key: key.clone(),
value: Some(value.clone()),
});
}
}
MetadataDeltas {
unexpected,
missing_required,
}
}
fn default_allowed_prefixes(arch: &str) -> Vec<String> {
vec![
"general.".to_string(),
"tokenizer.".to_string(),
format!("{arch}."),
]
}
fn fire_hypotheses(
hints: &[Hint],
missing: &[MissingTensor],
unexpected: &[UnexpectedTensor],
) -> Vec<Hypothesis> {
let mut out = Vec::new();
for (idx, hint) in hints.iter().enumerate() {
let missing_ok = hint
.when_missing
.iter()
.all(|pat| missing.iter().any(|m| pattern_match(pat, &m.pattern.name)));
let unexpected_ok = hint
.when_unexpected
.iter()
.all(|pat| unexpected.iter().any(|u| pattern_match(pat, &u.pattern.name)));
if missing_ok && unexpected_ok {
out.push(Hypothesis {
id: (idx as u32) + 1,
name: hint.name.clone(),
triggered_by: HypothesisTriggers {
missing: hint.when_missing.clone(),
unexpected: hint.when_unexpected.clone(),
},
message: hint.message.clone(),
});
}
}
out
}
fn pattern_match(pattern: &str, candidate: &str) -> bool {
normalize_blk(pattern) == normalize_blk(candidate)
}
fn normalize_blk(s: &str) -> String {
if let Some(rest) = s.strip_prefix("blk.*.") {
return format!("blk.{{layer}}.{rest}");
}
if s.starts_with("blk.{layer}.") {
return s.to_string();
}
match parse_blk_name(s) {
Some((_, tail)) => format!("blk.{{layer}}.{tail}"),
None => s.to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::diagnostics::profile::{ExpectedTensor, Hint, Profile, SymbolDef, SymbolSource};
use crate::diagnostics::types::{MetadataValue, TensorDtype, TensorRecord};
struct MockSource {
arch: Option<String>,
metadata: MetadataBundle,
tensors: Vec<TensorRecord>,
}
impl ModelSource for MockSource {
fn declared_architecture(&self) -> Option<String> {
self.arch.clone()
}
fn metadata(&self) -> MetadataBundle {
self.metadata.clone()
}
fn tensors(&self) -> Vec<TensorRecord> {
self.tensors.clone()
}
fn format_label(&self) -> String {
"Mock".into()
}
}
fn tensor(name: &str, shape: Vec<usize>) -> TensorRecord {
TensorRecord {
name: name.into(),
shape,
dtype: TensorDtype::F16,
}
}
fn test_ctx() -> ReportContext<'static> {
ReportContext {
file_path: "/x",
arch_source: "test",
format_kind: "mock",
}
}
fn profile_with_symbols() -> Profile {
Profile {
name: "test".into(),
architecture: "test".into(),
extends: None,
symbols: vec![
SymbolDef {
name: "hidden".into(),
source: SymbolSource::Metadata("test.hidden".into()),
},
SymbolDef {
name: "n_layers".into(),
source: SymbolSource::Metadata("test.n_layers".into()),
},
SymbolDef {
name: "vocab".into(),
source: SymbolSource::Metadata("test.vocab".into()),
},
],
required_metadata: vec![],
expected_tensors: vec![
ExpectedTensor {
name: "token_embd.weight".into(),
shape: Some(vec![ShapeExpr::from_str("vocab"), ShapeExpr::from_str("hidden")]),
per_layer: None,
dtype: None,
},
ExpectedTensor {
name: "blk.{layer}.attn_q.weight".into(),
shape: Some(vec![ShapeExpr::from_str("hidden"), ShapeExpr::from_str("hidden")]),
per_layer: Some("n_layers".into()),
dtype: None,
},
],
optional_tensors: vec![],
hints: vec![],
allowed_metadata_prefixes: vec![],
}
}
fn base_metadata() -> MetadataBundle {
let mut md = MetadataBundle::new();
md.insert("test.hidden".into(), MetadataValue::UInt(8));
md.insert("test.n_layers".into(), MetadataValue::UInt(2));
md.insert("test.vocab".into(), MetadataValue::UInt(100));
md
}
#[test]
fn matches_when_everything_present_with_right_shape() {
let src = MockSource {
arch: Some("test".into()),
metadata: base_metadata(),
tensors: vec![
tensor("token_embd.weight", vec![100, 8]),
tensor("blk.0.attn_q.weight", vec![8, 8]),
tensor("blk.1.attn_q.weight", vec![8, 8]),
],
};
let report = compare(&src, Some(&profile_with_symbols()), &test_ctx());
assert_eq!(report.verdict, Verdict::Matches, "{report:#?}");
assert!(report.missing_tensors.is_empty());
assert!(report.unexpected_tensors.is_empty());
assert!(report.shape_mismatches.is_empty());
}
#[test]
fn missing_required_tensor_yields_profile_mismatch() {
let src = MockSource {
arch: Some("test".into()),
metadata: base_metadata(),
tensors: vec![tensor("token_embd.weight", vec![100, 8])],
};
let report = compare(&src, Some(&profile_with_symbols()), &test_ctx());
assert_eq!(report.verdict, Verdict::ProfileMismatch);
assert_eq!(report.missing_tensors.len(), 1);
let m = &report.missing_tensors[0];
assert_eq!(m.pattern.name, "blk.{layer}.attn_q.weight");
assert_eq!(m.pattern.layers, vec![0, 1]);
assert_eq!(m.pattern.per_layer_count, Some(2));
assert!(!m.optional);
}
#[test]
fn unexpected_blk_tensors_collapse_by_layer() {
let src = MockSource {
arch: Some("test".into()),
metadata: base_metadata(),
tensors: vec![
tensor("token_embd.weight", vec![100, 8]),
tensor("blk.0.attn_q.weight", vec![8, 8]),
tensor("blk.1.attn_q.weight", vec![8, 8]),
tensor("blk.0.post_attn.weight", vec![8]),
tensor("blk.1.post_attn.weight", vec![8]),
],
};
let report = compare(&src, Some(&profile_with_symbols()), &test_ctx());
assert_eq!(report.verdict, Verdict::OptionalExtras);
assert_eq!(report.unexpected_tensors.len(), 1);
let u = &report.unexpected_tensors[0];
assert_eq!(u.pattern.name, "blk.{layer}.post_attn.weight");
assert_eq!(u.pattern.per_layer_count, Some(2));
assert_eq!(u.pattern.layers, vec![0, 1]);
}
#[test]
fn shape_mismatch_reported_with_expected_and_actual() {
let src = MockSource {
arch: Some("test".into()),
metadata: base_metadata(),
tensors: vec![
tensor("token_embd.weight", vec![100, 16]), tensor("blk.0.attn_q.weight", vec![8, 8]),
tensor("blk.1.attn_q.weight", vec![8, 8]),
],
};
let report = compare(&src, Some(&profile_with_symbols()), &test_ctx());
assert_eq!(report.verdict, Verdict::ProfileMismatch);
assert_eq!(report.shape_mismatches.len(), 1);
let m = &report.shape_mismatches[0];
assert_eq!(m.pattern.name, "token_embd.weight");
assert_eq!(m.actual_shape, vec![100, 16]);
assert_eq!(m.resolved_expected, vec![Some(100), Some(8)]);
}
#[test]
fn unresolved_symbol_downgrades_shape_to_skipped() {
let mut md = base_metadata();
md.remove("test.hidden"); let src = MockSource {
arch: Some("test".into()),
metadata: md,
tensors: vec![tensor("token_embd.weight", vec![100, 8])],
};
let report = compare(&src, Some(&profile_with_symbols()), &test_ctx());
assert!(!report.shape_comparisons_skipped.is_empty() || !report.warnings.is_empty());
assert!(report.shape_mismatches.is_empty());
}
#[test]
fn optional_missing_yields_optional_extras_verdict() {
let mut profile = profile_with_symbols();
profile.optional_tensors = vec![ExpectedTensor {
name: "output_norm.weight".into(),
shape: Some(vec![ShapeExpr::from_str("hidden")]),
per_layer: None,
dtype: None,
}];
let src = MockSource {
arch: Some("test".into()),
metadata: base_metadata(),
tensors: vec![
tensor("token_embd.weight", vec![100, 8]),
tensor("blk.0.attn_q.weight", vec![8, 8]),
tensor("blk.1.attn_q.weight", vec![8, 8]),
],
};
let report = compare(&src, Some(&profile), &test_ctx());
assert_eq!(report.verdict, Verdict::OptionalExtras);
assert_eq!(report.missing_tensors.len(), 1);
assert!(report.missing_tensors[0].optional);
}
#[test]
fn missing_required_metadata_yields_profile_mismatch() {
let mut profile = profile_with_symbols();
profile.required_metadata.push(
crate::diagnostics::profile::RequiredMetadata {
key: "test.something".into(),
},
);
let src = MockSource {
arch: Some("test".into()),
metadata: base_metadata(),
tensors: vec![
tensor("token_embd.weight", vec![100, 8]),
tensor("blk.0.attn_q.weight", vec![8, 8]),
tensor("blk.1.attn_q.weight", vec![8, 8]),
],
};
let report = compare(&src, Some(&profile), &test_ctx());
assert_eq!(report.verdict, Verdict::ProfileMismatch);
assert_eq!(report.metadata_deltas.missing_required.len(), 1);
}
#[test]
fn unexpected_metadata_filtered_by_allowed_prefixes() {
let mut md = base_metadata();
md.insert("general.quantization".into(), MetadataValue::String("q4_k".into()));
md.insert("tokenizer.type".into(), MetadataValue::String("bpe".into()));
md.insert("rogue.key".into(), MetadataValue::UInt(1));
let src = MockSource {
arch: Some("test".into()),
metadata: md,
tensors: vec![
tensor("token_embd.weight", vec![100, 8]),
tensor("blk.0.attn_q.weight", vec![8, 8]),
tensor("blk.1.attn_q.weight", vec![8, 8]),
],
};
let report = compare(&src, Some(&profile_with_symbols()), &test_ctx());
let unexpected_keys: Vec<_> = report
.metadata_deltas
.unexpected
.iter()
.map(|d| d.key.clone())
.collect();
assert_eq!(unexpected_keys, vec!["rogue.key"], "{report:#?}");
}
#[test]
fn hypothesis_fires_only_when_all_triggers_match() {
let mut profile = profile_with_symbols();
profile.hints = vec![Hint {
when_missing: vec!["blk.*.ffn_norm.weight".into()],
when_unexpected: vec!["blk.*.post_attention_norm.weight".into()],
message: "Rename.".into(),
name: Some("rename".into()),
}];
profile.expected_tensors.push(ExpectedTensor {
name: "blk.{layer}.ffn_norm.weight".into(),
shape: Some(vec![ShapeExpr::from_str("hidden")]),
per_layer: Some("n_layers".into()),
dtype: None,
});
let src = MockSource {
arch: Some("test".into()),
metadata: base_metadata(),
tensors: vec![
tensor("token_embd.weight", vec![100, 8]),
tensor("blk.0.attn_q.weight", vec![8, 8]),
tensor("blk.1.attn_q.weight", vec![8, 8]),
tensor("blk.0.post_attention_norm.weight", vec![8]),
tensor("blk.1.post_attention_norm.weight", vec![8]),
],
};
let report = compare(&src, Some(&profile), &test_ctx());
assert_eq!(report.hypotheses.len(), 1);
assert_eq!(report.hypotheses[0].name.as_deref(), Some("rename"));
}
#[test]
fn hypothesis_does_not_fire_when_only_partial_match() {
let mut profile = profile_with_symbols();
profile.hints = vec![Hint {
when_missing: vec!["blk.*.ffn_norm.weight".into()],
when_unexpected: vec!["blk.*.post_attention_norm.weight".into()],
message: "Rename.".into(),
name: Some("rename".into()),
}];
profile.expected_tensors.push(ExpectedTensor {
name: "blk.{layer}.ffn_norm.weight".into(),
shape: Some(vec![ShapeExpr::from_str("hidden")]),
per_layer: Some("n_layers".into()),
dtype: None,
});
let src = MockSource {
arch: Some("test".into()),
metadata: base_metadata(),
tensors: vec![
tensor("token_embd.weight", vec![100, 8]),
tensor("blk.0.attn_q.weight", vec![8, 8]),
tensor("blk.1.attn_q.weight", vec![8, 8]),
],
};
let report = compare(&src, Some(&profile), &test_ctx());
assert!(report.hypotheses.is_empty(), "{report:#?}");
}
#[test]
fn inventory_only_report_for_unknown_architecture() {
let src = MockSource {
arch: None,
metadata: MetadataBundle::new(),
tensors: vec![tensor("weird.weight", vec![1, 2])],
};
let ctx = ReportContext {
file_path: "/x",
arch_source: "none",
format_kind: "mock",
};
let report = compare(&src, None, &ctx);
assert_eq!(report.verdict, Verdict::UnknownArchitecture);
assert_eq!(report.unexpected_tensors.len(), 1);
assert_eq!(report.missing_tensors.len(), 0);
assert!(matches!(report.profile, ProfileRef::None));
}
#[test]
fn per_layer_with_unresolved_count_emits_warning_not_spurious_missing() {
let mut md = base_metadata();
md.remove("test.n_layers");
let src = MockSource {
arch: Some("test".into()),
metadata: md,
tensors: vec![
tensor("token_embd.weight", vec![100, 8]),
tensor("blk.0.attn_q.weight", vec![8, 8]),
tensor("blk.1.attn_q.weight", vec![8, 8]),
],
};
let report = compare(&src, Some(&profile_with_symbols()), &test_ctx());
assert!(
report
.warnings
.iter()
.any(|w| w.code == "could_not_expand_per_layer"
|| w.code == "unresolved_symbol"),
"{report:#?}"
);
assert!(
report
.missing_tensors
.iter()
.all(|m| m.pattern.name != "blk.{layer}.attn_q.weight"),
"{report:#?}"
);
}
}