use realizar::gguf::{MappedGGUFModel, OwnedQuantizedModel};
fn main() {
let path = "/tmp/parity-bench/tinyllama-1.1b-q4_k_m.gguf";
println!("=== PAR-001b: LM Head Verification ===\n");
let mapped = MappedGGUFModel::from_path(path).expect("Failed to load model");
let model = OwnedQuantizedModel::from_mapped(&mapped).expect("test");
let vocab = mapped.model.vocabulary().expect("test");
let token_id: u32 = 1; println!("Testing with BOS token (id=1)");
let hidden_dim = model.config().hidden_dim;
let embedding = model.embed(&[token_id]);
let eps = model.config().eps;
let sum_sq: f32 = embedding.iter().map(|x| x * x).sum();
let rms = (sum_sq / hidden_dim as f32 + eps).sqrt();
let inv_rms = 1.0 / rms;
let normed: Vec<f32> = embedding
.iter()
.zip(model.output_norm_weight().iter())
.map(|(&x, &w)| x * inv_rms * w)
.collect();
println!(
"Embedding L2: {}",
(embedding.iter().map(|x| x * x).sum::<f32>()).sqrt()
);
println!(
"Normed L2: {}",
(normed.iter().map(|x| x * x).sum::<f32>()).sqrt()
);
println!("\n\n=== Full forward pass (1 token) ===");
use realizar::gguf::OwnedQuantizedKVCache;
let head_dim = hidden_dim / model.config().num_heads;
let kv_dim = model.config().num_kv_heads * head_dim;
let mut cache = OwnedQuantizedKVCache::new(model.config().num_layers, kv_dim, 256);
let logits_full = model
.forward_single_with_cache(token_id, &mut cache, 0)
.expect("test");
let mut indexed_full: Vec<(usize, f32)> = logits_full.iter().copied().enumerate().collect();
indexed_full.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
println!("Top 10 predictions (full forward pass):");
for (rank, (idx, score)) in indexed_full.iter().take(10).enumerate() {
let token_str = if *idx < vocab.len() {
vocab[*idx].replace(['▁', '\u{0120}'], " ")
} else {
format!("<{}>", idx)
};
println!(
" #{}: {} (id={}, score={:.4})",
rank + 1,
token_str,
idx,
score
);
}
let uola_rank_full = indexed_full.iter().position(|(idx, _)| *idx == 29568);
println!(
"\n'uola' (29568) rank in full forward: {:?}",
uola_rank_full.map(|r| r + 1)
);
println!("\n\n=== Testing multiple tokens ===");
for &test_token in &[1u32, 2, 26222, 14990, 263] {
let mut test_cache = OwnedQuantizedKVCache::new(model.config().num_layers, kv_dim, 256);
let test_logits = model
.forward_single_with_cache(test_token, &mut test_cache, 0)
.expect("test");
let mut test_indexed: Vec<(usize, f32)> = test_logits.iter().copied().enumerate().collect();
test_indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let token_str = if test_token < vocab.len() as u32 {
vocab[test_token as usize].replace('▁', " ")
} else {
format!("<{}>", test_token)
};
let top1_str = if test_indexed[0].0 < vocab.len() {
vocab[test_indexed[0].0].replace('▁', " ")
} else {
format!("<{}>", test_indexed[0].0)
};
println!(
" Token '{}' ({}): top-1 = '{}' ({}, score={:.2})",
token_str, test_token, top1_str, test_indexed[0].0, test_indexed[0].1
);
}
println!("\n=== Verification complete ===");
}