tabicl-model 2.1.1

TabICL transformer model — column embedding, row interaction, ICL learning, KV cache.
//! Row-wise interaction transformer — port of
//! `tabicl._model.interaction.RowInteraction`.
//!
//! Per-row attention over the feature-token sequence + `num_cls` learnable
//! CLS tokens, with RoPE on Q/K. The last block uses only the CLS tokens as
//! queries (`q = embeddings[:, :, :num_cls, :]`) to aggregate features into
//! `(B, T, num_cls * embed_dim)` row representations.
//!
//! This file lands the config + parameter container. The full forward (and
//! its KV-cache fast path) hooks into [`crate::attention`] once the
//! `Encoder` block layout is ported alongside the embedding port.

use ndarray::{Array2, Array3, Array4, ArrayView4};
use serde::{Deserialize, Serialize};

use crate::encoders::{EncoderStack, MabConfig};
use crate::layers::layer_norm_last;
use crate::rope::RopeConfig;
use crate::state_dict::{StateDict, StateDictError};
use crate::tabicl::Activation;

/// Config for the row interaction transformer. Mirrors
/// `RowInteraction.__init__` exactly. Defaults match the Python signature
/// (`num_cls=4`, `rope_base=100_000`, `rope_interleaved=True`).
///
/// NOTE: Python's default for `rope_interleaved` here is `True`, but
/// [`crate::tabicl::TabICLConfig`] passes `row_rope_interleaved=false` (the
/// model-level default). That mismatch is intentional in Python and we
/// preserve it: the module default is decoupled from the TabICL default.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RowInteractionConfig {
    pub embed_dim: usize,
    pub num_blocks: usize,
    pub nhead: usize,
    pub dim_feedforward: usize,
    pub num_cls: usize,
    pub rope_base: f32,
    pub rope_interleaved: bool,
    pub dropout: f32,
    pub activation: Activation,
    pub norm_first: bool,
    pub bias_free_ln: bool,
    pub recompute: bool,
}

impl Default for RowInteractionConfig {
    fn default() -> Self {
        Self {
            embed_dim: 128,
            num_blocks: 3,
            nhead: 8,
            dim_feedforward: 256,
            num_cls: 4,
            rope_base: 100_000.0,
            rope_interleaved: true,
            dropout: 0.0,
            activation: Activation::Gelu,
            norm_first: true,
            bias_free_ln: false,
            recompute: false,
        }
    }
}

impl RowInteractionConfig {
    pub fn head_dim(&self) -> usize {
        self.embed_dim / self.nhead
    }
    /// Output row-representation dimension: `num_cls * embed_dim`.
    pub fn repr_dim(&self) -> usize {
        self.num_cls * self.embed_dim
    }
}

/// Parameter container for [`RowInteraction`]. The learnable parts are:
///
///   - `cls_tokens`: `(num_cls, embed_dim)` — broadcast over `(B, T)` at
///     forward time.
///   - one per-block encoder layer (Q/K/V projections, FFN, LayerNorms)
///     stored in [`crate::attention::AttentionParams`]-shaped fields.
///   - `out_ln`: a final LayerNorm applied to CLS outputs when
///     `norm_first` is set.
///
/// The full per-block layout will be filled in alongside the encoder port;
/// for now only the shared pieces are captured.
#[derive(Debug, Clone)]
pub struct RowInteractionParams {
    pub cls_tokens: Array2<f32>,
    /// Optional gamma/beta for the final LayerNorm. `None` indicates the
    /// `norm_first=False` Identity path.
    pub out_ln_gamma: Option<Vec<f32>>,
    pub out_ln_beta: Option<Vec<f32>>,
}

impl RowInteractionParams {
    /// Build with zero-initialized CLS tokens.
    pub fn zeros(cfg: &RowInteractionConfig) -> Self {
        Self {
            cls_tokens: Array2::<f32>::zeros((cfg.num_cls, cfg.embed_dim)),
            out_ln_gamma: if cfg.norm_first {
                Some(vec![1.0; cfg.embed_dim])
            } else {
                None
            },
            out_ln_beta: if cfg.norm_first && !cfg.bias_free_ln {
                Some(vec![0.0; cfg.embed_dim])
            } else {
                None
            },
        }
    }
}

#[derive(Debug, Clone)]
pub struct RowInteraction {
    pub config: RowInteractionConfig,
    pub params: RowInteractionParams,
    pub encoder: EncoderStack,
}

