tabicl-model 2.1.1

TabICL transformer model — column embedding, row interaction, ICL learning, KV cache.
//! Scalable Softmax (SSMax) variants.
//!
//! Port of `tabicl._model.ssmax`. Provides the five SSMax flavors plus a
//! "none" pass-through, all selectable by string (see [`SsmaxKind`]):
//!
//!   - `none`                    — no scaling, returns `q` unchanged.
//!   - `ssmax`                   — `q * (s * log n)` with per-head learnable `s`.
//!   - `ssmax-mlp`               — `q * mlp(log n)`.
//!   - `ssmax-mlp-elementwise`   — same with per-(head, head_dim) output.
//!   - `qassmax-mlp`             — query-aware scaling.
//!   - `qassmax-mlp-elementwise` — query-aware, per-(head, head_dim).
//!
//! At the moment this file holds the *spec* (config, parameter shapes,
//! kind discriminator). The actual graph construction will land alongside
//! the attention port that consumes it — SSMax is only meaningful in the
//! context of an attention layer.

use thiserror::Error;

#[derive(Debug, Error)]
pub enum SsmaxError {
    #[error("unknown ssmax kind: {0:?}")]
    UnknownKind(String),
    #[error("`head_dim` is required for `elementwise=true` SSMax variants")]
    MissingHeadDim,
}

/// Selector matching the Python `ssmax_type` strings.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum SsmaxKind {
    None,
    Ssmax,
    SsmaxMlp,
    SsmaxMlpElementwise,
    QassmaxMlp,
    QassmaxMlpElementwise,
}

impl SsmaxKind {
    pub fn parse(s: &str) -> Result<Self, SsmaxError> {
        Ok(match s {
            "none" => Self::None,
            "ssmax" => Self::Ssmax,
            "ssmax-mlp" => Self::SsmaxMlp,
            "ssmax-mlp-elementwise" => Self::SsmaxMlpElementwise,
            "qassmax-mlp" => Self::QassmaxMlp,
            "qassmax-mlp-elementwise" => Self::QassmaxMlpElementwise,
            other => return Err(SsmaxError::UnknownKind(other.to_string())),
        })
    }

    /// Boolean shorthand matching the Python module: `True` ↦ qassmax-mlp-elementwise,
    /// `False` ↦ none.
    pub fn from_bool(b: bool) -> Self {
        if b {
            Self::QassmaxMlpElementwise
        } else {
            Self::None
        }
    }

    pub fn is_active(self) -> bool {
        !matches!(self, Self::None)
    }
}

/// Parameter layout for a single SSMax block, independent of any graph.
///
/// Mirrors the `nn.Parameter` / `nn.Linear` shapes constructed by the
/// Python `SSMax`, `SSMaxMLP`, `QASSMaxMLP` modules so we can faithfully
/// load Python checkpoints. Naming follows Python attribute names:
///
///   - `SSMax.scales`                     ↦ `scales`
///   - `SSMaxMLP.mlp.{0,2}.{weight,bias}` ↦ `mlp_*`
///   - `QASSMaxMLP.{base,query}_mlp.…`    ↦ `base_*`, `query_*`
#[derive(Debug, Clone)]
pub struct SsmaxSpec {
    pub kind: SsmaxKind,
    pub num_heads: usize,
    pub head_dim: usize,
    /// MLP hidden dim; the Python default is 64 across all variants.
    pub n_hidden: usize,
}

impl SsmaxSpec {
    /// Factory mirroring `create_ssmax_layer(ssmax_type, num_heads, embed_dim)`.
    pub fn create(
        kind: SsmaxKind,
        num_heads: usize,
        embed_dim: usize,
    ) -> Result<Option<Self>, SsmaxError> {
        if !kind.is_active() {
            return Ok(None);
        }
        if num_heads == 0 || embed_dim == 0 || !embed_dim.is_multiple_of(num_heads) {
            return Err(SsmaxError::MissingHeadDim);
        }
        Ok(Some(Self {
            kind,
            num_heads,
            head_dim: embed_dim / num_heads,
            n_hidden: 64,
        }))
    }

