tabicl-model 2.1.1

TabICL transformer model — column embedding, row interaction, ICL learning, KV cache.
//! Shared transformer building blocks — port of `tabicl._model.layers`.
//!
//! The Python file (792 LOC) ships six classes. Their lay-of-the-land:
//!
//! | Python class            | scope                                         | port status      |
//! |-------------------------|-----------------------------------------------|------------------|
//! | `ClassNode`             | hierarchical classification tree node          | will move to `learning` |
//! | `OneHotAndLinear`       | int→embedding (one-hot ⨯ linear)               | reference here   |
//! | `SkippableLinear`       | linear that no-ops on sentinel rows            | reference here   |
//! | `MultiheadAttention`    | attention block with RoPE / SSMax / KV cache   | see [`crate::attention`] |
//! | `MultiheadAttentionBlock` (MAB) | transformer encoder layer              | placeholder      |
//! | `InducedSelfAttentionBlock` (ISAB) | Set Transformer block               | placeholder      |
//!
//! This file currently lands the two small, fully self-contained utility
//! classes as ndarray reference implementations so they can be parity-tested
//! independently. The MAB / ISAB ports are kept with the modules that use
//! them (embedding for ISAB, learning for the transformer encoder stack).

use ndarray::{Array2, Array3, ArrayView2, ArrayView3, Axis};

use crate::state_dict::{StateDict, StateDictError};

/// One-hot embedding via dense linear projection. Mirrors
/// `tabicl._model.layers.OneHotAndLinear` exactly: a single `nn.Linear`
/// applied to a one-hot row, equivalent to `weight.T[src]` plus optional
/// bias.
#[derive(Debug, Clone)]
pub struct OneHotAndLinear {
    pub num_classes: usize,
    pub embed_dim: usize,
    /// Shape `(embed_dim, num_classes)`, matching PyTorch `nn.Linear.weight`.
    pub weight: Array2<f32>,
    /// Optional bias, shape `(embed_dim,)`. `nn.Linear` defaults to bias=True.
    pub bias: Option<Vec<f32>>,
}

impl OneHotAndLinear {
    /// Construct with caller-supplied weight/bias (e.g. loaded from a
    /// Python checkpoint). Use `from_raw_weight` for the common case.
    pub fn from_raw_weight(weight: Array2<f32>, bias: Option<Vec<f32>>) -> Self {
        let (embed_dim, num_classes) = (weight.shape()[0], weight.shape()[1]);
        Self {
            num_classes,
            embed_dim,
            weight,
            bias,
        }
    }

    /// Load from a state dict. PyTorch keys: `{prefix}.weight` (shape
    /// `(embed_dim, num_classes)`) and `{prefix}.bias` (shape `(embed_dim,)`,
    /// optional — `nn.Linear` defaults to bias=True).
    pub fn load_from(&mut self, sd: &StateDict, prefix: &str) -> Result<(), StateDictError> {
        self.weight = sd.take_array2(
            &format!("{prefix}.weight"),
            self.embed_dim,
            self.num_classes,
        )?;
        // Bias is optional in this port — accept either presence.
        let bias_key = format!("{prefix}.bias");
        if sd.tensors.contains_key(&bias_key) {
            self.bias = Some(sd.take_vec(&bias_key, self.embed_dim)?);
        }
        Ok(())
    }

    /// Apply to integer indices `src` of shape `(B, T)`. Returns `(B, T, E)`.
    ///
    /// Indices >= `num_classes` are clamped to `num_classes - 1` (the
    /// Python code calls `F.one_hot(src.long(), num_classes)` which would
    /// raise; we trade the assertion for a debug-mode `debug_assert!`).
    pub fn forward(&self, src: ArrayView2<usize>) -> Array3<f32> {
        let (b, t) = (src.shape()[0], src.shape()[1]);
        let mut out = Array3::<f32>::zeros((b, t, self.embed_dim));
        for bi in 0..b {
            for ti in 0..t {
                let c = src[(bi, ti)];
                debug_assert!(c < self.num_classes, "class index out of range");
                for e in 0..self.embed_dim {
                    out[(bi, ti, e)] = self.weight[(e, c)];
                }
                if let Some(b_) = &self.bias {
                    for e in 0..self.embed_dim {
                        out[(bi, ti, e)] += b_[e];
                    }
                }
            }
        }
        out
    }
}

/// Linear that no-ops on sentinel rows. Mirrors
/// `tabicl._model.layers.SkippableLinear`: applies the standard
/// `F.linear(src, weight, bias)`, then for any input row whose
/// *every* feature equals `skip_value`, replaces the entire output row
/// with `skip_value`.
#[derive(Debug, Clone)]
pub struct SkippableLinear {
    /// Shape `(out_features, in_features)`.
    pub weight: Array2<f32>,
    pub bias: Option<Vec<f32>>,
    /// Python default is `-100.0`.
    pub skip_value: f32,
}

