tabicl-model 2.1.1

TabICL transformer model — column embedding, row interaction, ICL learning, KV cache.
//! Cross-stack parity with SSMax enabled in the column embedder
//! (`col_ssmax="qassmax-mlp-elementwise"`). Exercises the SSMax MLP +
//! query-aware modulation paths against Python's reference output.

use std::path::Path;
use tabicl_model::TabICL;
use tabicl_model::tabicl::{ColFeatureGroup, TabICLConfig};

fn read_f32_blob(path: &Path) -> Option<Vec<f32>> {
    let bytes = std::fs::read(path).ok()?;
    let mut out = Vec::with_capacity(bytes.len() / 4);
    for chunk in bytes.chunks_exact(4) {
        out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
    }
    Some(out)
}

#[test]
fn python_ssmax_forward_matches_rust() {
    let ckpt_path = Path::new("/tmp/tabicl_parity_ssmax/ckpt.json");
    let input_path = Path::new("/tmp/tabicl_parity_ssmax/input.bin");
    let output_path = Path::new("/tmp/tabicl_parity_ssmax/output.bin");

    if !ckpt_path.exists() {
        eprintln!(
            "ssmax parity: fixture not found at {:?}, skipping",
            ckpt_path
        );
        return;
    }

    let mut cfg = TabICLConfig::default();
    cfg.max_classes = 3;
    cfg.num_quantiles = 5;
    cfg.embed_dim = 8;
    cfg.col_num_blocks = 1;
    cfg.col_nhead = 2;
    cfg.col_num_inds = 4;
    cfg.col_affine = false;
    cfg.col_feature_group = ColFeatureGroup::None;
    cfg.col_target_aware = false;
    cfg.col_ssmax = "qassmax-mlp-elementwise".into();
    cfg.row_num_blocks = 1;
    cfg.row_nhead = 2;
    cfg.row_num_cls = 2;
    cfg.row_rope_base = 100_000.0;
    cfg.row_rope_interleaved = false;
    cfg.icl_num_blocks = 1;
    cfg.icl_nhead = 2;
    cfg.icl_ssmax = "none".into();
    cfg.ff_factor = 2;
    cfg.dropout = 0.0;
    cfg.norm_first = true;
    cfg.bias_free_ln = false;

    let mut model = TabICL::new(cfg.clone());
    model
        .load_from_file(ckpt_path)
        .expect("load SSMax checkpoint");

    let x_flat = read_f32_blob(input_path).expect("read input");
    let y_ref = read_f32_blob(output_path).expect("read output");

    let x = ndarray::Array3::from_shape_vec((1, 4, 3), x_flat).unwrap();
    let y_train = ndarray::Array2::from_shape_vec((1, 3), vec![0_usize, 1, 2]).unwrap();
    let out = model
        .forward(x.view(), Some(y_train.view()), None)
        .expect("Rust forward");

    let n_train = 3;
    let n_test = y_ref.len() / cfg.max_classes;
    eprintln!("Python ref: {:?}", y_ref);
    let mut rust_test: Vec<f32> = Vec::new();
    for i in 0..n_test {
        for c in 0..cfg.max_classes {
            rust_test.push(out[(0, n_train + i, c)]);
        }
    }
    eprintln!("Rust out: {:?}", rust_test);

    let mut max_diff = 0.0_f32;
    for (r, p) in rust_test.iter().zip(y_ref.iter()) {
        let d = (r - p).abs();
        if d > max_diff {
            max_diff = d;
        }
    }
    eprintln!("Max diff: {max_diff}");
    assert!(
        max_diff < 1e-5,
        "SSMax parity exceeded tolerance: max_diff={max_diff}"
    );
}