#[cfg(not(feature = "cuda"))]
fn main() {
eprintln!("This example requires the 'cuda' feature. Run with --features cuda");
}
#[cfg(feature = "cuda")]
fn main() -> Result<(), Box<dyn std::error::Error>> {
use realizar::gguf::{MappedGGUFModel, OwnedQuantizedModel, OwnedQuantizedModelCuda};
let path = std::env::var("MODEL_PATH").unwrap_or_else(|_| {
"/home/noah/models/qwen2.5-coder-1.5b-instruct-q4_k_m.gguf".to_string()
});
println!("CORRECTNESS-011: Hidden State Comparison Before Output Norm");
println!("============================================================");
println!("Model: {}", path);
let mapped = MappedGGUFModel::from_path(&path)?;
let model = OwnedQuantizedModel::from_mapped(&mapped)?;
let token_id = 791u32;
let position: usize = 0;
println!("\nToken ID: {}", token_id);
println!("Position: {}", position);
println!("Hidden dim: {}", model.config().hidden_dim);
println!("\n=== CPU Forward (tracing hidden state) ===");
let embedding = model.embed(&[token_id]);
println!("Embedding sum: {:.6}", embedding.iter().sum::<f32>());
let cpu_logits = model.forward(&[token_id])?;
let cpu_argmax = cpu_logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, v)| (i, *v));
println!("CPU argmax: {:?}", cpu_argmax);
println!("\n=== GPU Forward ===");
let mut cuda_model = OwnedQuantizedModelCuda::new(model.clone(), 0)?;
cuda_model.preload_weights_gpu()?;
cuda_model.clear_decode_graph();
std::env::set_var("GPU_DEBUG", "1");
std::env::set_var("CUDA_GRAPH_DISABLE", "1");
let mut dummy_cache = realizar::gguf::OwnedQuantizedKVCache::new(
model.config().num_layers,
model.config().num_kv_heads * (model.config().hidden_dim / model.config().num_heads),
100,
);
let gpu_logits = cuda_model.forward_gpu_resident(token_id, &mut dummy_cache, position)?;
let gpu_argmax = gpu_logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, v)| (i, *v));
println!("GPU argmax: {:?}", gpu_argmax);
println!("\n=== Analysis ===");
if cpu_logits.len() == gpu_logits.len() {
let mean_cpu: f32 = cpu_logits.iter().sum::<f32>() / cpu_logits.len() as f32;
let mean_gpu: f32 = gpu_logits.iter().sum::<f32>() / gpu_logits.len() as f32;
let mut cov = 0.0f32;
let mut var_cpu = 0.0f32;
let mut var_gpu = 0.0f32;
for (c, g) in cpu_logits.iter().zip(gpu_logits.iter()) {
let dc = c - mean_cpu;
let dg = g - mean_gpu;
cov += dc * dg;
var_cpu += dc * dc;
var_gpu += dg * dg;
}
let corr = cov / (var_cpu.sqrt() * var_gpu.sqrt() + 1e-10);
let slope = cov / (var_cpu + 1e-10);
let intercept = mean_gpu - slope * mean_cpu;
println!("Correlation: {:.6}", corr);
println!("Linear fit: GPU ≈ {:.4}*CPU + {:.4}", slope, intercept);
if (slope - 1.0).abs() < 0.1 {
println!("\nDIAGNOSIS: Slope ≈ 1.0, likely offset/bias issue");
} else if slope < 1.0 {
println!("\nDIAGNOSIS: Slope = {:.4} < 1.0", slope);
println!("GPU output is scaled DOWN by {:.1}%", (1.0 - slope) * 100.0);
println!("\nPossible causes:");
println!(" 1. Missing weight multiplication in RMSNorm");
println!(" 2. Wrong scaling in attention (1/sqrt(d_k))");
println!(" 3. Missing residual connection");
println!(" 4. Accumulated numerical error through layers");
} else {
println!("\nDIAGNOSIS: Slope = {:.4} > 1.0", slope);
println!("GPU output is scaled UP by {:.1}%", (slope - 1.0) * 100.0);
}
println!("\n=== Key Position Analysis ===");
for pos in [16, 13, 15, 74403]
.iter()
.filter(|&&p| p < cpu_logits.len())
{
let cpu_val = cpu_logits[*pos];
let gpu_val = gpu_logits[*pos];
let predicted = slope * cpu_val + intercept;
let residual = (gpu_val - predicted).abs();
println!(
"pos={}: CPU={:.4}, GPU={:.4}, predicted={:.4}, residual={:.4}",
pos, cpu_val, gpu_val, predicted, residual
);
if residual > 5.0 {
println!(
" ^^ LARGE RESIDUAL - this position deviates significantly from linear fit!"
);
}
}
}
println!("\n=== Result ===");
if cpu_argmax.map(|(i, _)| i) == gpu_argmax.map(|(i, _)| i) {
println!("PASS: CPU and GPU argmax match");
} else {
println!("FAIL: Argmax mismatch");
println!("\nNext steps:");
println!("1. Check the [CORRECTNESS-001] debug output above");
println!("2. Compare hidden state sum/rms with CPU");
println!("3. If they differ, trace earlier layers");
println!("4. Per spec ROOT CAUSE: 'Simplified trace omitted RoPE/Cache state management'");
}
Ok(())
}