impl RowInteraction {
    /// Load weights from a Python state dict under `{prefix}`. Keys used:
    ///
    ///   - `{prefix}.cls_tokens`        — (num_cls, embed_dim)
    ///   - `{prefix}.out_ln.weight`     — (embed_dim,) when norm_first
    ///   - `{prefix}.out_ln.bias`       — (embed_dim,) when norm_first & !bias_free_ln
    ///   - `{prefix}.tf_row.blocks.i.…` — for each MAB in the encoder stack
    pub fn load_from(&mut self, sd: &StateDict, prefix: &str) -> Result<(), StateDictError> {
        let cls_key = format!("{prefix}.cls_tokens");
        self.params.cls_tokens =
            sd.take_array2(&cls_key, self.config.num_cls, self.config.embed_dim)?;
        if self.config.norm_first {
            self.params.out_ln_gamma =
                Some(sd.take_vec(&format!("{prefix}.out_ln.weight"), self.config.embed_dim)?);
            let beta_key = format!("{prefix}.out_ln.bias");
            if sd.tensors.contains_key(&beta_key) {
                self.params.out_ln_beta = Some(sd.take_vec(&beta_key, self.config.embed_dim)?);
            }
        }
        self.encoder.load_from(sd, &format!("{prefix}.tf_row"))?;
        Ok(())
    }

    pub fn new(config: RowInteractionConfig) -> Self {
        let params = RowInteractionParams::zeros(&config);
        let mab_cfg = MabConfig {
            d_model: config.embed_dim,
            nhead: config.nhead,
            dim_feedforward: config.dim_feedforward,
            dropout: config.dropout,
            activation: config.activation,
            norm_first: config.norm_first,
            bias_free_ln: config.bias_free_ln,
        };
        let rope = Some(RopeConfig {
            head_dim: config.head_dim(),
            base: config.rope_base,
            interleaved: config.rope_interleaved,
        });
        let encoder = EncoderStack::new(config.num_blocks, mab_cfg, rope)
            .expect("RowInteraction: d_model must be divisible by nhead");
        Self {
            config,
            params,
            encoder,
        }
    }

    /// Forward pass — port of `RowInteraction.forward` / `_aggregate_embeddings`.
    ///
    /// Input shape: `(B, T, H + num_cls, E)`. The first `num_cls` positions
    /// along the third axis are overwritten with the learnable CLS tokens
    /// (broadcast over `(B, T)`).
    ///
    /// Output shape: `(B, T, num_cls * E)`. Only the CLS outputs are kept;
    /// the rest of the encoder output is discarded.
    ///
    /// The Python module has a special last-block pattern (q=CLS, k=v=all)
    /// for compute savings. This reference implementation runs full
    /// self-attention through every block — mathematically equivalent for
    /// the CLS outputs since each token's attention output depends only on
    /// K/V (which match) and its own Q (the CLS slice).
    pub fn forward(&self, embeddings: ArrayView4<f32>) -> Array3<f32> {
        let (b, t, hc, e) = (
            embeddings.shape()[0],
            embeddings.shape()[1],
            embeddings.shape()[2],
            embeddings.shape()[3],
        );
        assert_eq!(e, self.config.embed_dim, "embed_dim mismatch");
        assert!(hc >= self.config.num_cls, "fewer tokens than CLS slots");

        // 1. Build a (B, T, H+C, E) tensor with CLS tokens written into
        //    the first num_cls slots. `cls_tokens` is (num_cls, E),
        //    broadcast over (B, T).
        let mut buf = embeddings.to_owned();
        for bi in 0..b {
            for ti in 0..t {
                for ci in 0..self.config.num_cls {
                    for ei in 0..e {
                        buf[(bi, ti, ci, ei)] = self.params.cls_tokens[(ci, ei)];
                    }
                }
            }
        }

        // 2. Flatten (B, T) → batch dim so the 3-D encoder kernels can run.
        let bt = b * t;
        let mut flat = Array3::<f32>::zeros((bt, hc, e));
        for bi in 0..b {
            for ti in 0..t {
                for hi in 0..hc {
                    for ei in 0..e {
                        flat[(bi * t + ti, hi, ei)] = buf[(bi, ti, hi, ei)];
                    }
                }
            }
        }

        // 3. Run the encoder stack.
        let out_flat = self.encoder.forward(flat.view());

        // 4. Slice out CLS outputs and apply out-LN if pre-norm.
        let mut cls_out = Array3::<f32>::zeros((bt, self.config.num_cls, e));
        for bti in 0..bt {
            for ci in 0..self.config.num_cls {
                for ei in 0..e {
                    cls_out[(bti, ci, ei)] = out_flat[(bti, ci, ei)];
                }
            }
        }
        let cls_norm = match (&self.params.out_ln_gamma, &self.params.out_ln_beta) {
            (Some(g), beta) => layer_norm_last(cls_out.view(), g, beta.as_deref(), 1e-5),
            _ => cls_out, // norm_first=False → Identity
        };

        // 5. Flatten last two dims: (B*T, C, E) → (B, T, C*E).
        let repr_dim = self.config.repr_dim();
        let mut out = Array3::<f32>::zeros((b, t, repr_dim));
        for bi in 0..b {
            for ti in 0..t {
                for ci in 0..self.config.num_cls {
                    for ei in 0..e {
                        out[(bi, ti, ci * e + ei)] = cls_norm[(bi * t + ti, ci, ei)];
                    }
                }
            }
        }
        out
    }
}

