tabicl-model 2.1.1

TabICL transformer model — column embedding, row interaction, ICL learning, KV cache.
//! Cross-stack parity test: load a Python TabICL checkpoint, run forward
//! on a known input, verify Rust output matches Python's reference within
//! fp32 numerical noise.
//!
//! The checkpoint is generated by:
//!   /tmp/tabicl_parity/{ckpt.json,ckpt.bin,input.bin,output.bin}
//!
//! When the fixture files don't exist (CI without Python), the test
//! short-circuits to OK so the workspace still builds.

use std::path::Path;
use tabicl_model::TabICL;
use tabicl_model::tabicl::{ColFeatureGroup, TabICLConfig};

fn read_f32_blob(path: &Path) -> Option<Vec<f32>> {
    let bytes = std::fs::read(path).ok()?;
    let mut out = Vec::with_capacity(bytes.len() / 4);
    for chunk in bytes.chunks_exact(4) {
        out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
    }
    Some(out)
}

#[test]
fn python_checkpoint_forward_matches_rust() {
    let ckpt_path = Path::new("/tmp/tabicl_parity/ckpt.json");
    let input_path = Path::new("/tmp/tabicl_parity/input.bin");
    let output_path = Path::new("/tmp/tabicl_parity/output.bin");

    if !ckpt_path.exists() {
        // Fixture missing — skip without failing (cross-platform CI).
        eprintln!(
            "python_checkpoint_parity: fixture not found at {:?}, skipping",
            ckpt_path
        );
        return;
    }

    // Mirror the Python TabICL config used to generate the fixture.
    let mut cfg = TabICLConfig::default();
    cfg.max_classes = 3;
    cfg.num_quantiles = 5;
    cfg.embed_dim = 8;
    cfg.col_num_blocks = 1;
    cfg.col_nhead = 2;
    cfg.col_num_inds = 4;
    cfg.col_affine = false;
    cfg.col_feature_group = ColFeatureGroup::None;
    cfg.col_target_aware = false;
    cfg.col_ssmax = "none".into();
    cfg.row_num_blocks = 1;
    cfg.row_nhead = 2;
    cfg.row_num_cls = 2;
    cfg.row_rope_base = 100_000.0;
    cfg.row_rope_interleaved = false;
    cfg.icl_num_blocks = 1;
    cfg.icl_nhead = 2;
    cfg.icl_ssmax = "none".into();
    cfg.ff_factor = 2;
    cfg.dropout = 0.0;
    cfg.norm_first = true;
    cfg.bias_free_ln = false;

    let mut model = TabICL::new(cfg.clone());
    model
        .load_from_file(ckpt_path)
        .expect("load Python checkpoint");

    // Load input + reference output.
    let x_flat = read_f32_blob(input_path).expect("read input");
    let y_ref = read_f32_blob(output_path).expect("read output");

    // X shape: (1, 4, 3); y_train: (1, 3) with [0, 1, 2].
    let x = ndarray::Array3::from_shape_vec((1, 4, 3), x_flat).unwrap();
    let y_train = ndarray::Array2::from_shape_vec((1, 3), vec![0_usize, 1, 2]).unwrap();

    let out = model
        .forward(x.view(), Some(y_train.view()), None)
        .expect("Rust forward");

    // Compare against Python reference: shape (1, 1, 3) — only the test
    // sample's logits.
    assert_eq!(
        y_ref.len(),
        3,
        "expected 3 reference values, got {}",
        y_ref.len()
    );
    eprintln!("Python ref: {:?}", y_ref);
    eprintln!("Rust out shape: {:?}", out.shape());
    eprintln!(
        "Rust out at test pos: {:?}",
        (0..3).map(|c| out[(0, 3, c)]).collect::<Vec<_>>()
    );

    // The Python output is `_train_forward(X, y_train)` which returns
    // predictions only for the test rows — slice from the end.
    let n_test = y_ref.len() / cfg.max_classes;
    let n_train = 3;
    for i in 0..n_test {
        for c in 0..cfg.max_classes {
            let rust_val = out[(0, n_train + i, c)];
            let py_val = y_ref[i * cfg.max_classes + c];
            let diff = (rust_val - py_val).abs();
            // Allow up to 1e-3 — fp32 accumulation drift through 1
            // transformer block + LayerNorm is typically a few × 1e-5.
            assert!(
                diff < 1e-5,
                "logit[{i},{c}] Rust={rust_val} Python={py_val} diff={diff}"
            );
        }
    }
    eprintln!("✓ Rust output matches Python within {} tolerance", 1e-3);
}