tabicl-model 2.1.1

TabICL transformer model — column embedding, row interaction, ICL learning, KV cache.
//! Bisect: compare RowInteraction output between Python and Rust.

use std::path::Path;
use tabicl_model::TabICL;
use tabicl_model::tabicl::{ColFeatureGroup, TabICLConfig};

fn read_f32(path: &Path) -> Option<Vec<f32>> {
    let bytes = std::fs::read(path).ok()?;
    let mut out = Vec::with_capacity(bytes.len() / 4);
    for chunk in bytes.chunks_exact(4) {
        out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
    }
    Some(out)
}

#[test]
fn hf_row_interactor_output_matches_python() {
    let ckpt = Path::new("/tmp/tabicl_hf_parity/ckpt.json");
    let row_path = Path::new("/tmp/tabicl_intermediates/row_out.bin");
    if !ckpt.exists() || !row_path.exists() {
        return;
    }
    let mut cfg = TabICLConfig::default();
    cfg.max_classes = 10;
    cfg.num_quantiles = 999;
    cfg.embed_dim = 128;
    cfg.col_num_blocks = 3;
    cfg.col_nhead = 8;
    cfg.col_num_inds = 128;
    cfg.col_feature_group = ColFeatureGroup::Same;
    cfg.col_feature_group_size = 3;
    cfg.col_target_aware = true;
    cfg.col_ssmax = "qassmax-mlp-elementwise".into();
    cfg.row_num_blocks = 3;
    cfg.row_nhead = 8;
    cfg.row_num_cls = 4;
    cfg.icl_num_blocks = 12;
    cfg.icl_nhead = 8;
    cfg.icl_ssmax = "qassmax-mlp-elementwise".into();

    let mut model = TabICL::new(cfg.clone());
    model.load_from_file(ckpt).expect("load");

    let x_flat = read_f32(Path::new("/tmp/tabicl_hf_parity/input.bin")).unwrap();
    let x = ndarray::Array3::from_shape_vec((1, 6, 4), x_flat).unwrap();
    let y_train = ndarray::Array2::from_shape_vec((1, 4), vec![0_usize, 1, 2, 3]).unwrap();
    let col_out = model
        .col
        .forward_with_targets(x.view(), Some(y_train.view()), None, 4)
        .unwrap();
    let row_out = model.row.forward(col_out.view());

    let py = read_f32(row_path).unwrap();
    assert_eq!(row_out.shape(), &[1, 6, 512]);
    let mut max_diff = 0.0_f32;
    let mut at = (0_usize, 0_usize);
    for t in 0..6 {
        for k in 0..512 {
            let r = row_out[(0, t, k)];
            let p = py[t * 512 + k];
            let d = (r - p).abs();
            if d > max_diff {
                max_diff = d;
                at = (t, k);
            }
        }
    }
    eprintln!("row max diff: {max_diff} at t={} k={}", at.0, at.1);
    eprintln!(
        "  Python: {:.6}, Rust: {:.6}",
        py[at.0 * 512 + at.1],
        row_out[(0, at.0, at.1)]
    );
    assert!(max_diff < 1e-3, "row diff: {max_diff}");
}