use std::process::ExitCode;
use std::sync::Arc;
use rullama::gguf::{FileFetcher, GgufReader};
use rullama::reference::Weights;
use rullama::reference::diffusion::DiffusionConfig;
use rullama::reference::diffusion::forward::diffusion_forward;
fn read_i32(path: &str) -> Vec<u32> {
let bytes = std::fs::read(path).expect("read i32 file");
bytes
.chunks_exact(4)
.map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as u32)
.collect()
}
fn read_f32(path: &str) -> Vec<f32> {
let bytes = std::fs::read(path).expect("read f32 file");
bytes
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
}
fn main() -> ExitCode {
let mut a = std::env::args().skip(1);
let (Some(model), Some(pf), Some(cf), Some(of)) = (a.next(), a.next(), a.next(), a.next())
else {
eprintln!("usage: diffusion_parity <model.gguf> <prompt.i32> <canvas.i32> <oracle.bin>");
return ExitCode::from(2);
};
let prompt = read_i32(&pf);
let canvas = read_i32(&cf);
let oracle = read_f32(&of);
println!(
"prompt={} canvas={} oracle_floats={}",
prompt.len(),
canvas.len(),
oracle.len()
);
let fetcher = FileFetcher::open(std::path::Path::new(&model)).expect("open");
let r = pollster::block_on(GgufReader::new_streaming(Arc::new(fetcher))).expect("gguf");
let cfg = DiffusionConfig::from_gguf(&r).expect("diffusion config");
let vocab = cfg.base.vocab_size as usize;
assert_eq!(oracle.len(), canvas.len() * vocab, "oracle shape mismatch");
let weights = Weights::new(Arc::new(r));
let prev = std::env::var("DG_PREV_LOGITS").ok().map(|p| read_f32(&p));
let t = std::time::Instant::now();
let mine = if let Some(pl) = &prev {
eprintln!("self-conditioning ENABLED ({} floats)", pl.len());
rullama::reference::diffusion::forward::diffusion_forward_sc(
&cfg,
&weights,
&prompt,
&canvas,
Some(pl),
1.0,
)
.expect("sc forward")
} else {
diffusion_forward(&cfg, &weights, &prompt, &canvas).expect("forward")
};
println!("rullama forward: {:.1?}", t.elapsed());
assert_eq!(mine.len(), oracle.len());
{
let mut bytes = Vec::with_capacity(mine.len() * 4);
for &x in &mine {
bytes.extend_from_slice(&x.to_le_bytes());
}
let _ = std::fs::write(format!("{of}.mine.bin"), &bytes);
eprintln!("wrote {of}.mine.bin");
}
let c = canvas.len();
let am = |v: &[f32]| {
v.iter()
.enumerate()
.max_by(|x, y| x.1.partial_cmp(y.1).unwrap())
.unwrap()
.0
};
let mut mismatches = Vec::new();
let mut per_pos_maxabs = vec![0f32; c];
let mut global_max_abs = 0f32;
for ci in 0..c {
let m = &mine[ci * vocab..(ci + 1) * vocab];
let o = &oracle[ci * vocab..(ci + 1) * vocab];
let (am_m, am_o) = (am(m), am(o));
let mut pm = 0f32;
for (a, b) in m.iter().zip(o.iter()) {
pm = pm.max((a - b).abs());
}
per_pos_maxabs[ci] = pm;
global_max_abs = global_max_abs.max(pm);
if am_m != am_o {
mismatches.push((ci, am_m, m[am_m], o[am_m], am_o, m[am_o], o[am_o], pm));
}
}
let mut sorted = per_pos_maxabs.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
let mean = per_pos_maxabs.iter().sum::<f32>() / c as f32;
println!(
"per-position logit max_abs: median={:.3} mean={:.3} p90={:.3} max={:.3}",
sorted[c / 2],
mean,
sorted[c * 9 / 10],
sorted[c - 1]
);
println!("argmax mismatch: {}/{c}", mismatches.len());
for (ci, am_m, lm_m, lm_o, am_o, lo_m, lo_o, pm) in &mismatches {
println!(
" pos {ci}: mine→tok{am_m}(logit {lm_m:.2}; oracle has {lm_o:.2}) oracle→tok{am_o}(mine {lo_m:.2}; oracle {lo_o:.2}) posMaxAbs={pm:.2}"
);
}
let agree = (c - mismatches.len()) as f32 / c as f32;
if agree >= 0.97 {
println!(
"PASS (argmax agreement {:.1}% — logit drift is MoE routing-boundary accumulation, not a bug; see layer bisection)",
agree * 100.0
);
ExitCode::SUCCESS
} else {
println!(
"FAIL (argmax agreement {:.1}% — below 97%, investigate)",
agree * 100.0
);
ExitCode::from(1)
}
}