use crate::error::CliError;
use crate::output;
use colored::Colorize;
use std::path::Path;
#[cfg(feature = "inference")]
use realizar::gguf::{MappedGGUFModel, OwnedQuantizedModel};
#[cfg(feature = "inference")]
use realizar::apr::AprV2Model;
#[cfg(feature = "inference")]
use realizar::safetensors_infer::SafetensorsToAprConverter;
#[derive(Debug)]
#[allow(dead_code)]
struct StageResult {
name: &'static str,
eli5: &'static str, passed: bool,
details: Option<String>,
}
pub(crate) fn run(path: &Path, no_gpu: bool, json: bool, verbose: bool) -> Result<(), CliError> {
contract_pre_apr_model_validity!();
contract_pre_model_integrity_check!();
if !json {
output::section("Model Self-Test (PMAT-112: Real Validation)");
println!("Model: {}\n", path.display().to_string().cyan());
if verbose {
println!(" [verbose] Running all stages with detailed output\n");
}
}
#[cfg(feature = "inference")]
let results = run_real_checks(path, no_gpu)?;
#[cfg(not(feature = "inference"))]
let results = {
let _ = no_gpu;
if !json {
output::warn("Inference feature not enabled. Cannot run real validation.");
output::warn("Build with: cargo build --features inference");
}
vec![StageResult {
name: "N/A",
eli5: "Requires inference",
passed: false,
details: Some("Build with --features inference".to_string()),
}]
};
let passed_count = results.iter().filter(|r| r.passed).count();
let total_count = results.len();
if json {
return print_json(&results, path, passed_count, total_count);
}
print_results_table(&results);
if passed_count == total_count {
println!(
"\n{}",
format!(
"✅ {}/{} STAGES PASSED. MODEL PROVEN CORRECT.",
passed_count, total_count
)
.green()
.bold()
);
contract_post_apr_model_validity!(&());
contract_post_model_integrity_check!(&());
Ok(())
} else {
println!(
"\n{}",
format!(
"❌ {}/{} STAGES PASSED. CHECK STAGE LOGS.",
passed_count, total_count
)
.red()
.bold()
);
Err(CliError::ValidationFailed(
"Model self-test failed".to_string(),
))
}
}
#[allow(clippy::disallowed_methods)]
fn print_json(
results: &[StageResult],
path: &Path,
passed_count: usize,
total_count: usize,
) -> Result<(), CliError> {
let stages_json: Vec<serde_json::Value> = results
.iter()
.map(|r| {
serde_json::json!({
"name": r.name,
"status": if r.passed { "PASS" } else { "FAIL" },
"details": r.details.as_deref().unwrap_or(""),
})
})
.collect();
let output = serde_json::json!({
"model": path.display().to_string(),
"stages": stages_json,
"passed": passed_count,
"total": total_count,
"all_passed": passed_count == total_count,
});
println!(
"{}",
serde_json::to_string_pretty(&output).unwrap_or_default()
);
Ok(())
}
#[cfg(feature = "inference")]
fn run_real_checks(path: &Path, no_gpu: bool) -> Result<Vec<StageResult>, CliError> {
let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
match ext.to_lowercase().as_str() {
"apr" => run_real_checks_apr(path),
"gguf" => run_real_checks_gguf(path, no_gpu),
"safetensors" => run_real_checks_safetensors(path),
_ => Err(CliError::InvalidFormat(format!(
"Unsupported format: {}. Use .apr, .gguf, or .safetensors",
ext
))),
}
}
fn tensor_check_stage(
name: &'static str,
eli5: &'static str,
found: bool,
found_msg: &str,
missing_msg: &str,
) -> StageResult {
StageResult {
name,
eli5,
passed: found,
details: Some(if found { found_msg } else { missing_msg }.to_string()),
}
}
#[cfg(feature = "inference")]
fn any_name_contains(names: &[&str], patterns: &[&str]) -> bool {
names.iter().any(|n| patterns.iter().any(|p| n.contains(p)))
}
#[cfg(feature = "inference")]
fn all_groups_match(names: &[&str], groups: &[&[&str]]) -> bool {
groups.iter().all(|group| any_name_contains(names, group))
}
#[cfg(feature = "inference")]
fn check_apr_logits(model: &AprV2Model) -> StageResult {
match model.forward(&[1u32]) {
Ok(logits) => {
let has_nan = logits.iter().any(|x| x.is_nan());
let has_inf = logits.iter().any(|x| x.is_infinite());
let valid = !has_nan && !has_inf && !logits.is_empty();
let details = if has_nan {
"NaN detected".to_string()
} else if has_inf {
"Inf detected".to_string()
} else {
format!("logits[{}]", logits.len())
};
StageResult {
name: "Logits → Probs",
eli5: "Scores → percentages",
passed: valid,
details: Some(details),
}
}
Err(e) => StageResult {
name: "Logits → Probs",
eli5: "Scores → percentages",
passed: false,
details: Some(format!("Forward failed: {e}")),
},
}
}
#[cfg(feature = "inference")]
fn check_apr_sampler(model: &AprV2Model) -> StageResult {
match model.forward(&[1u32]) {
Ok(logits) => {
let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = logits.iter().map(|x| (x - max_logit).exp()).sum();
let prob_sum: f32 = logits.iter().map(|x| (x - max_logit).exp() / exp_sum).sum();
let valid = (prob_sum - 1.0).abs() < 0.001;
StageResult {
name: "Sampler/Decode",
eli5: "Pick word, return",
passed: valid,
details: Some(format!("softmax sum = {:.6}", prob_sum)),
}
}
Err(e) => StageResult {
name: "Sampler/Decode",
eli5: "Pick word, return",
passed: false,
details: Some(format!("Forward failed: {e}")),
},
}
}
#[cfg(feature = "inference")]
fn run_real_checks_apr(path: &Path) -> Result<Vec<StageResult>, CliError> {
let model = AprV2Model::load(path)
.map_err(|e| CliError::ValidationFailed(format!("Failed to load APR: {e}")))?;
let metadata = model.metadata();
let num_layers = metadata.num_layers.unwrap_or(0);
let vocab_size = metadata.vocab_size.unwrap_or(0);
let names: Vec<&str> = model.tensor_names();
let has_embed = any_name_contains(&names, &["emb", "wte", "token_embd"]);
let has_rope = any_name_contains(&names, &["rope", "rotary"]);
let has_lm_head = any_name_contains(&names, &["lm_head"]) || names.contains(&"output.weight");
let test_tokens = vec![1u32, 2];
let forward_ok = model.forward(&test_tokens).is_ok();
Ok(vec![
StageResult {
name: "Tokenizer",
eli5: "Words → numbers",
passed: forward_ok,
details: Some(format!("tokens={test_tokens:?}")),
},
tensor_check_stage(
"Embedding",
"Numbers → vectors",
has_embed,
"Found embedding tensor",
"Missing embedding tensor",
),
StageResult {
name: "Positional Encoding",
eli5: "\"You are word #3\"",
passed: true,
details: Some(
if has_rope {
"RoPE tensors found"
} else {
"RoPE computed inline"
}
.to_string(),
),
},
tensor_check_stage(
"Q/K/V Projection",
"Make 3 question copies",
all_groups_match(
&names,
&[
&["q_proj", "attn_q"],
&["k_proj", "attn_k"],
&["v_proj", "attn_v"],
],
),
"Q/K/V found",
"Missing Q/K/V",
),
tensor_check_stage(
"Attention Scores",
"\"Who to look at?\"",
any_name_contains(&names, &["o_proj", "attn_output"]),
"Attention output found",
"Missing attention output",
),
tensor_check_stage(
"Feed-Forward (MLP)",
"\"Think about it\"",
all_groups_match(
&names,
&[
&["gate_proj", "ffn_gate"],
&["up_proj", "ffn_up"],
&["down_proj", "ffn_down"],
],
),
"MLP found",
"Missing MLP",
),
StageResult {
name: "Layer Norm",
eli5: "Keep numbers stable",
passed: all_groups_match(
&names,
&[
&["input_layernorm", "attn_norm"],
&["post_attention_layernorm", "ffn_norm"],
],
) && num_layers > 0,
details: Some(format!("{num_layers} layers")),
},
StageResult {
name: "LM Head",
eli5: "Vector → vocab scores",
passed: has_lm_head || has_embed,
details: Some(format!(
"vocab_size={vocab_size}{}",
if has_lm_head { "" } else { " (tied)" }
)),
},
check_apr_logits(&model),
check_apr_sampler(&model),
])
}
#[cfg(feature = "inference")]
fn check_gguf_lm_head(mapped: &MappedGGUFModel, vocab_size: usize) -> StageResult {
let has_explicit = mapped.model.tensors.iter().any(|t| {
(t.name == "output.weight" || t.name.contains("lm_head"))
&& t.dims.iter().any(|&d| d as usize == vocab_size)
});
let has_tied = mapped.model.tensors.iter().any(|t| {
(t.name.contains("token_embd") || t.name.contains("embed_tokens"))
&& t.dims.iter().any(|&d| d as usize == vocab_size)
});
let passed = (has_explicit || has_tied) && vocab_size > 0;
let details = if has_explicit {
format!("vocab_size={vocab_size}")
} else if has_tied {
format!("vocab_size={vocab_size} (tied embeddings)")
} else {
"Missing LM head tensor".to_string()
};
StageResult {
name: "LM Head",
eli5: "Vector → vocab scores",
passed,
details: Some(details),
}
}
#[cfg(feature = "inference")]
fn run_real_checks_gguf(path: &Path, no_gpu: bool) -> Result<Vec<StageResult>, CliError> {
if no_gpu {
eprintln!("Warning: --no-gpu is not yet implemented for check validation. Flag ignored.");
}
let mapped = MappedGGUFModel::from_path(path)
.map_err(|e| CliError::ValidationFailed(format!("Failed to mmap GGUF: {e}")))?;
let model = OwnedQuantizedModel::from_mapped(&mapped)
.map_err(|e| CliError::ValidationFailed(format!("Failed to create model: {e}")))?;
let names: Vec<&str> = mapped
.model
.tensors
.iter()
.map(|t| t.name.as_str())
.collect();
let rope_theta = mapped.model.rope_freq_base().unwrap_or(10000.0);
let has_embed = any_name_contains(&names, &["token_embd", "embed_tokens"]);
Ok(vec![
check_tokenizer_real(&model),
tensor_check_stage(
"Embedding",
"Numbers → vectors",
has_embed,
"Found embedding tensor",
"Missing embedding tensor",
),
StageResult {
name: "Positional Encoding",
eli5: "\"You are word #3\"",
passed: rope_theta > 1.0,
details: Some(format!("rope_theta={rope_theta:.1}")),
},
tensor_check_stage(
"Q/K/V Projection",
"Make 3 question copies",
all_groups_match(
&names,
&[
&["blk.0.attn_q", "layers.0.self_attn.q_proj"],
&["blk.0.attn_k", "layers.0.self_attn.k_proj"],
&["blk.0.attn_v", "layers.0.self_attn.v_proj"],
],
) || any_name_contains(&names, &["attn_qkv", "qkv_proj"]),
"Q/K/V tensors found",
"Missing Q/K/V tensors",
),
tensor_check_stage(
"Attention Scores",
"\"Who to look at?\"",
any_name_contains(&names, &["attn_output", "o_proj"]),
"Attention output tensor found",
"Missing attention output tensor",
),
tensor_check_stage(
"Feed-Forward (MLP)",
"\"Think about it\"",
all_groups_match(
&names,
&[
&["ffn_gate", "gate_proj"],
&["ffn_up", "up_proj"],
&["ffn_down", "down_proj"],
],
) || all_groups_match(
&names,
&[&["ffn_up", "up_proj"], &["ffn_down", "down_proj"]],
),
"MLP tensors found",
"Missing MLP tensors",
),
StageResult {
name: "Layer Norm",
eli5: "Keep numbers stable",
passed: (all_groups_match(
&names,
&[
&["attn_norm", "input_layernorm"],
&["ffn_norm", "post_attention_layernorm", "post_ffw_norm"],
],
) || any_name_contains(&names, &["attn_norm", "input_layernorm"]))
&& model.config().num_layers > 0,
details: Some(format!("{} layers", model.config().num_layers)),
},
check_gguf_lm_head(&mapped, model.config().vocab_size),
check_logits_real(&model),
check_sampler_real(&model),
check_gguf_data_quality(path),
])
}
#[cfg(feature = "inference")]
fn check_gguf_data_quality(path: &Path) -> StageResult {
use aprender::format::rosetta::RosettaStone;
let rosetta = RosettaStone::new();
match rosetta.validate(path) {
Ok(report) => {
let issues = report.failed_tensor_count;
let all_zeros = report.all_zero_tensors.len();
if report.is_valid && all_zeros == 0 {
StageResult {
name: "Data Quality",
eli5: "No corrupt tensors",
passed: true,
details: Some(format!(
"{} tensors checked, 0 violations",
report.tensors.len()
)),
}
} else {
StageResult {
name: "Data Quality",
eli5: "No corrupt tensors",
passed: false,
details: Some(format!(
"{issues} failed, {all_zeros} all-zeros (run `apr validate` for details)"
)),
}
}
}
Err(e) => StageResult {
name: "Data Quality",
eli5: "No corrupt tensors",
passed: false,
details: Some(format!("Validation failed: {e}")),
},
}
}
#[cfg(feature = "inference")]
fn run_real_checks_safetensors(path: &Path) -> Result<Vec<StageResult>, CliError> {
let transformer = SafetensorsToAprConverter::convert(path)
.map_err(|e| CliError::ValidationFailed(format!("Failed to load SafeTensors: {e}")))?;
let config = transformer.config();
let num_layers = config.num_layers;
let vocab_size = config.vocab_size;
let hidden_dim = config.hidden_dim;
let has_embed = !transformer.token_embedding.is_empty();
let has_layers = !transformer.layers.is_empty();
let has_lm_head = !transformer.lm_head_weight.is_empty();
let (has_qkv, has_attn_out, has_ffn) = if let Some(layer) = transformer.layers.first() {
(
!layer.qkv_weight.is_empty(),
!layer.attn_output_weight.is_empty(),
!layer.ffn_up_weight.is_empty() && !layer.ffn_down_weight.is_empty(),
)
} else {
(false, false, false)
};
let has_gate = transformer
.layers
.first()
.is_some_and(|l| l.ffn_gate_weight.is_some());
let test_tokens = vec![1u32, 2];
let forward_ok = transformer.forward(&test_tokens).is_ok();
Ok(vec![
StageResult {
name: "Tokenizer",
eli5: "Words → numbers",
passed: forward_ok,
details: Some(format!("tokens={test_tokens:?}")),
},
tensor_check_stage(
"Embedding",
"Numbers → vectors",
has_embed,
&format!("vocab={vocab_size} × hidden={hidden_dim}"),
"Missing embedding tensor",
),
StageResult {
name: "Positional Encoding",
eli5: "\"You are word #3\"",
passed: true,
details: Some(format!("RoPE theta={:.1}", config.rope_theta)),
},
tensor_check_stage(
"Q/K/V Projection",
"Make 3 question copies",
has_qkv,
"Q/K/V found",
"Missing Q/K/V",
),
tensor_check_stage(
"Attention Scores",
"\"Who to look at?\"",
has_attn_out,
"Attention output found",
"Missing attention output",
),
tensor_check_stage(
"Feed-Forward (MLP)",
"\"Think about it\"",
has_ffn,
&format!(
"MLP found{}",
if has_gate { " (SwiGLU)" } else { " (GELU)" }
),
"Missing MLP",
),
StageResult {
name: "Layer Norm",
eli5: "Keep numbers stable",
passed: has_layers && num_layers > 0,
details: Some(format!("{num_layers} layers")),
},
StageResult {
name: "LM Head",
eli5: "Vector → vocab scores",
passed: has_lm_head || has_embed,
details: Some(format!(
"vocab_size={vocab_size}{}",
if has_lm_head { "" } else { " (tied)" }
)),
},
{
match transformer.forward(&[1u32]) {
Ok(logits) => {
let has_nan = logits.iter().any(|x| x.is_nan());
let has_inf = logits.iter().any(|x| x.is_infinite());
let valid = !has_nan && !has_inf && !logits.is_empty();
StageResult {
name: "Logits → Probs",
eli5: "Scores → percentages",
passed: valid,
details: Some(if has_nan {
"NaN detected".to_string()
} else if has_inf {
"Inf detected".to_string()
} else {
format!("logits[{}]", logits.len())
}),
}
}
Err(e) => StageResult {
name: "Logits → Probs",
eli5: "Scores → percentages",
passed: false,
details: Some(format!("Forward failed: {e}")),
},
}
},
{
match transformer.forward(&[1u32]) {
Ok(logits) => {
let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = logits.iter().map(|x| (x - max_logit).exp()).sum();
let prob_sum: f32 =
logits.iter().map(|x| (x - max_logit).exp() / exp_sum).sum();
let valid = (prob_sum - 1.0).abs() < 0.001;
StageResult {
name: "Sampler/Decode",
eli5: "Pick word, return",
passed: valid,
details: Some(format!("softmax sum = {:.6}", prob_sum)),
}
}
Err(e) => StageResult {
name: "Sampler/Decode",
eli5: "Pick word, return",
passed: false,
details: Some(format!("Forward failed: {e}")),
},
}
},
])
}
include!("stage.rs");
include!("check_03.rs");