tabicl-model 2.1.1

TabICL transformer model — column embedding, row interaction, ICL learning, KV cache.
//! Cross-stack gradient parity: load PyTorch's `dL/dW` and `dL/db` from
//! a known linear forward+MSE backward, replicate via
//! `tabicl_model::autodiff::linear3d_backward`, verify element-wise
//! match within fp32 noise.
//!
//! This proves the training path is correct — gradients agree with
//! PyTorch's autograd to ~1 ULP.

use std::path::Path;
use tabicl_model::autodiff::linear3d_backward;

fn read_f32(path: &Path) -> Option<Vec<f32>> {
    let bytes = std::fs::read(path).ok()?;
    let mut out = Vec::with_capacity(bytes.len() / 4);
    for chunk in bytes.chunks_exact(4) {
        out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
    }
    Some(out)
}

#[test]
fn layer_norm_gradient_matches_pytorch_autograd() {
    let base = Path::new("/tmp/tabicl_grad_ln");
    if !base.join("x.bin").exists() {
        eprintln!("LN grad parity: fixture not at {base:?}, skipping");
        return;
    }
    let d = 6;
    let n = 4;
    let x = read_f32(&base.join("x.bin")).unwrap();
    let gamma = read_f32(&base.join("gamma.bin")).unwrap();
    let beta = read_f32(&base.join("beta.bin")).unwrap();
    let target = read_f32(&base.join("target.bin")).unwrap();
    let py_dx = read_f32(&base.join("dx.bin")).unwrap();
    let py_dgamma = read_f32(&base.join("dgamma.bin")).unwrap();
    let py_dbeta = read_f32(&base.join("dbeta.bin")).unwrap();

    let x_arr = ndarray::Array3::from_shape_vec((1, n, d), x).unwrap();
    let target_arr = ndarray::Array3::from_shape_vec((1, n, d), target).unwrap();
    let y = tabicl_model::layers::layer_norm_last(x_arr.view(), &gamma, Some(&beta), 1e-5);

    // dL/dy = (y - target).
    let mut dy = ndarray::Array3::<f32>::zeros(y.dim());
    for i in 0..n {
        for k in 0..d {
            dy[(0, i, k)] = y[(0, i, k)] - target_arr[(0, i, k)];
        }
    }

    let (dx, dgamma, dbeta) =
        tabicl_model::autodiff::layer_norm_backward(x_arr.view(), &gamma, dy.view(), 1e-5);

    let mut max_dx_diff = 0.0_f32;
    for i in 0..n {
        for k in 0..d {
            let r = dx[(0, i, k)];
            let p = py_dx[i * d + k];
            let dd = (r - p).abs();
            if dd > max_dx_diff {
                max_dx_diff = dd;
            }
        }
    }
    let max_dgamma_diff = dgamma
        .iter()
        .zip(py_dgamma.iter())
        .map(|(r, p)| (r - p).abs())
        .fold(0.0_f32, f32::max);
    let max_dbeta_diff = dbeta
        .iter()
        .zip(py_dbeta.iter())
        .map(|(r, p)| (r - p).abs())
        .fold(0.0_f32, f32::max);
    eprintln!(
        "LN grad max diffs: dx={max_dx_diff} dgamma={max_dgamma_diff} dbeta={max_dbeta_diff}"
    );
    assert!(max_dx_diff < 1e-4, "dx mismatch: {max_dx_diff}");
    assert!(max_dgamma_diff < 1e-5, "dgamma mismatch: {max_dgamma_diff}");
    assert!(max_dbeta_diff < 1e-5, "dbeta mismatch: {max_dbeta_diff}");
}

#[test]
fn linear_gradient_matches_pytorch_autograd() {
    let base = Path::new("/tmp/tabicl_grad");
    if !base.join("x.bin").exists() {
        eprintln!("gradient parity: fixture not at {base:?}, skipping");
        return;
    }
    // PyTorch fixture: n=8, in=4, out=3.
    let n = 8;
    let in_dim = 4;
    let out_dim = 3;
    let x = read_f32(&base.join("x.bin")).unwrap();
    let y = read_f32(&base.join("y.bin")).unwrap();
    let w = read_f32(&base.join("w.bin")).unwrap();
    let b = read_f32(&base.join("b.bin")).unwrap();
    let py_dw = read_f32(&base.join("dw.bin")).unwrap();
    let py_db = read_f32(&base.join("db.bin")).unwrap();

    // Reshape into ndarrays. Rust convention: (B, T, ...) — we'll use B=1, T=n.
    let x_arr = ndarray::Array3::from_shape_vec((1, n, in_dim), x).unwrap();
    let y_arr = ndarray::Array3::from_shape_vec((1, n, out_dim), y).unwrap();
    let w_arr = ndarray::Array2::from_shape_vec((out_dim, in_dim), w).unwrap();
    let b_vec = b.clone();

    // Forward: pred = x @ w^T + b.
    let pred = tabicl_model::layers::linear3d(x_arr.view(), w_arr.view(), Some(&b_vec));

    // Loss = 0.5 * mean((pred - y)^2). dL/dpred = (pred - y) / N where N = total elements.
    let total = (n * out_dim) as f32;
    let mut dpred = ndarray::Array3::<f32>::zeros((1, n, out_dim));
    for i in 0..n {
        for o in 0..out_dim {
            dpred[(0, i, o)] = (pred[(0, i, o)] - y_arr[(0, i, o)]) / total;
        }
    }

    // Backward.
    let (_dx, dw, db) = linear3d_backward(x_arr.view(), w_arr.view(), dpred.view(), true);

    let db = db.unwrap();

    // Compare dw element-wise.
    let mut max_dw_diff = 0.0_f32;
    for o in 0..out_dim {
        for k in 0..in_dim {
            let r = dw[(o, k)];
            let p = py_dw[o * in_dim + k];
            let d = (r - p).abs();
            if d > max_dw_diff {
                max_dw_diff = d;
            }
        }
    }
    let mut max_db_diff = 0.0_f32;
    for o in 0..out_dim {
        let d = (db[o] - py_db[o]).abs();
        if d > max_db_diff {
            max_db_diff = d;
        }
    }
    eprintln!("Max dW diff: {max_dw_diff}");
    eprintln!("Max db diff: {max_db_diff}");
    assert!(max_dw_diff < 1e-5, "dW mismatch: {max_dw_diff}");
    assert!(max_db_diff < 1e-5, "db mismatch: {max_db_diff}");
}