use llama_rs::backend::Backend;
use llama_rs::backend::cpu::CpuBackend;
use llama_rs::gguf::GgufFile;
use llama_rs::tensor::{DType, Tensor};
use std::io::Read;
use std::path::Path;
fn main() {
let model_path = "/home/joseph/Models/qwen2.5-0.5b-instruct-q4_k_m.gguf";
eprintln!("Loading model...");
let gguf = GgufFile::open(Path::new(model_path)).expect("Failed to open GGUF");
let backend = CpuBackend::new();
let info = gguf
.data
.get_tensor("output.weight")
.expect("tensor not found");
let data = gguf.tensor_data("output.weight").expect("data not found");
let shape: Vec<usize> = info.dims.iter().map(|&d| d as usize).collect();
println!("output.weight tensor:");
println!(" GGUF shape (ne): {:?}", shape); println!(" dtype: {:?}", DType::from(info.dtype));
println!(" raw data len: {} bytes", data.len());
let expected_bytes = 896 * 151936 / 32 * 34;
println!(" expected bytes (Q8_0): {}", expected_bytes);
let tensor = Tensor::new(data.to_vec(), shape.clone(), DType::from(info.dtype)).unwrap();
let total_elements = 896 * 151936;
let mut out = Tensor::zeros(vec![total_elements], DType::F32);
backend
.dequantize(&tensor, &mut out)
.expect("dequantize failed");
let dequant = out.as_f32().unwrap();
println!(" dequantized len: {}", dequant.len());
println!("\n=== Checking dequantization layout ===");
println!("First 5 values: {:?}", &dequant[..5]);
let ref_path = "/tmp/output_weight_row0.npy";
if let Ok(mut file) = std::fs::File::open(ref_path) {
let mut npy_data = Vec::new();
file.read_to_end(&mut npy_data).unwrap();
let header_end = npy_data.iter().position(|&b| b == b'\n').unwrap_or(0) + 1;
if npy_data.len() > 128 {
let magic = &npy_data[..6];
if magic == b"\x93NUMPY" {
let _version = (npy_data[6], npy_data[7]);
let header_len = u16::from_le_bytes([npy_data[8], npy_data[9]]) as usize;
let data_start = 10 + header_len;
let float_data: Vec<f64> = npy_data[data_start..]
.chunks_exact(8)
.map(|chunk| f64::from_le_bytes(chunk.try_into().unwrap()))
.collect();
println!("\nPython reference (row 0 = vocab token 0):");
println!(" First 5: {:?}", &float_data[..5]);
let our_row0: Vec<f64> = dequant[..896].iter().map(|&x| x as f64).collect();
let py_row0: Vec<f64> = float_data[..896].to_vec();
let mut match_count = 0;
let mut max_diff = 0.0f64;
for (i, (&ours, &theirs)) in our_row0.iter().zip(py_row0.iter()).enumerate() {
let diff = (ours - theirs).abs();
if diff < 1e-4 {
match_count += 1;
}
if diff > max_diff {
max_diff = diff;
}
if i < 5 {
println!(
" [{}] ours={:.6}, py={:.6}, diff={:.6}",
i, ours, theirs, diff
);
}
}
println!("\n Matching values (< 1e-4 diff): {}/{}", match_count, 896);
println!(" Max difference: {:.6}", max_diff);
if match_count == 896 {
println!(
"\n ✓ Layout confirmed: dequant[j * 896 + i] = weight for vocab j, hidden i"
);
} else {
println!("\n ✗ Layout mismatch! Let's check alternative indexing...");
let alt_row0: Vec<f64> = (0..896).map(|i| dequant[i * 151936] as f64).collect();
let mut alt_match = 0;
for (&ours, &theirs) in alt_row0.iter().zip(py_row0.iter()) {
if (ours - theirs).abs() < 1e-4 {
alt_match += 1;
}
}
println!(" Alternative layout match: {}/{}", alt_match, 896);
}
}
}
} else {
println!("Could not load Python reference - run compare_dequant_values.py first");
}
println!("\n=== Test logit computation ===");
let hidden: Vec<f32> = vec![1.0; 896];
let logit_0: f32 = hidden.iter().zip(&dequant[..896]).map(|(h, w)| h * w).sum();
println!("Logit for vocab 0 with hidden=[1.0; 896]: {:.4}", logit_0);
let logit_17: f32 = hidden
.iter()
.zip(&dequant[17 * 896..18 * 896])
.map(|(h, w)| h * w)
.sum();
println!("Logit for vocab 17 with hidden=[1.0; 896]: {:.4}", logit_17);
println!("\n(Python row 0 sum was ~0.227, row 17 sum was ~0.220)");
}