use std::path::Path;
use tabicl_model::TabICL;
use tabicl_model::tabicl::{ColFeatureGroup, TabICLConfig};
fn read_f32_blob(path: &Path) -> Option<Vec<f32>> {
let bytes = std::fs::read(path).ok()?;
let mut out = Vec::with_capacity(bytes.len() / 4);
for chunk in bytes.chunks_exact(4) {
out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
Some(out)
}
#[test]
fn python_checkpoint_forward_matches_rust() {
let ckpt_path = Path::new("/tmp/tabicl_parity/ckpt.json");
let input_path = Path::new("/tmp/tabicl_parity/input.bin");
let output_path = Path::new("/tmp/tabicl_parity/output.bin");
if !ckpt_path.exists() {
eprintln!(
"python_checkpoint_parity: fixture not found at {:?}, skipping",
ckpt_path
);
return;
}
let mut cfg = TabICLConfig::default();
cfg.max_classes = 3;
cfg.num_quantiles = 5;
cfg.embed_dim = 8;
cfg.col_num_blocks = 1;
cfg.col_nhead = 2;
cfg.col_num_inds = 4;
cfg.col_affine = false;
cfg.col_feature_group = ColFeatureGroup::None;
cfg.col_target_aware = false;
cfg.col_ssmax = "none".into();
cfg.row_num_blocks = 1;
cfg.row_nhead = 2;
cfg.row_num_cls = 2;
cfg.row_rope_base = 100_000.0;
cfg.row_rope_interleaved = false;
cfg.icl_num_blocks = 1;
cfg.icl_nhead = 2;
cfg.icl_ssmax = "none".into();
cfg.ff_factor = 2;
cfg.dropout = 0.0;
cfg.norm_first = true;
cfg.bias_free_ln = false;
let mut model = TabICL::new(cfg.clone());
model
.load_from_file(ckpt_path)
.expect("load Python checkpoint");
let x_flat = read_f32_blob(input_path).expect("read input");
let y_ref = read_f32_blob(output_path).expect("read output");
let x = ndarray::Array3::from_shape_vec((1, 4, 3), x_flat).unwrap();
let y_train = ndarray::Array2::from_shape_vec((1, 3), vec![0_usize, 1, 2]).unwrap();
let out = model
.forward(x.view(), Some(y_train.view()), None)
.expect("Rust forward");
assert_eq!(
y_ref.len(),
3,
"expected 3 reference values, got {}",
y_ref.len()
);
eprintln!("Python ref: {:?}", y_ref);
eprintln!("Rust out shape: {:?}", out.shape());
eprintln!(
"Rust out at test pos: {:?}",
(0..3).map(|c| out[(0, 3, c)]).collect::<Vec<_>>()
);
let n_test = y_ref.len() / cfg.max_classes;
let n_train = 3;
for i in 0..n_test {
for c in 0..cfg.max_classes {
let rust_val = out[(0, n_train + i, c)];
let py_val = y_ref[i * cfg.max_classes + c];
let diff = (rust_val - py_val).abs();
assert!(
diff < 1e-5,
"logit[{i},{c}] Rust={rust_val} Python={py_val} diff={diff}"
);
}
}
eprintln!("✓ Rust output matches Python within {} tolerance", 1e-3);
}