use std::collections::BTreeMap;
use std::fs::File;
use std::io::Read;
use std::path::{Path, PathBuf};
use std::process::ExitCode;
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct CheckpointFile {
path: PathBuf,
step: u32,
layer: u32,
label: String,
shape: Vec<u64>,
data: Vec<f32>,
}
fn parse_filename(path: &Path) -> Option<(u32, u32, String)> {
let name = path.file_name()?.to_str()?;
if !name.ends_with(".bin") {
return None;
}
let stem = name.strip_suffix(".bin")?;
let mut parts = stem.splitn(3, '_');
let step_part = parts.next()?;
let layer_part = parts.next()?;
let label = parts.next()?.to_string();
let step = step_part.strip_prefix("step")?.parse().ok()?;
let layer = layer_part.strip_prefix("layer")?.parse().ok()?;
Some((step, layer, label))
}
fn read_dump(path: &Path) -> std::io::Result<CheckpointFile> {
let mut f = File::open(path)?;
let mut magic = [0u8; 4];
f.read_exact(&mut magic)?;
if &magic != b"BST1" {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"bad magic",
));
}
let mut rank_buf = [0u8; 1];
f.read_exact(&mut rank_buf)?;
let rank = rank_buf[0] as usize;
let mut shape = Vec::with_capacity(rank);
for _ in 0..rank {
let mut dim_buf = [0u8; 8];
f.read_exact(&mut dim_buf)?;
shape.push(u64::from_le_bytes(dim_buf));
}
let mut dtype_buf = [0u8; 1];
f.read_exact(&mut dtype_buf)?;
if dtype_buf[0] != 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"unsupported dtype",
));
}
let n_elem: u64 = shape.iter().product();
let mut raw = vec![0u8; (n_elem * 4) as usize];
f.read_exact(&mut raw)?;
let mut data = Vec::with_capacity(n_elem as usize);
for chunk in raw.chunks_exact(4) {
data.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
let (step, layer, label) = parse_filename(path)
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidData, "bad filename"))?;
Ok(CheckpointFile {
path: path.to_path_buf(),
step,
layer,
label,
shape,
data,
})
}
fn list_dumps(dir: &Path) -> std::io::Result<Vec<CheckpointFile>> {
let mut out = Vec::new();
for entry in std::fs::read_dir(dir)? {
let entry = entry?;
let p = entry.path();
if p.extension().and_then(|s| s.to_str()) == Some("bin") {
match read_dump(&p) {
Ok(cf) => out.push(cf),
Err(e) => eprintln!("warn: skipping {} ({})", p.display(), e),
}
}
}
Ok(out)
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct Diff {
step: u32,
layer: u32,
label: String,
n_elem: usize,
a_l2: f64,
b_l2: f64,
max_abs_diff: f64,
l2_diff: f64,
relative: f64,
nan_a: usize,
nan_b: usize,
}
fn diff_pair(a: &CheckpointFile, b: &CheckpointFile) -> Diff {
let n = a.data.len().min(b.data.len());
let mut max_abs = 0.0_f64;
let mut sum_sq_a = 0.0_f64;
let mut sum_sq_b = 0.0_f64;
let mut sum_sq_d = 0.0_f64;
let mut nan_a = 0usize;
let mut nan_b = 0usize;
for i in 0..n {
let av = a.data[i];
let bv = b.data[i];
if !av.is_finite() {
nan_a += 1;
}
if !bv.is_finite() {
nan_b += 1;
}
if av.is_finite() && bv.is_finite() {
let d = (av - bv) as f64;
sum_sq_a += (av as f64) * (av as f64);
sum_sq_b += (bv as f64) * (bv as f64);
sum_sq_d += d * d;
let ad = d.abs();
if ad > max_abs {
max_abs = ad;
}
}
}
let a_l2 = sum_sq_a.sqrt();
let b_l2 = sum_sq_b.sqrt();
let l2_diff = sum_sq_d.sqrt();
let denom = a_l2.max(b_l2).max(1e-12);
let relative = l2_diff / denom;
Diff {
step: a.step,
layer: a.layer,
label: a.label.clone(),
n_elem: n,
a_l2,
b_l2,
max_abs_diff: max_abs,
l2_diff,
relative,
nan_a,
nan_b,
}
}
fn main() -> ExitCode {
let args: Vec<String> = std::env::args().collect();
if args.len() < 3 {
eprintln!(
"usage: {} <dir-a> <dir-b> [--threshold <rel>]\n\n\
Compares two CANDLE_BISECT_DUMP_DIR outputs. Reports the first\n\
checkpoint where relative L2 difference exceeds the threshold\n\
(default 0.001 = 0.1%) and the top-5 worst checkpoints.",
args.first()
.map(|s| s.as_str())
.unwrap_or("gemma4_bisect_diff"),
);
return ExitCode::from(2);
}
let dir_a = PathBuf::from(&args[1]);
let dir_b = PathBuf::from(&args[2]);
let mut threshold = 0.001_f64;
let mut layer_filter: Option<u32> = None;
let mut i = 3;
while i < args.len() {
match args[i].as_str() {
"--threshold" => {
if let Some(v) = args.get(i + 1).and_then(|s| s.parse::<f64>().ok()) {
threshold = v;
i += 2;
continue;
}
}
"--layer" => {
if let Some(v) = args.get(i + 1).and_then(|s| s.parse::<u32>().ok()) {
layer_filter = Some(v);
i += 2;
continue;
}
}
_ => {}
}
i += 1;
}
let a = match list_dumps(&dir_a) {
Ok(v) => v,
Err(e) => {
eprintln!("error: cannot read {}: {}", dir_a.display(), e);
return ExitCode::from(1);
}
};
let b = match list_dumps(&dir_b) {
Ok(v) => v,
Err(e) => {
eprintln!("error: cannot read {}: {}", dir_b.display(), e);
return ExitCode::from(1);
}
};
let mut by_key: BTreeMap<(u32, u32, String), (Option<CheckpointFile>, Option<CheckpointFile>)> =
BTreeMap::new();
for cf in a {
let key = (cf.step, cf.layer, cf.label.clone());
by_key.entry(key).or_insert((None, None)).0 = Some(cf);
}
for cf in b {
let key = (cf.step, cf.layer, cf.label.clone());
by_key.entry(key).or_insert((None, None)).1 = Some(cf);
}
let mut diffs: Vec<Diff> = Vec::new();
let mut a_only = 0usize;
let mut b_only = 0usize;
let mut shape_mismatch = 0usize;
for ((_step, _layer, _label), (oa, ob)) in by_key.iter() {
match (oa, ob) {
(Some(a), Some(b)) => {
if a.shape != b.shape {
shape_mismatch += 1;
eprintln!(
"warn: shape mismatch step={} layer={} label={}: {:?} vs {:?}",
a.step, a.layer, a.label, a.shape, b.shape,
);
continue;
}
diffs.push(diff_pair(a, b));
}
(Some(_), None) => a_only += 1,
(None, Some(_)) => b_only += 1,
_ => {}
}
}
println!(
"==> Bisect diff: {} ↔ {}\n matched checkpoints: {}, only-A: {}, only-B: {}, shape-mismatch: {}",
dir_a.display(),
dir_b.display(),
diffs.len(),
a_only,
b_only,
shape_mismatch,
);
let first_diverge = diffs.iter().find(|d| d.relative > threshold);
println!(
"\n==> First divergence above threshold (rel L2 > {})",
threshold
);
if let Some(d) = first_diverge {
println!(
" step={:4} layer={:3} {:30} | max|Δ|={:.4e} L2_diff={:.4e} rel={:.4e}",
d.step, d.layer, d.label, d.max_abs_diff, d.l2_diff, d.relative,
);
} else {
println!(
" NO checkpoint diverged above threshold ({}). Devices match within tolerance.",
threshold
);
}
let mut sorted = diffs.clone();
sorted.sort_by(|a, b| {
b.relative
.partial_cmp(&a.relative)
.unwrap_or(std::cmp::Ordering::Equal)
});
println!("\n==> Top-10 worst divergences (by relative L2)");
println!(
" {:>4} {:>3} {:30} {:>11} {:>11} {:>11} {:>10} {:>10}",
"step", "lyr", "label", "max|Δ|", "L2(diff)", "rel", "L2(A)", "L2(B)",
);
for d in sorted.iter().take(10) {
println!(
" {:>4} {:>3} {:30} {:>11.3e} {:>11.3e} {:>11.3e} {:>10.3e} {:>10.3e}",
d.step, d.layer, d.label, d.max_abs_diff, d.l2_diff, d.relative, d.a_l2, d.b_l2,
);
}
let mut by_layer: BTreeMap<(u32, u32), Vec<f64>> = BTreeMap::new();
for d in &diffs {
by_layer
.entry((d.step, d.layer))
.or_default()
.push(d.relative);
}
println!("\n==> Per-layer summary (max + mean rel L2 across checkpoints in that layer)");
println!(
" {:>4} {:>3} {:>11} {:>11} {:>5}",
"step", "lyr", "max-rel", "mean-rel", "n"
);
let mut lay_rows: Vec<((u32, u32), f64, f64, usize)> = by_layer
.iter()
.map(|((s, l), v)| {
let max = v.iter().cloned().fold(0.0_f64, f64::max);
let mean = v.iter().sum::<f64>() / v.len() as f64;
((*s, *l), max, mean, v.len())
})
.collect();
lay_rows.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
for ((s, l), max, mean, n) in lay_rows.iter().take(15) {
println!(
" {:>4} {:>3} {:>11.3e} {:>11.3e} {:>5}",
s, l, max, mean, n
);
}
if let Some(filter) = layer_filter {
println!(
"\n==> Full checkpoint chain for layer {} (sorted by step, label)",
filter,
);
println!(
" {:>4} {:30} {:>11} {:>11} {:>11} {:>10}",
"step", "label", "max|Δ|", "L2(diff)", "rel", "L2(A)",
);
let mut chain: Vec<&Diff> = diffs.iter().filter(|d| d.layer == filter).collect();
chain.sort_by(|a, b| a.step.cmp(&b.step).then_with(|| a.label.cmp(&b.label)));
for d in chain {
println!(
" {:>4} {:30} {:>11.3e} {:>11.3e} {:>11.3e} {:>10.3e}",
d.step, d.label, d.max_abs_diff, d.l2_diff, d.relative, d.a_l2,
);
}
}
let nan_total_a: usize = diffs.iter().map(|d| d.nan_a).sum();
let nan_total_b: usize = diffs.iter().map(|d| d.nan_b).sum();
if nan_total_a > 0 || nan_total_b > 0 {
println!(
"\n==> NaN/Inf detected: A={}, B={}",
nan_total_a, nan_total_b
);
}
if first_diverge.is_some() {
ExitCode::from(1)
} else {
ExitCode::from(0)
}
}