    /// Output dimension of the *base* path (depends on `elementwise`).
    pub fn base_out_dim(&self) -> usize {
        match self.kind {
            SsmaxKind::Ssmax | SsmaxKind::None => self.num_heads,
            SsmaxKind::SsmaxMlp | SsmaxKind::QassmaxMlp => self.num_heads,
            SsmaxKind::SsmaxMlpElementwise | SsmaxKind::QassmaxMlpElementwise => {
                self.num_heads * self.head_dim
            }
        }
    }

    /// Output dimension of the *query* path for QASSMax variants. Returns 0
    /// for non-query-aware variants.
    pub fn query_out_dim(&self) -> usize {
        match self.kind {
            SsmaxKind::QassmaxMlp => 1,
            SsmaxKind::QassmaxMlpElementwise => self.head_dim,
            _ => 0,
        }
    }
}

/// Runtime parameters for an active SSMax block (kind ≠ None).
/// Mirrors the Python `nn.Linear(in, hidden) → GELU → nn.Linear(hidden, out)`
/// pair for both the base and query MLPs. The base MLP runs every step;
/// the query MLP runs only for QASSMax variants.
#[derive(Debug, Clone, Default)]
pub struct SsmaxParams {
    /// For [`SsmaxKind::Ssmax`]: a `(num_heads,)` learnable scaling
    /// vector. Empty for the MLP variants.
    pub scales: Vec<f32>,
    /// `base_mlp.0`: (n_hidden, 1).
    pub base_w1: ndarray::Array2<f32>,
    pub base_b1: Vec<f32>,
    /// `base_mlp.2`: (out_dim, n_hidden) where out_dim ∈ {num_heads, num_heads*head_dim}.
    pub base_w2: ndarray::Array2<f32>,
    pub base_b2: Vec<f32>,
    /// `query_mlp.0`: (n_hidden, head_dim). Present only for QASSMax variants.
    pub query_w1: ndarray::Array2<f32>,
    pub query_b1: Vec<f32>,
    /// `query_mlp.2`: (q_out_dim, n_hidden). Present only for QASSMax variants.
    pub query_w2: ndarray::Array2<f32>,
    pub query_b2: Vec<f32>,
}

impl SsmaxParams {
    pub fn zeros(spec: &SsmaxSpec) -> Self {
        let base_out = spec.base_out_dim();
        let q_out = spec.query_out_dim();
        Self {
            scales: vec![1.0; spec.num_heads],
            base_w1: ndarray::Array2::<f32>::zeros((spec.n_hidden, 1)),
            base_b1: vec![0.0; spec.n_hidden],
            base_w2: ndarray::Array2::<f32>::zeros((base_out, spec.n_hidden)),
            base_b2: vec![0.0; base_out],
            query_w1: if q_out > 0 {
                ndarray::Array2::<f32>::zeros((spec.n_hidden, spec.head_dim))
            } else {
                ndarray::Array2::<f32>::zeros((0, 0))
            },
            query_b1: if q_out > 0 {
                vec![0.0; spec.n_hidden]
            } else {
                Vec::new()
            },
            query_w2: if q_out > 0 {
                ndarray::Array2::<f32>::zeros((q_out, spec.n_hidden))
            } else {
                ndarray::Array2::<f32>::zeros((0, 0))
            },
            query_b2: if q_out > 0 {
                vec![0.0; q_out]
            } else {
                Vec::new()
            },
        }
    }

