use crate::error::CliError;
use crate::output;
use aprender::format::rosetta::{
FormatType, RosettaStone, ValidationReport as RosettaValidationReport,
};
use aprender::format::validation::{AprValidator, Category, CheckStatus, ValidationReport};
use colored::Colorize;
use std::fs;
use std::path::Path;
#[provable_contracts_macros::contract("apr-cli-safety-v1", equation = "validate_exit_code")]
pub(crate) fn run(
path: &Path,
quality: bool,
strict: bool,
min_score: Option<u8>,
json: bool,
skip_contract: bool,
) -> Result<(), CliError> {
contract_pre_validate_exit_code_consistency!();
if let Some(score) = min_score {
if score > 100 {
return Err(CliError::ValidationFailed(format!(
"Invalid --min-score value: {}. Must be in range 0-100.",
score
)));
}
}
validate_path(path)?;
if !json {
println!("Validating {}...\n", path.display());
}
let format = FormatType::from_magic(path)
.or_else(|_| FormatType::from_extension(path))
.map_err(|e| CliError::InvalidFormat(format!("Cannot detect format: {e}")))?;
let result = match format {
FormatType::Apr => {
run_apr_validation(path, quality, strict, min_score, json, skip_contract)
}
FormatType::Gguf | FormatType::SafeTensors => {
run_rosetta_validation(path, format, quality, strict, json, skip_contract)
}
};
if let Ok(ref r) = result {
contract_post_validate_exit_code_consistency!(r);
}
result
}
fn run_apr_validation(
path: &Path,
quality: bool,
strict: bool,
min_score: Option<u8>,
json: bool,
skip_contract: bool,
) -> Result<(), CliError> {
let data = fs::read(path)?;
let mut validator = AprValidator::new();
let report = validator.validate_bytes(&data);
if json {
return print_apr_validation_json(path, report, strict, min_score);
}
print_check_results(report);
print_summary(report, strict)?;
if quality {
print_quality_assessment(report);
}
if let Some(min) = min_score {
if report.total_score < min {
return Err(CliError::ValidationFailed(format!(
"Score {}/100 below minimum {min}",
report.total_score
)));
}
}
if !skip_contract && report.total_score < 50 {
return Err(CliError::ValidationFailed(format!(
"Score {}/100 (below 50% threshold)",
report.total_score
)));
}
Ok(())
}
fn run_rosetta_validation(
path: &Path,
format: FormatType,
quality: bool,
strict: bool,
json: bool,
skip_contract: bool,
) -> Result<(), CliError> {
let rosetta = RosettaStone::new();
let report = rosetta
.validate(path)
.map_err(|e| CliError::ValidationFailed(format!("Validation failed: {e}")))?;
if json {
if strict
&& !skip_contract
&& (report.total_nan_count > 0
|| report.total_inf_count > 0
|| !report.all_zero_tensors.is_empty())
{
let mut issues = Vec::new();
if report.total_nan_count > 0 {
issues.push(format!("{} NaN values", report.total_nan_count));
}
if report.total_inf_count > 0 {
issues.push(format!("{} Inf values", report.total_inf_count));
}
if !report.all_zero_tensors.is_empty() {
issues.push(format!(
"{} all-zero tensors",
report.all_zero_tensors.len()
));
}
let _ = print_rosetta_validation_json(path, &report, format, quality);
return Err(CliError::ValidationFailed(format!(
"Strict mode: {}",
issues.join(", ")
)));
}
return print_rosetta_validation_json(path, &report, format, quality);
}
output::header(&format!("Validate: {} (Rosetta Stone)", format));
let mut rows: Vec<Vec<String>> = Vec::new();
for tv in &report.tensors {
let badge = if tv.is_valid {
output::badge_pass("PASS")
} else {
output::badge_fail("FAIL")
};
let failures_str = if tv.failures.is_empty() {
String::new()
} else {
tv.failures.join("; ")
};
rows.push(vec![tv.name.clone(), badge, failures_str]);
}
if !rows.is_empty() {
println!(
"{}",
output::table(&["Tensor", "Status", "Failures"], &rows)
);
}
println!();
println!("{}", report.summary());
if quality {
print_quality_constraints(&report);
}
if strict
&& !skip_contract
&& (report.total_nan_count > 0
|| report.total_inf_count > 0
|| !report.all_zero_tensors.is_empty())
{
let mut issues = Vec::new();
if report.total_nan_count > 0 {
issues.push(format!("{} NaN values", report.total_nan_count));
}
if report.total_inf_count > 0 {
issues.push(format!("{} Inf values", report.total_inf_count));
}
if !report.all_zero_tensors.is_empty() {
issues.push(format!(
"{} all-zero tensors",
report.all_zero_tensors.len()
));
}
return Err(CliError::ValidationFailed(format!(
"Strict mode: {}",
issues.join(", ")
)));
}
if report.tensors.is_empty() {
return Err(CliError::ValidationFailed(
"Model contains 0 tensors (truncated or corrupt file)".to_string(),
));
}
if skip_contract || report.is_valid {
Ok(())
} else {
Err(CliError::ValidationFailed(format!(
"{} tensors failed validation",
report.failed_tensor_count
)))
}
}
fn print_quality_constraints(report: &RosettaValidationReport) {
println!();
println!(
"{}",
"=== Physics Constraints (APR-SPEC 10.9) ===".cyan().bold()
);
println!(" Total NaN: {}", report.total_nan_count);
println!(" Total Inf: {}", report.total_inf_count);
println!(" All-zeros: {}", report.all_zero_tensors.len());
println!(" Duration: {} ms", report.duration_ms);
let all_failures: Vec<(&str, &str)> = report
.tensors
.iter()
.flat_map(|t| {
t.failures
.iter()
.map(move |f| (t.name.as_str(), f.as_str()))
})
.collect();
if all_failures.is_empty() {
println!();
println!(
" {} All tensors pass PMAT-235 contract gates",
"[OK]".green()
);
} else {
print_contract_violations(&all_failures);
}
}
fn print_contract_violations(failures: &[(&str, &str)]) {
println!();
println!("{}", "=== PMAT-235 Contract Violations ===".red().bold());
let mut by_rule: std::collections::BTreeMap<&str, Vec<&str>> =
std::collections::BTreeMap::new();
for (tensor_name, failure) in failures {
let rule_id = if failure.starts_with('[') {
failure.find(']').map_or("UNKNOWN", |end| &failure[1..end])
} else {
"UNKNOWN"
};
by_rule.entry(rule_id).or_default().push(tensor_name);
}
for (rule, tensors) in &by_rule {
println!(" {} {} tensor(s) failed", rule.red(), tensors.len());
for name in tensors.iter().take(5) {
println!(" - {}", name);
}
if tensors.len() > 5 {
println!(" ... and {} more", tensors.len() - 5);
}
}
}
#[allow(clippy::disallowed_methods)]
fn print_apr_validation_json(
path: &Path,
report: &ValidationReport,
strict: bool,
min_score: Option<u8>,
) -> Result<(), CliError> {
if strict {
eprintln!(
"Warning: --strict is not yet implemented for APR JSON validation. Flag ignored."
);
}
let passed =
report.failed_checks().is_empty() && min_score.is_none_or(|min| report.total_score >= min);
let checks_json: Vec<serde_json::Value> = report
.checks
.iter()
.filter(|c| matches!(&c.status, CheckStatus::Pass | CheckStatus::Fail(_)))
.map(|c| {
let (status, detail) = match &c.status {
CheckStatus::Pass => ("PASS", String::new()),
CheckStatus::Fail(r) => ("FAIL", r.clone()),
CheckStatus::Warn(r) => ("WARN", r.clone()),
CheckStatus::Skip(r) => ("SKIP", r.clone()),
};
serde_json::json!({
"id": c.id,
"name": c.name,
"status": status,
"detail": detail,
"points": c.points,
})
})
.collect();
let output = serde_json::json!({
"model": path.display().to_string(),
"format": "apr",
"total_score": report.total_score,
"grade": report.grade(),
"checks": checks_json,
"total_checks": report.checks.len(),
"failed": report.failed_checks().len(),
"passed": passed,
});
println!(
"{}",
serde_json::to_string_pretty(&output).unwrap_or_default()
);
if !passed {
return Err(CliError::ValidationFailed(format!(
"Score {}/100",
report.total_score
)));
}
Ok(())
}
#[allow(clippy::disallowed_methods)]
fn print_rosetta_validation_json(
path: &Path,
report: &RosettaValidationReport,
format: FormatType,
quality: bool,
) -> Result<(), CliError> {
let checks_json: Vec<serde_json::Value> = report
.tensors
.iter()
.map(|tv| {
let status = if tv.is_valid { "PASS" } else { "FAIL" };
let detail = if tv.failures.is_empty() {
String::new()
} else {
tv.failures.join("; ")
};
serde_json::json!({
"name": tv.name,
"status": status,
"detail": detail,
})
})
.collect();
let format_str = match format {
FormatType::SafeTensors => "safetensors",
FormatType::Gguf => "gguf",
FormatType::Apr => "apr",
};
let mut output = serde_json::json!({
"model": path.display().to_string(),
"format": format_str,
"total_tensors": report.tensor_count,
"failed_tensors": report.failed_tensor_count,
"total_nan": report.total_nan_count,
"total_inf": report.total_inf_count,
"duration_ms": report.duration_ms,
"checks": checks_json,
"total_checks": report.tensor_count,
"failed": report.failed_tensor_count,
"passed": report.is_valid,
});
if quality {
let all_zero_names: Vec<&str> =
report.all_zero_tensors.iter().map(|s| s.as_str()).collect();
output["quality"] = serde_json::json!({
"total_nan": report.total_nan_count,
"total_inf": report.total_inf_count,
"all_zero_tensors": all_zero_names,
"all_zero_count": report.all_zero_tensors.len(),
"physics_pass": report.total_nan_count == 0
&& report.total_inf_count == 0
&& report.all_zero_tensors.is_empty(),
});
}
println!(
"{}",
serde_json::to_string_pretty(&output).unwrap_or_default()
);
if !report.is_valid {
return Err(CliError::ValidationFailed(format!(
"{} tensors failed validation",
report.failed_tensor_count
)));
}
Ok(())
}
fn validate_path(path: &Path) -> Result<(), CliError> {
if !path.exists() {
return Err(CliError::FileNotFound(path.to_path_buf()));
}
if !path.is_file() {
return Err(CliError::NotAFile(path.to_path_buf()));
}
Ok(())
}
fn print_check_results(report: &ValidationReport) {
let mut rows: Vec<Vec<String>> = Vec::new();
for check in &report.checks {
let (badge, detail) = match &check.status {
CheckStatus::Pass => (output::badge_pass("PASS"), String::new()),
CheckStatus::Fail(reason) => (output::badge_fail("FAIL"), reason.clone()),
CheckStatus::Warn(reason) => (output::badge_warn("WARN"), reason.clone()),
CheckStatus::Skip(reason) => (output::badge_skip("SKIP"), reason.clone()),
};
rows.push(vec![
format!("{}", check.id),
check.name.to_string(),
badge,
detail,
]);
}
println!(
"{}",
output::table(&["#", "Check", "Status", "Detail"], &rows)
);
}
fn print_summary(report: &ValidationReport, strict: bool) -> Result<(), CliError> {
if strict {
eprintln!(
"Warning: --strict is not yet implemented for APR validation summary. Flag ignored."
);
}
println!();
let failed_checks = report.failed_checks();
if failed_checks.is_empty() {
println!(
" {} {}/100 points",
output::badge_pass("VALID"),
report.total_score
);
Ok(())
} else {
println!(
" {} {} checks failed",
output::badge_fail("INVALID"),
failed_checks.len()
);
Err(CliError::ValidationFailed(format!(
"{} validation checks failed",
failed_checks.len()
)))
}
}
fn print_quality_assessment(report: &ValidationReport) {
output::header("100-Point Quality Assessment");
let categories = [
(Category::Structure, "A. Format & Structural Integrity"),
(Category::Physics, "B. Tensor Physics & Statistics"),
(Category::Tooling, "C. Tooling & Operations"),
(Category::Conversion, "D. Conversion & Interoperability"),
];
let mut rows: Vec<Vec<String>> = Vec::new();
for (cat, name) in &categories {
let score = report.category_scores.get(cat).copied().unwrap_or(0);
let max = 25;
let bar = output::progress_bar(score as usize, max as usize, 20);
rows.push(vec![(*name).to_string(), format!("{score}/{max}"), bar]);
}
println!(
"{}",
output::table(&["Category", "Score", "Progress"], &rows)
);
let grade = report.grade();
println!(
"\n TOTAL: {}/100 Grade: {}",
format!("{}", report.total_score).white().bold(),
output::grade_color(grade),
);
let failed = report.failed_checks();
if !failed.is_empty() {
output::subheader("Failed Checks");
for check in failed {
if let CheckStatus::Fail(reason) = &check.status {
println!(
" {} #{}: {} - {}",
"✗".red().bold(),
check.id,
check.name,
reason.dimmed()
);
}
}
}
}
#[cfg(test)]
#[path = "validate_tests.rs"]
mod tests;