use realizar::gguf::{MappedGGUFModel, OwnedQuantizedKVCache, OwnedQuantizedModel};
fn main() -> Result<(), Box<dyn std::error::Error>> {
let path = "/home/noah/.cache/huggingface/hub/models--Qwen--Qwen2-0.5B-Instruct-GGUF/snapshots/198f08841147e5196a6a69bd0053690fb1fd3857/qwen2-0_5b-instruct-q4_0.gguf";
let mapped = MappedGGUFModel::from_path(path)?;
let model = OwnedQuantizedModel::from_mapped(&mapped)?;
let vocab = mapped.model.vocabulary().expect("vocab");
println!("=== Single Token Full Forward Trace ===\n");
let tok = 17u32; let mut cache = OwnedQuantizedKVCache::new(
model.config().num_layers,
model.config().num_kv_heads * (model.config().hidden_dim / model.config().num_heads),
1024,
);
let logits_cached = model.forward_cached(tok, &mut cache, 0)?;
println!("forward_cached (single token, position 0):");
let mut indexed: Vec<_> = logits_cached.iter().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
println!("Top 10 predictions:");
for (tok_id, logit) in indexed.iter().take(10) {
let tok_str = vocab.get(*tok_id).map_or("?", |s| s.as_str());
println!(" Token {} ({:?}): logit={:.4}", tok_id, tok_str, logit);
}
println!("\nforward([17]) (seq_len=1):");
let logits_forward = model.forward(&[tok])?;
let mut indexed2: Vec<_> = logits_forward.iter().enumerate().collect();
indexed2.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
println!("Top 10 predictions:");
for (tok_id, logit) in indexed2.iter().take(10) {
let tok_str = vocab.get(*tok_id).map_or("?", |s| s.as_str());
println!(" Token {} ({:?}): logit={:.4}", tok_id, tok_str, logit);
}
let max_diff = logits_cached
.iter()
.zip(logits_forward.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
println!(
"\nMax difference between forward_cached and forward: {:.6}",
max_diff
);
println!("\nSpecific token logits (forward_cached / forward):");
println!(
" Token 17 (\"2\"): {:.4} / {:.4}",
logits_cached[17], logits_forward[17]
);
println!(
" Token 19 (\"4\"): {:.4} / {:.4}",
logits_cached[19], logits_forward[19]
);
println!(
" Token 0 (\"!\"): {:.4} / {:.4}",
logits_cached[0], logits_forward[0]
);
println!("\n=== Hidden State Evolution (manual trace) ===");
let hidden_dim = model.config().hidden_dim;
let num_heads = model.config().num_heads;
let num_kv_heads = model.config().num_kv_heads;
let head_dim = hidden_dim / num_heads;
let q_dim = num_heads * head_dim;
let k_dim = num_kv_heads * head_dim;
let emb_start = tok as usize * hidden_dim;
let hidden = model.token_embedding()[emb_start..emb_start + hidden_dim].to_vec();
println!(
"After embedding: norm={:.4}",
hidden.iter().map(|x| x * x).sum::<f32>().sqrt()
);
for layer_idx in 0..1 {
let layer = &model.layers()[layer_idx];
let sum_sq: f32 = hidden.iter().map(|x| x * x).sum();
let inv_rms = 1.0 / ((sum_sq / hidden_dim as f32) + model.config().eps).sqrt();
let normed: Vec<f32> = hidden
.iter()
.zip(layer.attn_norm_weight.iter())
.map(|(h, w)| h * inv_rms * w)
.collect();
println!(
"After RMSNorm (layer {}): norm={:.4}",
layer_idx,
normed.iter().map(|x| x * x).sum::<f32>().sqrt()
);
let mut qkv = model.qkv_matmul(&normed, &layer.qkv_weight)?;
if let Some(ref bias) = layer.qkv_bias {
for i in 0..qkv.len() {
qkv[i] += bias[i];
}
}
let _q = &qkv[0..q_dim];
let _k = &qkv[q_dim..q_dim + k_dim];
let v = &qkv[q_dim + k_dim..];
let group_size = num_heads / num_kv_heads;
let mut attn_out = vec![0.0f32; q_dim];
for h in 0..num_heads {
let kv_head = h / group_size;
let v_start = kv_head * head_dim;
let out_start = h * head_dim;
attn_out[out_start..out_start + head_dim]
.copy_from_slice(&v[v_start..v_start + head_dim]);
}
println!(
"After attention (layer {}): attn_out norm={:.4}",
layer_idx,
attn_out.iter().map(|x| x * x).sum::<f32>().sqrt()
);
}
println!("\n=== Final Logit Analysis ===");
let logit_min = logits_forward.iter().copied().fold(f32::INFINITY, f32::min);
let logit_max = logits_forward
.iter()
.copied()
.fold(f32::NEG_INFINITY, f32::max);
let logit_mean: f32 = logits_forward.iter().sum::<f32>() / logits_forward.len() as f32;
let logit_std: f32 = (logits_forward
.iter()
.map(|x| (x - logit_mean).powi(2))
.sum::<f32>()
/ logits_forward.len() as f32)
.sqrt();
println!("Logit statistics:");
println!(" min: {:.4}", logit_min);
println!(" max: {:.4}", logit_max);
println!(" mean: {:.4}", logit_mean);
println!(" std: {:.4}", logit_std);
println!("\nDigit token logits:");
for d in 0..=9 {
let digit_str = d.to_string();
let tok_id = vocab
.iter()
.enumerate()
.find(|(_, s)| s.as_str() == digit_str)
.map(|(i, _)| i);
if let Some(tok_id) = tok_id {
println!(
" '{}' (token {}): logit={:.4}",
d, tok_id, logits_forward[tok_id]
);
}
}
Ok(())
}