impl SkippableLinear {
    pub fn new(weight: Array2<f32>, bias: Option<Vec<f32>>, skip_value: f32) -> Self {
        Self {
            weight,
            bias,
            skip_value,
        }
    }

    pub fn in_features(&self) -> usize {
        self.weight.shape()[1]
    }

    pub fn out_features(&self) -> usize {
        self.weight.shape()[0]
    }

    /// Load from a PyTorch state dict. `SkippableLinear` extends
    /// `nn.Linear`, so the keys match `nn.Linear`: `{prefix}.weight`
    /// (shape `(out, in)`) and `{prefix}.bias` (shape `(out,)`).
    pub fn load_from(&mut self, sd: &StateDict, prefix: &str) -> Result<(), StateDictError> {
        let (out_f, in_f) = (self.out_features(), self.in_features());
        self.weight = sd.take_array2(&format!("{prefix}.weight"), out_f, in_f)?;
        let bias_key = format!("{prefix}.bias");
        if sd.tensors.contains_key(&bias_key) {
            self.bias = Some(sd.take_vec(&bias_key, out_f)?);
        }
        Ok(())
    }

    /// Forward pass on a 3-D batch of shape `(B, T, in_features)`. Returns
    /// `(B, T, out_features)`.
    pub fn forward(&self, src: ArrayView3<f32>) -> Array3<f32> {
        let (b, t, in_f) = (src.shape()[0], src.shape()[1], src.shape()[2]);
        assert_eq!(in_f, self.in_features());
        let out_f = self.out_features();

        let mut out = Array3::<f32>::zeros((b, t, out_f));
        for bi in 0..b {
            for ti in 0..t {
                // skip-mask: are all input values equal to skip_value?
                let skip = (0..in_f).all(|k| src[(bi, ti, k)] == self.skip_value);
                if skip {
                    for k in 0..out_f {
                        out[(bi, ti, k)] = self.skip_value;
                    }
                    continue;
                }
                // standard linear: out = src @ weight.T + bias
                for k in 0..out_f {
                    let mut acc = 0.0_f32;
                    for j in 0..in_f {
                        acc += src[(bi, ti, j)] * self.weight[(k, j)];
                    }
                    if let Some(b_) = &self.bias {
                        acc += b_[k];
                    }
                    out[(bi, ti, k)] = acc;
                }
            }
        }
        out
    }
}

/// Dense linear `(B, T, in_f) -> (B, T, out_f)` with optional bias.
/// Mirrors PyTorch `F.linear`. Accumulation runs in f64 to reduce
/// fp32 drift; the output is cast back to f32. This brings cross-stack
/// numerical parity well below 1e-5 on typical TabICL layer dims.
pub fn linear3d(
    src: ArrayView3<f32>,
    weight: ArrayView2<f32>,
    bias: Option<&[f32]>,
) -> Array3<f32> {
    let (b, t, in_f) = (src.shape()[0], src.shape()[1], src.shape()[2]);
    let (out_f, in_f2) = (weight.shape()[0], weight.shape()[1]);
    assert_eq!(in_f, in_f2);
    let mut out = Array3::<f32>::zeros((b, t, out_f));
    for bi in 0..b {
        for ti in 0..t {
            for k in 0..out_f {
                let mut acc: f64 = 0.0;
                for j in 0..in_f {
                    acc += (src[(bi, ti, j)] as f64) * (weight[(k, j)] as f64);
                }
                if let Some(b_) = bias {
                    acc += b_[k] as f64;
                }
                out[(bi, ti, k)] = acc as f32;
            }
        }
    }
    out
}

/// LayerNorm reference (over the last dim).
///
/// `y = (x - mean) / sqrt(var + eps) * gamma + beta`.
/// Mean + variance accumulate in f64 to match PyTorch's mixed-precision
/// reduction; output is cast back to f32.
pub fn layer_norm_last(
    x: ArrayView3<f32>,
    gamma: &[f32],
    beta: Option<&[f32]>,
    eps: f32,
) -> Array3<f32> {
    let (b, t, d) = (x.shape()[0], x.shape()[1], x.shape()[2]);
    assert_eq!(gamma.len(), d);
    let mut out = Array3::<f32>::zeros((b, t, d));
    let inv_d = 1.0_f64 / d as f64;
    for bi in 0..b {
        for ti in 0..t {
            let mut mean: f64 = 0.0;
            for k in 0..d {
                mean += x[(bi, ti, k)] as f64;
            }
            mean *= inv_d;
            let mut var: f64 = 0.0;
            for k in 0..d {
                let dx = (x[(bi, ti, k)] as f64) - mean;
                var += dx * dx;
            }
            var *= inv_d;
            let inv = (var + eps as f64).sqrt().recip();
            for k in 0..d {
                let val = ((x[(bi, ti, k)] as f64 - mean) * inv) as f32 * gamma[k];
                out[(bi, ti, k)] = match beta {
                    Some(b_) => val + b_[k],
                    None => val,
                };
            }
        }
    }
    out
}

