tabicl-model 2.1.1

TabICL transformer model — column embedding, row interaction, ICL learning, KV cache.
//! Numerical parity tests against hand-computed Python-equivalent values.
//!
//! Each test fixes a small input and parameter set, computes the
//! expected output using the same formula Python `tabicl` would apply,
//! and asserts the Rust output matches within fp32 numerical noise.
//!
//! This is the closest we get to "100% parity" without downloading a
//! TabICL checkpoint and running Python alongside. Each fixture
//! verifies a different layer's math.

use ndarray::{Array, Array2, Array4, array};
use tabicl_model::attention::{AttentionConfig, AttentionParams, multi_head_attention_forward};
use tabicl_model::layers::{OneHotAndLinear, SkippableLinear, layer_norm_last, linear3d};
use tabicl_model::rope::{RopeConfig, RopeTables, apply_rotary_emb_ref};

/// LayerNorm against actual PyTorch output. Truth values captured from
/// `torch.nn.functional.layer_norm(x, (4,), eps=1e-5)` on x=[[[1,2,3,4]]].
#[test]
fn layer_norm_matches_pytorch_output_bit_for_bit() {
    let x = array![[[1.0_f32, 2.0, 3.0, 4.0]]];
    let y = layer_norm_last(x.view(), &[1.0, 1.0, 1.0, 1.0], None, 1e-5);
    // From `torch.nn.functional.layer_norm`:
    let pytorch_truth = [
        -1.3416354656219482_f32,
        -0.4472118318080902,
        0.4472118318080902,
        1.3416354656219482,
    ];
    for k in 0..4 {
        let diff = (y[(0, 0, k)] - pytorch_truth[k]).abs();
        assert!(
            diff < 1e-5,
            "LN[{k}] Rust={} PyTorch={} diff={}",
            y[(0, 0, k)],
            pytorch_truth[k],
            diff
        );
    }
}

/// LayerNorm with γ/β: scales and shifts after normalization.
#[test]
fn layer_norm_with_affine_matches_python_formula() {
    let x = array![[[1.0_f32, 2.0, 3.0]]];
    let gamma = [2.0, 0.5, 1.5];
    let beta = [10.0, 20.0, -5.0];
    let y = layer_norm_last(x.view(), &gamma, Some(&beta), 1e-5);
    let var = 2.0_f32 / 3.0;
    let inv_std = 1.0 / (var + 1e-5).sqrt();
    let expected = [
        (1.0 - 2.0) * inv_std * 2.0 + 10.0,
        (2.0 - 2.0) * inv_std * 0.5 + 20.0,
        (3.0 - 2.0) * inv_std * 1.5 - 5.0,
    ];
    for k in 0..3 {
        assert!((y[(0, 0, k)] - expected[k]).abs() < 1e-5);
    }
}

/// linear3d: known weight × known input + bias → hand-computed output.
#[test]
fn linear3d_matches_python_formula() {
    // x = [[[1, 2]]] (1, 1, 2);  w = [[0.5, 1.0], [-1.0, 2.0]] (2, 2);  b = [0.1, 0.2]
    let x = array![[[1.0_f32, 2.0]]];
    let w = array![[0.5_f32, 1.0], [-1.0, 2.0]];
    let b = [0.1_f32, 0.2];
    let y = linear3d(x.view(), w.view(), Some(&b));
    // y[0,0,0] = 1*0.5 + 2*1.0 + 0.1 = 2.6
    // y[0,0,1] = 1*(-1) + 2*2 + 0.2 = 3.2
    assert!((y[(0, 0, 0)] - 2.6).abs() < 1e-6);
    assert!((y[(0, 0, 1)] - 3.2).abs() < 1e-6);
}