#[allow(dead_code)]
fn _silence(_a: Array4<f32>) {}

#[cfg(test)]
mod tests {
    use super::*;
    use ndarray::Array;

    #[test]
    fn forward_output_shape() {
        // Small config so the test is fast: embed_dim=8, 2 heads, 1 block,
        // 4 CLS tokens. 2 batches × 3 rows × 6 (=2 features + 4 CLS) tokens.
        let cfg = RowInteractionConfig {
            embed_dim: 8,
            num_blocks: 1,
            nhead: 2,
            dim_feedforward: 16,
            num_cls: 4,
            rope_base: 100_000.0,
            rope_interleaved: false,
            dropout: 0.0,
            activation: Activation::Gelu,
            norm_first: true,
            bias_free_ln: false,
            recompute: false,
        };
        let ri = RowInteraction::new(cfg);
        let emb = Array::from_shape_fn((2, 3, 6, 8), |(b, t, h, e)| {
            ((b * 100 + t * 10 + h) as f32) * 0.001 + (e as f32) * 0.0001
        });
        let out = ri.forward(emb.view());
        assert_eq!(out.shape(), &[2, 3, 4 * 8]);
    }

    #[test]
    fn forward_cls_tokens_are_overwritten_then_propagated() {
        // With zero-init params (out_proj=0, linear2=0), pre-norm path
        // reduces to identity at the residual layer. The CLS positions
        // are overwritten with `cls_tokens` BEFORE the encoder runs, so
        // the output CLS slots reflect the post-LN of the CLS tokens.
        let cfg = RowInteractionConfig {
            embed_dim: 4,
            num_blocks: 1,
            nhead: 2,
            dim_feedforward: 8,
            num_cls: 2,
            rope_base: 100_000.0,
            rope_interleaved: false,
            dropout: 0.0,
            activation: Activation::Gelu,
            norm_first: true,
            bias_free_ln: false,
            recompute: false,
        };
        let mut ri = RowInteraction::new(cfg);
        // Make the CLS tokens distinctive: row 0 is all +1, row 1 is all -1.
        for ei in 0..4 {
            ri.params.cls_tokens[(0, ei)] = 1.0;
            ri.params.cls_tokens[(1, ei)] = -1.0;
        }
        // Input embeddings — anything in the CLS positions is overwritten.
        let emb = Array::from_shape_fn((1, 2, 4, 4), |(_, _, h, e)| (h * 10 + e) as f32);
        let out = ri.forward(emb.view());
        // After pre-norm-identity encoder, CLS positions hold the *LN'd*
        // CLS tokens. LN(constant vector) = 0 with γ=1,β=0 — so the first
        // 4 of repr_dim should be ~0 (CLS_0) and next 4 likewise (CLS_1
        // is also constant per-position).
        for b in 0..1 {
            for t in 0..2 {
                for k in 0..8 {
                    assert!(
                        out[(b, t, k)].abs() < 1e-4,
                        "constant CLS row should LN to zero: out[{b},{t},{k}] = {}",
                        out[(b, t, k)]
                    );
                }
            }
        }
    }

    #[test]
    fn defaults_match_python_signature() {
        let c = RowInteractionConfig::default();
        assert_eq!(c.num_cls, 4);
        assert_eq!(c.embed_dim, 128);
        assert_eq!(c.head_dim(), 16);
        assert_eq!(c.repr_dim(), 4 * 128);
        assert!(c.rope_interleaved);
        assert!(c.norm_first);
    }

    #[test]
    fn norm_first_controls_out_ln_params() {
        let mut c = RowInteractionConfig::default();
        c.norm_first = true;
        c.bias_free_ln = false;
        let p = RowInteractionParams::zeros(&c);
        assert!(p.out_ln_gamma.is_some());
        assert!(p.out_ln_beta.is_some());

        c.bias_free_ln = true;
        let p = RowInteractionParams::zeros(&c);
        assert!(p.out_ln_gamma.is_some());
        assert!(p.out_ln_beta.is_none());

        c.norm_first = false;
        let p = RowInteractionParams::zeros(&c);
        assert!(p.out_ln_gamma.is_none());
        assert!(p.out_ln_beta.is_none());
    }
}