#[allow(dead_code)]
fn _silence_unused(_a: Axis) {}

#[cfg(test)]
mod tests {
    use super::*;
    use ndarray::array;

    #[test]
    fn one_hot_and_linear_picks_column() {
        // weight[:, c] should be the output for index c (ignoring bias).
        let weight = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; // (E=2, C=3)
        let m = OneHotAndLinear::from_raw_weight(weight, None);
        let src = array![[0_usize, 1, 2]];
        let out = m.forward(src.view());
        assert_eq!(out.shape(), &[1, 3, 2]);
        // out[0, t, :] == weight[:, t]
        for t in 0..3 {
            for e in 0..2 {
                assert_eq!(out[(0, t, e)], (e * 3 + t) as f32 + 1.0);
            }
        }
    }

    #[test]
    fn one_hot_and_linear_adds_bias() {
        let weight = array![[1.0, 2.0], [3.0, 4.0]]; // (E=2, C=2)
        let m = OneHotAndLinear::from_raw_weight(weight, Some(vec![10.0, 100.0]));
        let src = array![[0_usize]];
        let out = m.forward(src.view());
        assert_eq!(out[(0, 0, 0)], 1.0 + 10.0);
        assert_eq!(out[(0, 0, 1)], 3.0 + 100.0);
    }

    #[test]
    fn skippable_linear_skips_sentinel_rows() {
        let weight = array![[1.0, 0.0], [0.0, 1.0]]; // identity (out_f=2, in_f=2)
        let m = SkippableLinear::new(weight, None, -100.0);
        // Row 0 is normal; row 1 is all sentinels.
        let src = array![[[1.0, 2.0], [-100.0, -100.0]]];
        let out = m.forward(src.view());
        assert_eq!(out[(0, 0, 0)], 1.0);
        assert_eq!(out[(0, 0, 1)], 2.0);
        // Skipped row gets sentinel back, *not* the linear output.
        assert_eq!(out[(0, 1, 0)], -100.0);
        assert_eq!(out[(0, 1, 1)], -100.0);
    }

    #[test]
    fn skippable_linear_partial_sentinel_is_not_skipped() {
        // Only fully-sentinel rows are skipped per the Python spec
        // (`(src == skip).all(dim=-1)`).
        let weight = array![[1.0, 1.0]]; // (out_f=1, in_f=2)
        let m = SkippableLinear::new(weight, None, -100.0);
        let src = array![[[-100.0, 1.0]]]; // mixed
        let out = m.forward(src.view());
        assert_eq!(out[(0, 0, 0)], -99.0); // normal linear path
    }

    #[test]
    fn layer_norm_zero_mean_unit_var() {
        // x = [-1, 0, 1] → mean 0, var 2/3 → after LN with γ=1, β=0: scaled.
        let x = array![[[-1.0_f32, 0.0, 1.0]]];
        let y = layer_norm_last(x.view(), &[1.0, 1.0, 1.0], None, 1e-5);
        // mean of y should be ~0, variance of y close to 1.
        let mean: f32 = y.iter().sum::<f32>() / 3.0;
        let var: f32 = y.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / 3.0;
        assert!(mean.abs() < 1e-5);
        assert!((var - 1.0).abs() < 1e-3);
    }

    #[test]
    fn linear3d_matches_manual_compute() {
        let x = array![[[1.0_f32, 2.0], [3.0, 4.0]]]; // (1, 2, 2)
        let w = array![[1.0, 1.0], [0.0, 2.0]]; // (out=2, in=2)
        let b = [0.0, 0.5];
        let y = linear3d(x.view(), w.view(), Some(&b));
        // y[0,0] = (1+2 + 0,  0+4 + .5) = (3, 4.5)
        assert_eq!(y[(0, 0, 0)], 3.0);
        assert_eq!(y[(0, 0, 1)], 4.5);
        // y[0,1] = (3+4, 0+8 + .5)
        assert_eq!(y[(0, 1, 0)], 7.0);
        assert_eq!(y[(0, 1, 1)], 8.5);
    }
}