use crate::error::CliError;
use crate::output;
use aprender::format::HEADER_SIZE;
use colored::Colorize;
use serde::Serialize;
use std::collections::BTreeMap;
use std::fs::File;
use std::io::{BufReader, Read, Seek, SeekFrom};
use std::path::Path;
#[derive(Serialize, Clone)]
pub(crate) struct LayerTrace {
pub name: String,
pub index: Option<usize>,
pub input_stats: Option<TensorStats>,
pub output_stats: Option<TensorStats>,
pub weight_stats: Option<TensorStats>,
pub anomalies: Vec<String>,
}
#[derive(Serialize, Clone)]
#[allow(dead_code)]
pub(crate) struct TensorStats {
pub count: usize,
pub mean: f32,
pub std: f32,
pub l2_norm: f32,
pub min: f32,
pub max: f32,
pub max_abs: f32,
pub nan_count: usize,
pub inf_count: usize,
}
impl TensorStats {
#[allow(dead_code, clippy::cast_lossless)]
pub(crate) fn from_slice(data: &[f32]) -> Self {
let count = data.len();
if count == 0 {
return Self {
count: 0,
mean: 0.0,
std: 0.0,
l2_norm: 0.0,
min: 0.0,
max: 0.0,
max_abs: 0.0,
nan_count: 0,
inf_count: 0,
};
}
let mut sum = 0.0_f64;
let mut sum_sq = 0.0_f64;
let mut min = f32::INFINITY;
let mut max = f32::NEG_INFINITY;
let mut max_abs = 0.0_f32;
let mut nan_count = 0;
let mut inf_count = 0;
for &v in data {
if v.is_nan() {
nan_count += 1;
continue;
}
if v.is_infinite() {
inf_count += 1;
continue;
}
sum += v as f64;
sum_sq += (v as f64) * (v as f64);
min = min.min(v);
max = max.max(v);
max_abs = max_abs.max(v.abs());
}
let valid_count = count - nan_count - inf_count;
let mean = if valid_count > 0 {
(sum / valid_count as f64) as f32
} else {
0.0
};
let variance = if valid_count > 1 {
((sum_sq / valid_count as f64) - (mean as f64).powi(2)).max(0.0)
} else {
0.0
};
let std = (variance as f32).sqrt();
let l2_norm = (sum_sq as f32).sqrt();
Self {
count,
mean,
std,
l2_norm,
min: if min.is_finite() { min } else { 0.0 },
max: if max.is_finite() { max } else { 0.0 },
max_abs,
nan_count,
inf_count,
}
}
#[allow(dead_code)]
pub(crate) fn detect_anomalies(&self, name: &str) -> Vec<String> {
let mut anomalies = Vec::new();
if self.nan_count > 0 {
anomalies.push(format!(
"{name}: {}/{} NaN values",
self.nan_count, self.count
));
}
if self.inf_count > 0 {
anomalies.push(format!(
"{name}: {}/{} Inf values",
self.inf_count, self.count
));
}
if self.std < 1e-8 && self.count > 1 {
anomalies.push(format!("{name}: near-zero variance (std={:.2e})", self.std));
}
if self.max_abs > 100.0 {
anomalies.push(format!(
"{name}: large values (max_abs={:.2})",
self.max_abs
));
}
if self.mean.abs() > 10.0 {
anomalies.push(format!("{name}: large mean bias ({:.4})", self.mean));
}
anomalies
}
}
#[derive(Serialize)]
struct TraceResult {
file: String,
format: String,
layers: Vec<LayerTrace>,
summary: TraceSummary,
}
#[derive(Serialize)]
struct TraceSummary {
total_layers: usize,
total_parameters: usize,
anomaly_count: usize,
anomalies: Vec<String>,
}
fn handle_special_modes(
path: &Path,
reference: Option<&Path>,
payload: bool,
diff: bool,
interactive: bool,
) -> Option<Result<(), CliError>> {
if interactive {
println!("Starting interactive trace (TUI) for {}", path.display());
println!("(TUI mode not yet fully implemented)");
return Some(Ok(()));
}
if payload {
return Some(run_traced_inference(path));
}
if diff {
if let Some(ref_path) = reference {
println!(
"Diffing trace between {} and {}",
path.display(),
ref_path.display()
);
} else {
println!("Diff mode requires --reference");
}
}
None
}
fn resolve_model_path(path: &Path) -> Result<std::path::PathBuf, CliError> {
use super::run::{download_hf_model, ModelSource};
let path_str = path.to_string_lossy();
if !path_str.starts_with("hf://") {
println!("Model: {}", path.display());
println!();
return Ok(path.to_path_buf());
}
let source = ModelSource::parse(&path_str)?;
match source {
ModelSource::HuggingFace { org, repo, file } => {
println!(
"Model: hf://{}/{}{}",
org,
repo,
file.as_ref().map(|f| format!("/{}", f)).unwrap_or_default()
);
println!();
eprintln!("{}", "Downloading from HuggingFace...".yellow());
download_hf_model(&org, &repo, file.as_deref())
}
_ => Ok(path.to_path_buf()),
}
}
fn preflight_contract_check(local_path: &Path) {
use aprender::format::rosetta::RosettaStone;
let rosetta = RosettaStone::new();
match rosetta.validate(local_path) {
Ok(report) => {
let contract_failures: Vec<String> = report
.tensors
.iter()
.flat_map(|t| t.failures.iter().map(move |f| format!("{}: {}", t.name, f)))
.collect();
if contract_failures.is_empty() {
println!(
"{}",
format!(
"Contract: {} tensors pass PMAT-235 gates",
report.tensor_count
)
.green()
);
} else {
println!(
"{}",
format!(
"Contract: {} violations in {} tensors",
contract_failures.len(),
report.failed_tensor_count
)
.red()
.bold()
);
for failure in contract_failures.iter().take(5) {
println!(" {}", failure.red());
}
if contract_failures.len() > 5 {
println!(" ... and {} more", contract_failures.len() - 5);
}
println!();
println!(
"{}",
"WARNING: Contract violations may cause garbage output."
.yellow()
.bold()
);
}
}
Err(e) => {
println!("{}", format!("Contract: validation skipped ({e})").yellow());
}
}
println!();
}
fn dispatch_by_format(local_path: &Path) -> Result<(), CliError> {
let ext = local_path
.extension()
.and_then(|e| e.to_str())
.unwrap_or("");
match ext.to_lowercase().as_str() {
"gguf" => run_traced_inference_gguf(local_path),
"apr" => run_traced_inference_apr(local_path),
"safetensors" => run_traced_inference_safetensors(local_path),
_ => Err(CliError::InvalidFormat(format!(
"Unknown format: {}. Supported: .gguf, .apr, .safetensors",
ext
))),
}
}
fn run_traced_inference(path: &Path) -> Result<(), CliError> {
output::section("Traced Inference (APR-TRACE-001)");
let local_path = resolve_model_path(path)?;
preflight_contract_check(&local_path);
dispatch_by_format(&local_path)
}
#[cfg(feature = "inference")]
fn run_traced_inference_gguf(path: &Path) -> Result<(), CliError> {
use colored::Colorize;
use realizar::gguf::{MappedGGUFModel, OwnedQuantizedModel, QuantizedGenerateConfig};
println!("{}", "Format: GGUF (quantized)".cyan());
println!();
let mapped = MappedGGUFModel::from_path(path)
.map_err(|e| CliError::ModelLoadFailed(format!("Failed to load GGUF: {e}")))?;
let model = OwnedQuantizedModel::from_mapped(&mapped)
.map_err(|e| CliError::ModelLoadFailed(format!("Failed to create quantized model: {e}")))?;
let config = model.config();
println!("Architecture: {}", config.architecture);
println!(" Layers: {}", config.num_layers);
println!(" Hidden dim: {}", config.hidden_dim);
println!(" Vocab size: {}", config.vocab_size);
println!(
" Heads: {} (KV: {})",
config.num_heads, config.num_kv_heads
);
println!();
let test_prompt = "What is 2+2?";
let test_tokens = mapped
.model
.encode(test_prompt)
.unwrap_or_else(|| vec![1u32]);
println!("{}", format!("Test prompt: {:?}", test_prompt).cyan());
println!("{}", format!("Encoded tokens: {:?}", test_tokens).cyan());
println!();
println!("{}", "GENERATION (max 8 tokens):".green().bold());
let gen_config = QuantizedGenerateConfig {
max_tokens: 8,
temperature: 0.0, top_k: 1,
..Default::default()
};
let output_tokens = model
.generate_with_cache(&test_tokens, &gen_config)
.map_err(|e| CliError::InferenceFailed(format!("Generation failed: {e}")))?;
let generated = &output_tokens[test_tokens.len()..];
println!(" Generated token IDs: {:?}", generated);
println!();
println!("{}", "TOKEN-BY-TOKEN DECODE:".green().bold());
for (i, &token_id) in generated.iter().enumerate() {
let decoded = mapped.model.decode(&[token_id]);
let is_garbage = is_likely_garbage(&decoded);
if is_garbage {
println!(
" {}. token_id={} → {:?} {}",
i + 1,
token_id,
decoded,
"⚠ GARBAGE".red().bold()
);
} else {
println!(" {}. token_id={} → {:?}", i + 1, token_id, decoded);
}
}
let full_decoded = mapped.model.decode(generated);
println!();
println!("{}", "FULL OUTPUT:".green().bold());
println!(" {:?}", full_decoded);
println!();
if is_likely_garbage(&full_decoded) {
println!("{}", "⚠ GARBAGE OUTPUT DETECTED!".red().bold());
println!();
println!("Likely causes:");
println!(" 1. LAYOUT-001: Column-major vs row-major kernel mismatch");
println!(" 2. Weight tensor corruption during loading");
println!(" 3. Tokenizer vocabulary mismatch");
println!();
println!("Debug steps:");
println!(" 1. Check if SafeTensors produces correct output (same model)");
println!(" 2. Compare token IDs between GGUF and SafeTensors");
println!(" 3. Verify quantization type is supported");
} else {
println!("{}", "✓ Output appears reasonable".green());
}
Ok(())
}
#[cfg(not(feature = "inference"))]
fn run_traced_inference_gguf(_path: &Path) -> Result<(), CliError> {
Err(CliError::FeatureDisabled(
"Traced inference for GGUF models requires the 'inference' feature. Build with --features inference".to_string(),
))
}
include!("vector_stats.rs");
include!("trace_likely_has_repeated.rs");
include!("layer.rs");
include!("trace_05.rs");