    /// Load from a PyTorch state dict under `{prefix}.ssmax_layer.`.
    pub fn load_from(
        &mut self,
        sd: &crate::state_dict::StateDict,
        prefix: &str,
        spec: &SsmaxSpec,
    ) -> Result<(), crate::state_dict::StateDictError> {
        use crate::state_dict::StateDictError;
        let p = format!("{prefix}.ssmax_layer");
        // Detect bias vector for per-head SSMax variant.
        let scales_key = format!("{p}.scales");
        if sd.tensors.contains_key(&scales_key) {
            self.scales = sd.take_vec(&scales_key, spec.num_heads)?;
            return Ok(());
        }
        let base_out = spec.base_out_dim();
        self.base_w1 = sd
            .take_array2(&format!("{p}.base_mlp.0.weight"), spec.n_hidden, 1)
            .map_err(|e| StateDictError::MissingKey(format!("{p}.base_mlp.0.weight: {e}")))?;
        self.base_b1 = sd.take_vec(&format!("{p}.base_mlp.0.bias"), spec.n_hidden)?;
        self.base_w2 =
            sd.take_array2(&format!("{p}.base_mlp.2.weight"), base_out, spec.n_hidden)?;
        self.base_b2 = sd.take_vec(&format!("{p}.base_mlp.2.bias"), base_out)?;

        let q_out = spec.query_out_dim();
        if q_out > 0 {
            self.query_w1 = sd.take_array2(
                &format!("{p}.query_mlp.0.weight"),
                spec.n_hidden,
                spec.head_dim,
            )?;
            self.query_b1 = sd.take_vec(&format!("{p}.query_mlp.0.bias"), spec.n_hidden)?;
            self.query_w2 =
                sd.take_array2(&format!("{p}.query_mlp.2.weight"), q_out, spec.n_hidden)?;
            self.query_b2 = sd.take_vec(&format!("{p}.query_mlp.2.bias"), q_out)?;
        }
        Ok(())
    }
}

/// Compute the per-query scale tensor `(B, H, T, D_or_1)` for the given
/// SSMax kind, source-sequence length `n`, and projected query tensor
/// `q` of shape `(B, H, T, D)`.
pub fn compute_query_scale(
    spec: &SsmaxSpec,
    params: &SsmaxParams,
    q: ndarray::ArrayView4<f32>,
    n_src: usize,
) -> ndarray::Array4<f32> {
    let (b, h, t, d) = (q.shape()[0], q.shape()[1], q.shape()[2], q.shape()[3]);
    let log_n = (n_src.max(1) as f32).ln();

    let base_scales: Vec<f32> = match spec.kind {
        SsmaxKind::None => return ndarray::Array4::<f32>::ones((1, h, 1, 1)),
        SsmaxKind::Ssmax => {
            // base = scales * log_n (per head).
            params.scales.iter().map(|s| s * log_n).collect()
        }
        _ => {
            // MLP: out = W2 · GELU(W1 · [log_n] + b1) + b2 with exact erf.
            let n_hidden = params.base_w1.shape()[0];
            let mut h1 = vec![0.0_f32; n_hidden];
            for (k, h1k) in h1.iter_mut().enumerate().take(n_hidden) {
                let pre = params.base_w1[(k, 0)] * log_n + params.base_b1[k];
                *h1k = 0.5 * pre * (1.0 + erf_ss(pre / std::f32::consts::SQRT_2));
            }
            let out_dim = params.base_w2.shape()[0];
            let mut h2 = vec![0.0_f32; out_dim];
            for (o, h2o) in h2.iter_mut().enumerate().take(out_dim) {
                let mut s = params.base_b2[o];
                for (k, h1k) in h1.iter().enumerate().take(n_hidden) {
                    s += params.base_w2[(o, k)] * h1k;
                }
                *h2o = s;
            }
            h2
        }
    };

    // Reshape base_scales into (1, H, 1, scale_d) where scale_d ∈ {1, D}.
    let scale_d = match spec.kind {
        SsmaxKind::Ssmax | SsmaxKind::SsmaxMlp | SsmaxKind::QassmaxMlp => 1,
        SsmaxKind::SsmaxMlpElementwise | SsmaxKind::QassmaxMlpElementwise => d,
        SsmaxKind::None => 1,
    };
    let mut base = ndarray::Array4::<f32>::zeros((1, h, 1, scale_d));
    for hi in 0..h {
        for di in 0..scale_d {
            let idx = if scale_d == 1 { hi } else { hi * scale_d + di };
            base[(0, hi, 0, di)] = base_scales[idx];
        }
    }

    // For QASSMax: also compute the query-modulation MLP.
    // modulation[b,h,t,o] = tanh(W2 · GELU(W1 · q[b,h,t,:] + b1) + b2) + 1
    // out = base * modulation (broadcast over batch + time).
    if matches!(
        spec.kind,
        SsmaxKind::QassmaxMlp | SsmaxKind::QassmaxMlpElementwise
    ) {
        let n_hidden = params.query_w1.shape()[0];
        let q_out = params.query_w2.shape()[0];
        let mut out = ndarray::Array4::<f32>::zeros((b, h, t, q_out));
        for bi in 0..b {
            for hi in 0..h {
                for ti in 0..t {
                    let mut h1 = vec![0.0_f32; n_hidden];
                    for (k, h1k) in h1.iter_mut().enumerate().take(n_hidden) {
                        let mut s = params.query_b1[k];
                        for di in 0..d {
                            s += params.query_w1[(k, di)] * q[(bi, hi, ti, di)];
                        }
                        *h1k = 0.5 * s * (1.0 + erf_ss(s / std::f32::consts::SQRT_2));
                    }
                    for o in 0..q_out {
                        let mut s2 = params.query_b2[o];
                        for (k, h1k) in h1.iter().enumerate().take(n_hidden) {
                            s2 += params.query_w2[(o, k)] * h1k;
                        }
                        out[(bi, hi, ti, o)] = 1.0 + s2.tanh();
                    }
                }
            }
        }
        // Multiply base (1, H, 1, scale_d) by out (B, H, T, q_out).
        // scale_d and q_out match per the SSMax spec.
        let mut result = ndarray::Array4::<f32>::zeros((b, h, t, q_out));
        for bi in 0..b {
            for hi in 0..h {
                for ti in 0..t {
                    for di in 0..q_out {
                        let base_v = base[(0, hi, 0, di.min(scale_d - 1))];
                        result[(bi, hi, ti, di)] = base_v * out[(bi, hi, ti, di)];
                    }
                }
            }
        }
        return result;
    }

    base
}

