use std::process::ExitCode;
use std::sync::Arc;
use std::time::Instant;
use rullama::backend::{BindGroupCache, Pipelines, WeightCache, WgpuCtx};
use rullama::gguf::{FileFetcher, GgufReader};
use rullama::reference::Weights;
use rullama::reference::diffusion::DiffusionConfig;
use rullama::reference::diffusion::forward::diffusion_forward;
use rullama::reference::diffusion::gpu::diffusion_forward_gpu;
fn read_i32(path: &str) -> Vec<u32> {
std::fs::read(path)
.expect("read i32 file")
.chunks_exact(4)
.map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as u32)
.collect()
}
fn argmax(row: &[f32]) -> usize {
let mut best = 0;
let mut bv = f32::NEG_INFINITY;
for (i, &v) in row.iter().enumerate() {
if v > bv {
bv = v;
best = i;
}
}
best
}
fn main() -> ExitCode {
let mut pos = Vec::new();
let mut canvas_trunc: Option<usize> = None;
for a in std::env::args().skip(1) {
if let Some(v) = a.strip_prefix("--canvas=") {
canvas_trunc = Some(v.parse().expect("canvas N"));
} else if a.starts_with("--") {
eprintln!("unknown flag {a}");
} else {
pos.push(a);
}
}
let (Some(model), Some(pf), Some(cf)) = (pos.first(), pos.get(1), pos.get(2)) else {
eprintln!(
"usage: diffusion_gpu_parity <model.gguf> <prompt.i32> <canvas.i32> [--canvas=N]"
);
return ExitCode::from(2);
};
let prompt_ids = read_i32(pf);
let mut canvas_ids = read_i32(cf);
if let Some(n) = canvas_trunc {
canvas_ids.truncate(n);
}
println!(
"prompt {} tok, canvas {} tok",
prompt_ids.len(),
canvas_ids.len()
);
let fetcher = FileFetcher::open(std::path::Path::new(model)).expect("open");
let r = pollster::block_on(GgufReader::new_streaming(Arc::new(fetcher))).expect("gguf");
let cfg = DiffusionConfig::from_gguf(&r).expect("diffusion config");
let r_arc = Arc::new(r);
let weights = Weights::new(r_arc.clone());
println!("running CPU oracle ...");
let t0 = Instant::now();
let cpu = diffusion_forward(&cfg, &weights, &prompt_ids, &canvas_ids).expect("cpu forward");
println!(" cpu forward in {:?}", t0.elapsed());
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let pipes = Pipelines::new(&ctx.device);
let bind_cache = Arc::new(BindGroupCache::default());
let wcache = WeightCache::new(
r_arc.clone(),
ctx.device.clone(),
ctx.queue.clone(),
bind_cache,
);
println!("running GPU hybrid ...");
let t0 = Instant::now();
let gpu = pollster::block_on(diffusion_forward_gpu(
&cfg,
&ctx,
&pipes,
&wcache,
&weights,
&prompt_ids,
&canvas_ids,
None,
1.0,
))
.expect("gpu forward");
println!(" gpu forward in {:?}", t0.elapsed());
let vocab = cfg.base.vocab_size as usize;
let c = canvas_ids.len();
assert_eq!(cpu.len(), c * vocab);
assert_eq!(gpu.len(), c * vocab);
let mut max_abs = 0f32;
let mut argmax_match = 0usize;
for ci in 0..c {
let cr = &cpu[ci * vocab..(ci + 1) * vocab];
let gr = &gpu[ci * vocab..(ci + 1) * vocab];
for (a, b) in cr.iter().zip(gr.iter()) {
max_abs = max_abs.max((a - b).abs());
}
if argmax(cr) == argmax(gr) {
argmax_match += 1;
}
}
println!("\nmax abs logit diff: {max_abs:.5}");
println!("argmax match: {argmax_match}/{c} canvas positions");
if argmax_match == c {
println!("PASS");
ExitCode::SUCCESS
} else {
println!("FAIL — argmax diverges (see above)");
ExitCode::FAILURE
}
}