use std::path::Path;
use tabicl_model::TabICL;
use tabicl_model::tabicl::{ColFeatureGroup, TabICLConfig};
fn read_f32(path: &Path) -> Option<Vec<f32>> {
let bytes = std::fs::read(path).ok()?;
let mut out = Vec::with_capacity(bytes.len() / 4);
for chunk in bytes.chunks_exact(4) {
out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
Some(out)
}
#[test]
fn depth2_full_config_parity() {
let base = Path::new("/tmp/tabicl_d2");
let ckpt = base.join("ckpt.json");
if !ckpt.exists() {
return;
}
let mut cfg = TabICLConfig::default();
cfg.max_classes = 10;
cfg.num_quantiles = 999;
cfg.embed_dim = 128;
cfg.col_num_blocks = 2;
cfg.col_nhead = 8;
cfg.col_num_inds = 16;
cfg.col_feature_group = ColFeatureGroup::Same;
cfg.col_feature_group_size = 3;
cfg.col_target_aware = true;
cfg.col_ssmax = "qassmax-mlp-elementwise".into();
cfg.row_num_blocks = 2;
cfg.row_nhead = 8;
cfg.row_num_cls = 4;
cfg.icl_num_blocks = 2;
cfg.icl_nhead = 8;
cfg.icl_ssmax = "qassmax-mlp-elementwise".into();
let mut model = TabICL::new(cfg.clone());
model.load_from_file(&ckpt).expect("load");
let x_flat = read_f32(&base.join("input.bin")).unwrap();
let y_ref = read_f32(&base.join("output.bin")).unwrap();
let x = ndarray::Array3::from_shape_vec((1, 6, 4), x_flat).unwrap();
let y_train = ndarray::Array2::from_shape_vec((1, 4), vec![0_usize, 1, 2, 3]).unwrap();
let out = model.forward(x.view(), Some(y_train.view()), None).unwrap();
let n_train = 4;
let active = 4;
let mut max_diff = 0.0_f32;
for i in 0..2 {
for c in 0..active {
let r = out[(0, n_train + i, c)];
let p = y_ref[i * active + c];
let d = (r - p).abs();
if d > max_diff {
max_diff = d;
}
}
}
eprintln!("depth=2 max diff: {max_diff}");
eprintln!("py row0: {:?}", &y_ref[..active]);
let rust_row0: Vec<_> = (0..active).map(|c| out[(0, n_train, c)]).collect();
eprintln!("rs row0: {rust_row0:?}");
assert!(max_diff < 0.1, "depth=2 too large: {max_diff}");
}