/// Abramowitz-Stegun erf approximation (used by GELU here so SSMax
/// matches PyTorch's `nn.GELU()` exact-erf default).
fn erf_ss(x: f32) -> f32 {
    let sign = x.signum();
    let ax = x.abs();
    let t = 1.0 / (1.0 + 0.3275911 * ax);
    let y = 1.0
        - (((((1.061_405_4_f32 * t - 1.453_152_1) * t + 1.421_413_8) * t - 0.284_496_72) * t
            + 0.254_829_6)
            * t)
            * (-ax * ax).exp();
    sign * y
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn parse_round_trip() {
        let cases = [
            ("none", SsmaxKind::None),
            ("ssmax", SsmaxKind::Ssmax),
            ("ssmax-mlp", SsmaxKind::SsmaxMlp),
            ("ssmax-mlp-elementwise", SsmaxKind::SsmaxMlpElementwise),
            ("qassmax-mlp", SsmaxKind::QassmaxMlp),
            ("qassmax-mlp-elementwise", SsmaxKind::QassmaxMlpElementwise),
        ];
        for (s, want) in cases {
            assert_eq!(SsmaxKind::parse(s).unwrap(), want);
        }
        assert!(SsmaxKind::parse("nope").is_err());
    }

    #[test]
    fn bool_shorthand_matches_python_default() {
        // Python: ssmax=True ↔ "qassmax-mlp-elementwise"
        assert_eq!(SsmaxKind::from_bool(true), SsmaxKind::QassmaxMlpElementwise);
        assert_eq!(SsmaxKind::from_bool(false), SsmaxKind::None);
    }

    #[test]
    fn dims_match_python_layout() {
        // embed_dim=128, heads=8 → head_dim=16
        let s = SsmaxSpec::create(SsmaxKind::QassmaxMlpElementwise, 8, 128)
            .unwrap()
            .unwrap();
        assert_eq!(s.head_dim, 16);
        assert_eq!(s.base_out_dim(), 8 * 16);
        assert_eq!(s.query_out_dim(), 16);

        let s = SsmaxSpec::create(SsmaxKind::Ssmax, 8, 128)
            .unwrap()
            .unwrap();
        assert_eq!(s.base_out_dim(), 8);
        assert_eq!(s.query_out_dim(), 0);
    }
}