/// OneHotAndLinear: indexes weight column directly.
#[test]
fn one_hot_and_linear_matches_python_formula() {
    // weight = [[1, 2, 3], [4, 5, 6]] (E=2, C=3)
    // bias = [10, 20]
    let weight = array![[1.0_f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
    let m = OneHotAndLinear::from_raw_weight(weight, Some(vec![10.0, 20.0]));
    let src = array![[0_usize, 2, 1]];
    let out = m.forward(src.view());
    // For class c, out = weight[:, c] + bias.
    // c=0: [1, 4] + [10, 20] = [11, 24]
    // c=2: [3, 6] + [10, 20] = [13, 26]
    // c=1: [2, 5] + [10, 20] = [12, 25]
    assert_eq!(out[(0, 0, 0)], 11.0);
    assert_eq!(out[(0, 0, 1)], 24.0);
    assert_eq!(out[(0, 1, 0)], 13.0);
    assert_eq!(out[(0, 2, 1)], 25.0);
}

/// SkippableLinear: sentinel rows pass through unchanged.
#[test]
fn skippable_linear_sentinel_matches_python() {
    // weight: identity-on-first-out; row with all -100 → all -100.
    let w = array![[1.0_f32, 0.0], [0.0, 1.0]];
    let m = SkippableLinear::new(w, Some(vec![10.0, 20.0]), -100.0);
    let src = array![[[1.0_f32, 2.0], [-100.0, -100.0], [3.0, 4.0]]];
    let out = m.forward(src.view());
    // Normal row: [1+10, 2+20] = [11, 22].
    assert_eq!(out[(0, 0, 0)], 11.0);
    assert_eq!(out[(0, 0, 1)], 22.0);
    // Sentinel row: [-100, -100].
    assert_eq!(out[(0, 1, 0)], -100.0);
    assert_eq!(out[(0, 1, 1)], -100.0);
    // Normal again.
    assert_eq!(out[(0, 2, 0)], 13.0);
    assert_eq!(out[(0, 2, 1)], 24.0);
}

/// RoPE non-interleaved against Python truth values.
/// Truth (head_dim=2, base=10000, position 1, x=[0.5, 0.5]):
///   out_lo = -0.15058433946987837
///   out_hi =  0.6908866453380181
#[test]
fn rope_non_interleaved_matches_python_bit_for_bit() {
    let cfg = RopeConfig {
        head_dim: 2,
        base: 10_000.0,
        interleaved: false,
    };
    let tables = RopeTables::new(cfg, 2);
    let x = Array4::from_shape_vec((1, 1, 2, 2), vec![1.0_f32, 0.0, 0.5, 0.5]).unwrap();
    let y = apply_rotary_emb_ref(&x.view(), &tables);
    // Position 0: identity.
    assert!((y[(0, 0, 0, 0)] - 1.0).abs() < 1e-6);
    assert!((y[(0, 0, 0, 1)] - 0.0).abs() < 1e-6);
    // Position 1: Python truth.
    let python_lo = -0.15058433946987837_f32;
    let python_hi = 0.6908866453380181_f32;
    let diff_lo = (y[(0, 0, 1, 0)] - python_lo).abs();
    let diff_hi = (y[(0, 0, 1, 1)] - python_hi).abs();
    assert!(
        diff_lo < 1e-6,
        "RoPE lo Rust={} Python={} diff={}",
        y[(0, 0, 1, 0)],
        python_lo,
        diff_lo
    );
    assert!(
        diff_hi < 1e-6,
        "RoPE hi Rust={} Python={} diff={}",
        y[(0, 0, 1, 1)],
        python_hi,
        diff_hi
    );
}

/// Multi-head attention against actual PyTorch SDPA output.
/// Truth captured from: B=1, T=2, E=4, H=2, with:
///   q = [[1,0,0,0],[0,1,0,0]], k=q, v = [[0.5,1,1.5,2],[3,2.5,2,1.5]]
/// PyTorch output (flattened, row-major):
///   [1.3255960941314697, 1.4953577518463135, 1.75, 1.75,
///    2.1744039058685303, 2.0046422481536865, 1.75, 1.75]
#[test]
fn multi_head_attention_matches_pytorch_bit_for_bit() {
    let embed_dim = 4;
    let num_heads = 2;
    let cfg = AttentionConfig {
        embed_dim,
        num_heads,
        dropout: 0.0,
        bias: true,
    };
    // Identity Q/K projection, V projection that picks out the V we
    // supply, all-zero biases. Output projection is identity.
    let mut w = Array2::<f32>::zeros((3 * embed_dim, embed_dim));
    for i in 0..embed_dim {
        w[(i, i)] = 1.0;
        w[(embed_dim + i, i)] = 1.0;
    }
    // We'll pass Q = "qk_input", V is supplied separately.
    let qk_input =
        Array::from_shape_vec((1, 2, 4), vec![1.0_f32, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
    // For V, use a separate "value" input that's identity-projected.
    // Identity V projection slot.
    for i in 0..embed_dim {
        w[(2 * embed_dim + i, i)] = 1.0;
    }
    let v_input =
        Array::from_shape_vec((1, 2, 4), vec![0.5_f32, 1.0, 1.5, 2.0, 3.0, 2.5, 2.0, 1.5]).unwrap();
    let mut out_w = Array2::<f32>::zeros((embed_dim, embed_dim));
    for i in 0..embed_dim {
        out_w[(i, i)] = 1.0;
    }
    let params = AttentionParams {
        in_proj_weight: w,
        in_proj_bias: None,
        out_proj_weight: out_w,
        out_proj_bias: None,
    };
    let out = multi_head_attention_forward(
        qk_input.view(),
        qk_input.view(),
        v_input.view(),
        &params,
        &cfg,
        None,
        None,
    );
    let pytorch_truth = [
        1.3255960941314697_f32,
        1.4953577518463135,
        1.75,
        1.75,
        2.1744039058685303,
        2.0046422481536865,
        1.75,
        1.75,
    ];
    let mut idx = 0;
    for t in 0..2 {
        for e in 0..4 {
            let diff = (out[(0, t, e)] - pytorch_truth[idx]).abs();
            assert!(
                diff < 1e-5,
                "attn[{t},{e}] Rust={} PyTorch={} diff={}",
                out[(0, t, e)],
                pytorch_truth[idx],
                diff
            );
            idx += 1;
        }
    }
}

/// Attention: tiny known case where Q=K identity and V is known.
/// Then attention output = softmax(QK^T/sqrt(D)) @ V, with the softmax
/// computable by hand.
#[test]
fn attention_with_known_qkv_matches_softmax_v() {
    let embed_dim = 2;
    let cfg = AttentionConfig {
        embed_dim,
        num_heads: 1,
        dropout: 0.0,
        bias: true,
    };
    // Identity QKV projections (3*E rows of E cols).
    let mut w = Array2::<f32>::zeros((3 * embed_dim, embed_dim));
    for i in 0..embed_dim {
        w[(i, i)] = 1.0; // Q
        w[(embed_dim + i, i)] = 1.0; // K
        w[(2 * embed_dim + i, i)] = 1.0; // V
    }
    // Identity output projection.
    let mut out_w = Array2::<f32>::zeros((embed_dim, embed_dim));
    for i in 0..embed_dim {
        out_w[(i, i)] = 1.0;
    }
    let params = AttentionParams {
        in_proj_weight: w,
        in_proj_bias: None,
        out_proj_weight: out_w,
        out_proj_bias: None,
    };
    // Two tokens with distinct values.
    let x = Array::from_shape_vec((1, 2, embed_dim), vec![1.0_f32, 0.0, 0.0, 1.0]).unwrap();
    let y = multi_head_attention_forward(x.view(), x.view(), x.view(), &params, &cfg, None, None);
    // Q = [[1, 0], [0, 1]]; K = same; QK^T = identity*1 → diag = [1, 1],
    // off-diag = [0, 0]. After /sqrt(D=2): [[1/√2, 0], [0, 1/√2]].
    // Softmax row 0: exp([1/√2, 0]) / sum → [e^{0.707}, 1] / (e^{0.707} + 1).
    let s = 1.0_f32 / 2.0_f32.sqrt();
    let e_s = s.exp();
    let z0 = e_s + 1.0;
    let a00 = e_s / z0;
    let a01 = 1.0 / z0;
    // Output row 0 = a00 * V[0] + a01 * V[1] = a00 * [1, 0] + a01 * [0, 1] = [a00, a01].
    assert!((y[(0, 0, 0)] - a00).abs() < 1e-5);
    assert!((y[(0, 0, 1)] - a01).abs() < 1e-5);
    // Output row 1 is symmetric.
    assert!((y[(0, 1, 0)] - a01).abs() < 1e-5);
    assert!((y[(0, 1, 1)] - a00).abs() < 1e-5);
}

/// GELU(0.5) tanh-approximation against actual PyTorch output.
/// Truth: `F.gelu(torch.tensor(0.5), approximate='tanh') = 0.3457140028476715`.
#[test]
fn gelu_matches_pytorch_bit_for_bit() {
    let v = 0.5_f32;
    let c = (2.0_f32 / std::f32::consts::PI).sqrt();
    let inner = c * (v + 0.044715 * v * v * v);
    let g = 0.5 * v * (1.0 + inner.tanh());
    let pytorch_truth = 0.3457140028476715_f32;
    let diff = (g - pytorch_truth).abs();
    assert!(
        diff < 1e-6,
        "GELU(0.5) Rust={g} PyTorch={pytorch_truth} diff={diff}"
    );
}

/// Softmax + cross-entropy against actual PyTorch output.
/// Truth: F.softmax([2, 1, 0]) = [0.6652409, 0.2447285, 0.0900306];
///        CE(label=0) = 0.4076059829985112.
#[test]
fn softmax_and_cross_entropy_match_pytorch_bit_for_bit() {
    let logits = [2.0_f32, 1.0, 0.0];
    let mut maxv = f32::NEG_INFINITY;
    for v in &logits {
        if *v > maxv {
            maxv = *v;
        }
    }
    let exps: Vec<f32> = logits.iter().map(|v| (v - maxv).exp()).collect();
    let z: f32 = exps.iter().sum();
    let probs: Vec<f32> = exps.iter().map(|e| e / z).collect();
    let pytorch_softmax = [
        0.6652409434318542_f32,
        0.2447284758090973,
        0.09003057330846786,
    ];
    for k in 0..3 {
        let diff = (probs[k] - pytorch_softmax[k]).abs();
        assert!(
            diff < 1e-6,
            "softmax[{k}] Rust={} PyTorch={} diff={diff}",
            probs[k],
            pytorch_softmax[k]
        );
    }
    let ce = -probs[0].ln();
    let pytorch_ce = 0.4076059829985112_f32;
    let diff = (ce - pytorch_ce).abs();
    assert!(diff < 1e-6, "CE Rust={ce} PyTorch={pytorch_ce} diff={diff}");
}