#![allow(clippy::uninlined_format_args)]
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::File;
use std::io::Read;
use std::path::{Path, PathBuf};
use crate::error::{Error, Result};
pub const DEFAULT_CONTRACT_PATH: &str = "../aprender/contracts/tensor-layout-v1.yaml";
#[derive(Debug, Clone, Deserialize)]
pub struct TensorLayoutContract {
pub metadata: ContractMetadata,
pub formats: HashMap<String, FormatConvention>,
pub kernel: KernelConvention,
pub tensors: HashMap<String, TensorSpec>,
pub validation_rules: Vec<ValidationRule>,
#[serde(default)]
pub semantic_validation: Option<SemanticValidation>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ContractMetadata {
pub version: String,
pub created: String,
pub updated: String,
pub author: String,
pub description: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct FormatConvention {
pub layout: String,
pub shape_convention: String,
#[serde(default)]
pub note: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct KernelConvention {
pub signature: String,
pub weight_shape: String,
pub computation: String,
pub byte_calculation: String,
pub block_sizes: HashMap<String, u32>,
#[serde(rename = "QK_K")]
pub qk_k: u32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TensorSpec {
pub gguf_name: String,
pub apr_name: String,
pub gguf_shape: String,
pub apr_shape: String,
pub transpose: bool,
pub kernel: String,
#[serde(default)]
pub kernel_out_dim: Option<String>,
#[serde(default)]
pub kernel_in_dim: Option<String>,
#[serde(default)]
pub validation: Option<String>,
#[serde(default)]
pub critical: bool,
#[serde(default)]
pub note: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ValidationRule {
pub id: String,
pub name: String,
pub description: String,
pub severity: String,
#[serde(default)]
pub critical: bool,
#[serde(default)]
pub reference: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct SemanticValidation {
#[serde(default)]
pub density: Option<DensityConfig>,
#[serde(default)]
pub numeric: Option<NumericConfig>,
#[serde(default)]
pub distribution: Option<DistributionConfig>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct DensityConfig {
pub embedding_max_zero_pct: f64,
pub weight_max_zero_pct: f64,
}
#[derive(Debug, Clone, Deserialize)]
pub struct NumericConfig {
pub allow_nan: bool,
pub allow_inf: bool,
}
#[derive(Debug, Clone, Deserialize)]
pub struct DistributionConfig {
pub min_l2_norm: f64,
pub require_variation: bool,
}
pub fn load_contract() -> Result<TensorLayoutContract> {
load_contract_from(DEFAULT_CONTRACT_PATH)
}
pub fn load_contract_from<P: AsRef<Path>>(path: P) -> Result<TensorLayoutContract> {
let path = path.as_ref();
let content = std::fs::read_to_string(path).map_err(|e| {
Error::Execution(format!(
"Failed to read tensor layout contract from {}: {e}",
path.display()
))
})?;
serde_yaml::from_str(&content).map_err(|e| {
Error::Execution(format!(
"Failed to parse tensor layout contract from {}: {e}",
path.display()
))
})
}
#[derive(Debug, Clone, Serialize)]
pub struct TensorValidationResult {
pub tensor_name: String,
pub rule_id: String,
pub passed: bool,
pub details: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub expected: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub actual: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ModelValidationResult {
pub model_path: PathBuf,
pub passed: bool,
pub rules_checked: usize,
pub rules_passed: usize,
pub rules_failed: usize,
pub tensor_results: Vec<TensorValidationResult>,
pub critical_failures: Vec<String>,
}
const MAX_HEADER_SIZE: usize = 10 * 1024 * 1024;
pub fn validate_model(
model_path: &Path,
contract: &TensorLayoutContract,
) -> Result<ModelValidationResult> {
if let Some(early_result) = check_model_path_preconditions(model_path) {
return Ok(early_result);
}
let (results, critical_failures) = run_all_validations(model_path, contract);
let rules_failed = results.iter().filter(|r| !r.passed).count();
let rules_passed = results.iter().filter(|r| r.passed).count();
Ok(ModelValidationResult {
model_path: model_path.to_path_buf(),
passed: critical_failures.is_empty() && rules_failed == 0,
rules_checked: results.len(),
rules_passed,
rules_failed,
tensor_results: results,
critical_failures,
})
}
fn check_model_path_preconditions(model_path: &Path) -> Option<ModelValidationResult> {
if !model_path.exists() {
return Some(ModelValidationResult {
model_path: model_path.to_path_buf(),
passed: false,
rules_checked: 0,
rules_passed: 0,
rules_failed: 1,
tensor_results: vec![TensorValidationResult {
tensor_name: "N/A".to_string(),
rule_id: "FILE-EXISTS".to_string(),
passed: false,
details: format!("Model file not found: {}", model_path.display()),
expected: Some("File exists".to_string()),
actual: Some("File not found".to_string()),
}],
critical_failures: vec!["Model file not found".to_string()],
});
}
let safetensors_files = find_safetensors_files(model_path);
if safetensors_files.is_empty() {
return Some(ModelValidationResult {
model_path: model_path.to_path_buf(),
passed: true,
rules_checked: 0,
rules_passed: 0,
rules_failed: 0,
tensor_results: vec![],
critical_failures: vec![],
});
}
None
}
fn run_all_validations(
model_path: &Path,
contract: &TensorLayoutContract,
) -> (Vec<TensorValidationResult>, Vec<String>) {
let mut results = Vec::new();
let mut critical_failures = Vec::new();
let all_tensors = collect_tensor_metadata(model_path, &mut results);
let config = find_and_load_config(model_path);
validate_lm_head(
&all_tensors,
&config,
contract,
&mut results,
&mut critical_failures,
);
validate_2d_tensors(contract, &all_tensors, &config, &mut results);
validate_1d_tensors(contract, &all_tensors, &config, &mut results);
(results, critical_failures)
}
fn collect_tensor_metadata(
model_path: &Path,
results: &mut Vec<TensorValidationResult>,
) -> HashMap<String, Vec<usize>> {
let safetensors_files = find_safetensors_files(model_path);
let mut all_tensors = HashMap::new();
for file in &safetensors_files {
match read_safetensors_metadata(file) {
Ok(tensors) => all_tensors.extend(tensors),
Err(e) => {
results.push(TensorValidationResult {
tensor_name: file.display().to_string(),
rule_id: "PARSE-ERROR".to_string(),
passed: false,
details: format!("Failed to read SafeTensors metadata: {e}"),
expected: None,
actual: None,
});
}
}
}
all_tensors
}
fn validate_lm_head(
all_tensors: &HashMap<String, Vec<usize>>,
config: &ModelConfig,
contract: &TensorLayoutContract,
results: &mut Vec<TensorValidationResult>,
critical_failures: &mut Vec<String>,
) {
if let Some(lm_head_shape) = all_tensors.get("lm_head.weight") {
let validation = validate_lm_head_shape(lm_head_shape, config, contract);
if !validation.passed && validation.rule_id == "F-LAYOUT-CONTRACT-002" {
critical_failures.push(validation.details.clone());
}
results.push(validation);
}
}
fn validate_2d_tensors(
contract: &TensorLayoutContract,
all_tensors: &HashMap<String, Vec<usize>>,
config: &ModelConfig,
results: &mut Vec<TensorValidationResult>,
) {
for (name, spec) in &contract.tensors {
if !spec.transpose {
continue;
}
if spec.apr_name.contains("{n}") {
validate_layer_tensors(&spec.apr_name, all_tensors, config, spec, results);
} else if let Some(actual_shape) = all_tensors.get(&spec.apr_name) {
results.push(validate_2d_tensor_shape(name, actual_shape, spec, config));
}
}
}
fn validate_1d_tensors(
contract: &TensorLayoutContract,
all_tensors: &HashMap<String, Vec<usize>>,
config: &ModelConfig,
results: &mut Vec<TensorValidationResult>,
) {
for (name, spec) in &contract.tensors {
if spec.transpose {
continue;
}
if spec.apr_name.contains("{n}") {
validate_1d_layer_tensors(&spec.apr_name, all_tensors, config, spec, results);
} else if let Some(actual_shape) = all_tensors.get(&spec.apr_name) {
results.push(validate_1d_tensor_shape(name, actual_shape, spec, config));
}
}
}
#[derive(Debug, Default)]
struct ModelConfig {
vocab_size: Option<usize>,
hidden_size: Option<usize>,
intermediate_size: Option<usize>,
num_attention_heads: Option<usize>,
num_key_value_heads: Option<usize>,
num_hidden_layers: Option<usize>,
}
fn find_safetensors_files(path: &Path) -> Vec<PathBuf> {
if path.is_file() {
if path.extension().is_some_and(|e| e == "safetensors") {
return vec![path.to_path_buf()];
}
return Vec::new();
}
let st_dir = path.join("safetensors");
let search_dir = if st_dir.exists() { &st_dir } else { path };
let Ok(entries) = search_dir.read_dir() else {
return Vec::new();
};
entries
.flatten()
.filter(|e| e.path().extension().is_some_and(|ext| ext == "safetensors"))
.map(|e| e.path())
.collect()
}
fn read_safetensors_metadata(
path: &Path,
) -> std::result::Result<HashMap<String, Vec<usize>>, String> {
let mut file = File::open(path).map_err(|e| format!("Failed to open: {e}"))?;
let mut header_len_bytes = [0u8; 8];
file.read_exact(&mut header_len_bytes)
.map_err(|e| format!("Failed to read header length: {e}"))?;
let header_len = u64::from_le_bytes(header_len_bytes) as usize;
if header_len > MAX_HEADER_SIZE {
return Err(format!("Header too large: {header_len}"));
}
let mut header_bytes = vec![0u8; header_len];
file.read_exact(&mut header_bytes)
.map_err(|e| format!("Failed to read header: {e}"))?;
let header_str =
std::str::from_utf8(&header_bytes).map_err(|e| format!("Invalid UTF-8: {e}"))?;
let header: serde_json::Value =
serde_json::from_str(header_str).map_err(|e| format!("JSON parse error: {e}"))?;
let obj = header.as_object().ok_or("Header is not JSON object")?;
let tensors = obj
.iter()
.filter(|(name, _)| *name != "__metadata__")
.filter_map(|(name, value)| {
let shape = value.as_object()?.get("shape")?.as_array()?;
let dims: Vec<usize> = shape
.iter()
.filter_map(|v| v.as_u64().map(|n| n as usize))
.collect();
Some((name.clone(), dims))
})
.collect();
Ok(tensors)
}
fn get_usize(json: &serde_json::Value, key: &str) -> Option<usize> {
json.get(key)
.and_then(serde_json::Value::as_u64)
.map(|n| n as usize)
}
fn find_and_load_config(model_path: &Path) -> ModelConfig {
let config_paths = if model_path.is_file() {
let parent = model_path.parent().unwrap_or(model_path);
let stem = model_path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("");
vec![
parent.join(format!("{stem}.config.json")),
parent.join("config.json"),
]
} else {
vec![
model_path.join("config.json"),
model_path.join("safetensors/config.json"),
]
};
for path in config_paths {
if let Ok(content) = std::fs::read_to_string(&path) {
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&content) {
return ModelConfig {
vocab_size: get_usize(&json, "vocab_size"),
hidden_size: get_usize(&json, "hidden_size"),
intermediate_size: get_usize(&json, "intermediate_size"),
num_attention_heads: get_usize(&json, "num_attention_heads"),
num_key_value_heads: get_usize(&json, "num_key_value_heads"),
num_hidden_layers: get_usize(&json, "num_hidden_layers"),
};
}
}
}
ModelConfig::default()
}
fn validate_lm_head_shape(
actual_shape: &[usize],
config: &ModelConfig,
_contract: &TensorLayoutContract,
) -> TensorValidationResult {
if actual_shape.len() != 2 {
return TensorValidationResult {
tensor_name: "lm_head.weight".to_string(),
rule_id: "F-LAYOUT-CONTRACT-002".to_string(),
passed: false,
details: "lm_head.weight must be 2D tensor".to_string(),
expected: Some("[vocab_size, hidden_size]".to_string()),
actual: Some(format!("{actual_shape:?}")),
};
}
let (expected_vocab, expected_hidden) = (config.vocab_size, config.hidden_size);
let shape_valid = match (expected_vocab, expected_hidden) {
(Some(vocab), Some(hidden)) => actual_shape[0] == vocab && actual_shape[1] == hidden,
(Some(vocab), None) => actual_shape[0] == vocab,
(None, Some(hidden)) => actual_shape[1] == hidden,
(None, None) => true, };
if shape_valid {
TensorValidationResult {
tensor_name: "lm_head.weight".to_string(),
rule_id: "F-LAYOUT-CONTRACT-002".to_string(),
passed: true,
details: format!("lm_head.weight shape correct: {:?}", actual_shape),
expected: Some(format!("[{:?}, {:?}]", expected_vocab, expected_hidden)),
actual: Some(format!("{actual_shape:?}")),
}
} else {
TensorValidationResult {
tensor_name: "lm_head.weight".to_string(),
rule_id: "F-LAYOUT-CONTRACT-002".to_string(),
passed: false,
details: format!(
"lm_head.weight shape MISMATCH (GH-202 bug pattern): expected [{:?}, {:?}], got {:?}",
expected_vocab, expected_hidden, actual_shape
),
expected: Some(format!("[{:?}, {:?}]", expected_vocab, expected_hidden)),
actual: Some(format!("{actual_shape:?}")),
}
}
}
fn validate_2d_tensor_shape(
name: &str,
actual_shape: &[usize],
spec: &TensorSpec,
config: &ModelConfig,
) -> TensorValidationResult {
if actual_shape.len() != 2 {
return TensorValidationResult {
tensor_name: spec.apr_name.clone(),
rule_id: "F-LAYOUT-CONTRACT-001".to_string(),
passed: false,
details: format!("{name} must be 2D, got {}D", actual_shape.len()),
expected: Some(spec.apr_shape.clone()),
actual: Some(format!("{actual_shape:?}")),
};
}
let expected = parse_expected_shape(&spec.apr_shape, config);
let shape_valid = match expected {
Some((dim0, dim1)) => actual_shape[0] == dim0 && actual_shape[1] == dim1,
None => true, };
TensorValidationResult {
tensor_name: spec.apr_name.clone(),
rule_id: "F-LAYOUT-CONTRACT-001".to_string(),
passed: shape_valid,
details: if shape_valid {
format!("{name} shape correct: {actual_shape:?}")
} else {
format!("{name} shape mismatch")
},
expected: Some(spec.apr_shape.clone()),
actual: Some(format!("{actual_shape:?}")),
}
}
fn validate_layer_tensors(
pattern: &str,
all_tensors: &HashMap<String, Vec<usize>>,
config: &ModelConfig,
spec: &TensorSpec,
results: &mut Vec<TensorValidationResult>,
) {
let num_layers = config.num_hidden_layers.unwrap_or(0);
for layer_idx in 0..num_layers {
let tensor_name = pattern.replace("{n}", &layer_idx.to_string());
if let Some(actual_shape) = all_tensors.get(&tensor_name) {
let validation = validate_2d_tensor_shape(&tensor_name, actual_shape, spec, config);
results.push(validation);
}
}
}
fn validate_1d_layer_tensors(
pattern: &str,
all_tensors: &HashMap<String, Vec<usize>>,
config: &ModelConfig,
spec: &TensorSpec,
results: &mut Vec<TensorValidationResult>,
) {
let num_layers = config.num_hidden_layers.unwrap_or(0);
for layer_idx in 0..num_layers {
let tensor_name = pattern.replace("{n}", &layer_idx.to_string());
if let Some(actual_shape) = all_tensors.get(&tensor_name) {
let validation = validate_1d_tensor_shape(&tensor_name, actual_shape, spec, config);
results.push(validation);
}
}
}
fn validate_1d_tensor_shape(
name: &str,
actual_shape: &[usize],
spec: &TensorSpec,
config: &ModelConfig,
) -> TensorValidationResult {
if actual_shape.len() != 1 {
return TensorValidationResult {
tensor_name: name.to_string(),
rule_id: "F-LAYOUT-CONTRACT-003".to_string(),
passed: false,
details: format!("{name} must be 1D, got {}D", actual_shape.len()),
expected: Some(spec.apr_shape.clone()),
actual: Some(format!("{actual_shape:?}")),
};
}
let shape_valid = config.hidden_size.is_none_or(|h| actual_shape[0] == h);
TensorValidationResult {
tensor_name: name.to_string(),
rule_id: "F-LAYOUT-CONTRACT-003".to_string(),
passed: shape_valid,
details: if shape_valid {
format!("{name} shape correct: {actual_shape:?}")
} else {
format!(
"{name} shape mismatch: expected [{}], got {actual_shape:?}",
config.hidden_size.unwrap_or(0)
)
},
expected: Some(spec.apr_shape.clone()),
actual: Some(format!("{actual_shape:?}")),
}
}
fn parse_expected_shape(shape_str: &str, config: &ModelConfig) -> Option<(usize, usize)> {
let shape_parts = parse_shape_dims(shape_str);
if shape_parts.len() != 2 {
return None;
}
let first_dim = resolve_dimension(&shape_parts[0], config)?;
let second_dim = resolve_dimension(&shape_parts[1], config)?;
Some((first_dim, second_dim))
}
fn resolve_dimension(dim: &str, config: &ModelConfig) -> Option<usize> {
match dim {
"vocab" | "vocab_size" => config.vocab_size,
"hidden" | "hidden_dim" | "hidden_size" => config.hidden_size,
"intermediate" | "intermediate_dim" | "intermediate_size" => config.intermediate_size,
s if s.contains('*') => {
let parts: Vec<&str> = s.split('*').map(str::trim).collect();
if parts.len() == 2 {
let left = resolve_dimension(parts[0], config)?;
let right = resolve_dimension(parts[1], config)?;
Some(left * right)
} else {
None
}
}
"heads" | "num_heads" | "num_attention_heads" => config.num_attention_heads,
"kv_heads" | "num_kv_heads" | "num_key_value_heads" => config.num_key_value_heads,
"head_dim" => {
match (config.hidden_size, config.num_attention_heads) {
(Some(h), Some(n)) if n > 0 => Some(h / n),
_ => None,
}
}
_ => dim.parse().ok(),
}
}
#[must_use]
pub fn get_validation_rules(contract: &TensorLayoutContract) -> &[ValidationRule] {
&contract.validation_rules
}
#[must_use]
pub fn get_critical_tensors(contract: &TensorLayoutContract) -> Vec<&TensorSpec> {
contract.tensors.values().filter(|t| t.critical).collect()
}
#[must_use]
pub fn is_2d_shape(shape: &str) -> bool {
shape.matches(',').count() == 1
}
#[must_use]
pub fn parse_shape_dims(shape: &str) -> Vec<String> {
shape
.trim_matches(|c| c == '[' || c == ']')
.split(',')
.map(|s| s.trim().to_string())
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_2d_shape() {
assert!(is_2d_shape("[vocab, hidden]"));
assert!(is_2d_shape("[hidden, vocab]"));
assert!(!is_2d_shape("[hidden]"));
assert!(!is_2d_shape("[a, b, c]"));
}
#[test]
fn test_parse_shape_dims() {
let dims = parse_shape_dims("[vocab, hidden]");
assert_eq!(dims, vec!["vocab", "hidden"]);
let dims = parse_shape_dims("[hidden]");
assert_eq!(dims, vec!["hidden"]);
}
#[test]
fn test_load_contract_missing_file() {
let result = load_contract_from("/nonexistent/path.yaml");
assert!(result.is_err());
}
#[test]
fn test_validate_model_missing_file() {
let contract = TensorLayoutContract {
metadata: ContractMetadata {
version: "1.0".to_string(),
created: "2026-01-01".to_string(),
updated: "2026-01-01".to_string(),
author: "test".to_string(),
description: "test".to_string(),
},
formats: HashMap::new(),
kernel: KernelConvention {
signature: "test".to_string(),
weight_shape: "[out, in]".to_string(),
computation: "y = Wx".to_string(),
byte_calculation: "out * in".to_string(),
block_sizes: HashMap::new(),
qk_k: 256,
},
tensors: HashMap::new(),
validation_rules: vec![],
semantic_validation: None,
};
let result = validate_model(Path::new("/nonexistent/model.apr"), &contract);
assert!(result.is_ok());
let result = result.unwrap();
assert!(!result.passed);
assert!(!result.critical_failures.is_empty());
}
#[test]
fn test_get_critical_tensors() {
let mut tensors = HashMap::new();
tensors.insert(
"lm_head".to_string(),
TensorSpec {
gguf_name: "output.weight".to_string(),
apr_name: "lm_head.weight".to_string(),
gguf_shape: "[hidden, vocab]".to_string(),
apr_shape: "[vocab, hidden]".to_string(),
transpose: true,
kernel: "matmul".to_string(),
kernel_out_dim: Some("vocab_size".to_string()),
kernel_in_dim: Some("hidden_dim".to_string()),
validation: None,
critical: true,
note: Some("GH-202".to_string()),
},
);
tensors.insert(
"embedding".to_string(),
TensorSpec {
gguf_name: "token_embd.weight".to_string(),
apr_name: "model.embed_tokens.weight".to_string(),
gguf_shape: "[hidden, vocab]".to_string(),
apr_shape: "[vocab, hidden]".to_string(),
transpose: true,
kernel: "lookup".to_string(),
kernel_out_dim: None,
kernel_in_dim: None,
validation: None,
critical: false,
note: None,
},
);
let contract = TensorLayoutContract {
metadata: ContractMetadata {
version: "1.0".to_string(),
created: "2026-01-01".to_string(),
updated: "2026-01-01".to_string(),
author: "test".to_string(),
description: "test".to_string(),
},
formats: HashMap::new(),
kernel: KernelConvention {
signature: "test".to_string(),
weight_shape: "[out, in]".to_string(),
computation: "y = Wx".to_string(),
byte_calculation: "out * in".to_string(),
block_sizes: HashMap::new(),
qk_k: 256,
},
tensors,
validation_rules: vec![],
semantic_validation: None,
};
let critical = get_critical_tensors(&contract);
assert_eq!(critical.len(), 1);
assert_eq!(critical[0].apr_name, "lm_head.weight");
}
}