tabicl-model 2.1.1

TabICL transformer model — column embedding, row interaction, ICL learning, KV cache.
//! In-context learning transformer — port of
//! `tabicl._model.learning.ICLearning`.
//!
//! Flow (single-forward, no hierarchical decoding, no train_size masking):
//!
//!   1. Embed `y_train` via `OneHotAndLinear` (classification) or a
//!      `Linear(1 → d_model)` (regression). Shape `(B, train_size, d_model)`.
//!   2. Add the y-embedding back into the first `train_size` positions of
//!      the row representations `R`.
//!   3. Run the modified `R` through the ICL transformer encoder.
//!   4. Apply a final LayerNorm if `norm_first`.
//!   5. Decode with a 2-layer MLP: `Linear(d_model → 2*d_model) → GELU →
//!      Linear(2*d_model → out_dim)`.
//!
//! Status: reference (host fp32) implementation of the standard
//! (non-hierarchical) classification path and the regression path.
//! Hierarchical many-class decoding (`_fit_hierarchical` /
//! `_predict_hierarchical`) and the `train_size`-aware attention mask
//! are follow-up items.

use ndarray::{Array2, Array3, ArrayView2, ArrayView3};
use serde::{Deserialize, Serialize};

use crate::encoders::{EncoderStack, MabConfig};
use crate::layers::{OneHotAndLinear, layer_norm_last, linear3d};
use crate::state_dict::{StateDict, StateDictError};
use crate::tabicl::Activation;

/// Static config for the ICL transformer. Mirrors `ICLearning.__init__`.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ICLearningConfig {
    pub max_classes: usize,
    pub out_dim: usize,
    pub d_model: usize,
    pub num_blocks: usize,
    pub nhead: usize,
    pub dim_feedforward: usize,
    pub dropout: f32,
    pub activation: Activation,
    pub norm_first: bool,
    pub bias_free_ln: bool,
    pub ssmax: String,
    pub recompute: bool,
}

impl ICLearningConfig {
    pub fn is_regression(&self) -> bool {
        self.max_classes == 0
    }
}

/// Parameters owned by the ICL transformer.
#[derive(Debug, Clone)]
pub struct ICLearningParams {
    /// Classification y-encoder, `Some` iff `max_classes > 0`.
    pub y_one_hot: Option<OneHotAndLinear>,
    /// Regression y-encoder, `Some` iff `max_classes == 0`.
    ///   - shape `(d_model, 1)` for the weight,
    ///   - shape `(d_model,)` for the optional bias.
    pub y_linear: Option<(Array2<f32>, Option<Vec<f32>>)>,
    /// Final LayerNorm γ/β. `None` when `norm_first == false` (Identity).
    pub ln_gamma: Option<Vec<f32>>,
    pub ln_beta: Option<Vec<f32>>,
    /// 2-layer decoder: `Linear(d_model → 2*d_model)`, GELU,
    /// `Linear(2*d_model → out_dim)`.
    pub decoder_w1: Array2<f32>,
    pub decoder_b1: Option<Vec<f32>>,
    pub decoder_w2: Array2<f32>,
    pub decoder_b2: Option<Vec<f32>>,
}

impl ICLearningParams {
    pub fn zeros(cfg: &ICLearningConfig) -> Self {
        let d = cfg.d_model;
        let hid = d * 2;
        Self {
            y_one_hot: if cfg.max_classes > 0 {
                Some(OneHotAndLinear::from_raw_weight(
                    Array2::<f32>::zeros((d, cfg.max_classes)),
                    Some(vec![0.0; d]),
                ))
            } else {
                None
            },
            y_linear: if cfg.is_regression() {
                Some((Array2::<f32>::zeros((d, 1)), Some(vec![0.0; d])))
            } else {
                None
            },
            ln_gamma: if cfg.norm_first {
                Some(vec![1.0; d])
            } else {
                None
            },
            ln_beta: if cfg.norm_first && !cfg.bias_free_ln {
                Some(vec![0.0; d])
            } else {
                None
            },
            decoder_w1: Array2::<f32>::zeros((hid, d)),
            decoder_b1: Some(vec![0.0; hid]),
            decoder_w2: Array2::<f32>::zeros((cfg.out_dim, hid)),
            decoder_b2: Some(vec![0.0; cfg.out_dim]),
        }
    }
}

/// In-context learning transformer.
#[derive(Debug, Clone)]
pub struct ICLearning {
    pub config: ICLearningConfig,
    pub params: ICLearningParams,
    pub encoder: EncoderStack,
}

impl ICLearning {
    /// Load weights from a Python state dict under `{prefix}`. Keys used:
    ///
    ///   - `{prefix}.tf_icl.blocks.i.…`  — encoder stack
    ///   - `{prefix}.ln.weight` / `.bias` — final LN (norm_first only)
    ///   - `{prefix}.y_encoder.weight` / `.bias` — Linear or OneHotAndLinear
    ///   - `{prefix}.decoder.0.weight` / `.bias` — Linear(d_model → 2*d_model)
    ///   - `{prefix}.decoder.2.weight` / `.bias` — Linear(2*d_model → out_dim)
    pub fn load_from(&mut self, sd: &StateDict, prefix: &str) -> Result<(), StateDictError> {
        let ssmax_kind = crate::ssmax::SsmaxKind::parse(&self.config.ssmax)
            .unwrap_or(crate::ssmax::SsmaxKind::None);
        self.encoder
            .load_from_with_ssmax(sd, &format!("{prefix}.tf_icl"), ssmax_kind)?;

        if self.config.norm_first {
            self.params.ln_gamma =
                Some(sd.take_vec(&format!("{prefix}.ln.weight"), self.config.d_model)?);
            let beta_key = format!("{prefix}.ln.bias");
            if sd.tensors.contains_key(&beta_key) {
                self.params.ln_beta = Some(sd.take_vec(&beta_key, self.config.d_model)?);
            }
        }

        let d = self.config.d_model;
        let hid = d * 2;

        // y-encoder (one branch active depending on task type).
        if let Some(enc) = self.params.y_one_hot.as_mut() {
            enc.load_from(sd, &format!("{prefix}.y_encoder"))?;
        } else if let Some((w, bias)) = self.params.y_linear.as_mut() {
            *w = sd.take_array2(&format!("{prefix}.y_encoder.weight"), d, 1)?;
            let bias_key = format!("{prefix}.y_encoder.bias");
            if sd.tensors.contains_key(&bias_key) {
                *bias = Some(sd.take_vec(&bias_key, d)?);
            }
        }

        // Decoder is `nn.Sequential(Linear, GELU, Linear)`; index 0 and 2
        // are the linears (1 is the activation, no params).
        self.params.decoder_w1 = sd.take_array2(&format!("{prefix}.decoder.0.weight"), hid, d)?;
        let b1k = format!("{prefix}.decoder.0.bias");
        if sd.tensors.contains_key(&b1k) {
            self.params.decoder_b1 = Some(sd.take_vec(&b1k, hid)?);
        }
        self.params.decoder_w2 = sd.take_array2(
            &format!("{prefix}.decoder.2.weight"),
            self.config.out_dim,
            hid,
        )?;
        let b2k = format!("{prefix}.decoder.2.bias");
        if sd.tensors.contains_key(&b2k) {
            self.params.decoder_b2 = Some(sd.take_vec(&b2k, self.config.out_dim)?);
        }
        Ok(())
    }

    pub fn new(config: ICLearningConfig) -> Self {
        let params = ICLearningParams::zeros(&config);
        let mab_cfg = MabConfig {
            d_model: config.d_model,
            nhead: config.nhead,
            dim_feedforward: config.dim_feedforward,
            dropout: config.dropout,
            activation: config.activation,
            norm_first: config.norm_first,
            bias_free_ln: config.bias_free_ln,
        };
        let encoder = EncoderStack::new(config.num_blocks, mab_cfg, None)
            .expect("ICLearning: d_model must be divisible by nhead");
        Self {
            config,
            params,
            encoder,
        }
    }

    /// Forward — port of `ICLearning._icl_predictions`.
    ///
    /// `r` shape: `(B, T, d_model)`. `y_train` is either:
    ///   - integer class indices `(B, train_size)` when classification,
    ///   - float targets `(B, train_size)` when regression.
    ///
    /// Output: `(B, T, out_dim)`.
    pub fn forward(
        &self,
        r: ArrayView3<f32>,
        y_train_class: Option<ArrayView2<usize>>,
        y_train_reg: Option<ArrayView2<f32>>,
    ) -> Array3<f32> {
        let (b, t, d) = (r.shape()[0], r.shape()[1], r.shape()[2]);
        assert_eq!(d, self.config.d_model);

        // 1. Embed y_train into d_model space.
        let ry_train: Array3<f32> = match (
            &self.params.y_one_hot,
            &self.params.y_linear,
            y_train_class,
            y_train_reg,
        ) {
            (Some(enc), _, Some(y_cls), None) => enc.forward(y_cls),
            (_, Some((w, bias)), None, Some(y_reg)) => {
                // y_reg is (B, train_size); unsqueeze last dim → (B, T, 1).
                let (br, tr) = (y_reg.shape()[0], y_reg.shape()[1]);
                let mut y3 = Array3::<f32>::zeros((br, tr, 1));
                for bi in 0..br {
                    for ti in 0..tr {
                        y3[(bi, ti, 0)] = y_reg[(bi, ti)];
                    }
                }
                linear3d(y3.view(), w.view(), bias.as_deref())
            }
            _ => panic!("y_train shape doesn't match task type (classification vs regression)"),
        };

        let train_size = ry_train.shape()[1];
        assert!(train_size <= t, "train_size {train_size} > total {t}");

        // 2. Add y-embedding into the first `train_size` positions.
        let mut r_aug = r.to_owned();
        for bi in 0..b {
            for ti in 0..train_size {
                for di in 0..d {
                    r_aug[(bi, ti, di)] += ry_train[(bi, ti, di)];
                }
            }
        }

        // 3. Encoder forward with train-size masking: test rows only see
        //    training rows (the first `train_size` positions). Mirrors
        //    Python `tf_icl(R, train_size=train_size)`.
        let src = self
            .encoder
            .forward_train_size(r_aug.view(), Some(train_size));

        // 4. Final LayerNorm (norm_first path).
        let src_normed = match &self.params.ln_gamma {
            Some(g) => layer_norm_last(src.view(), g, self.params.ln_beta.as_deref(), 1e-5),
            None => src,
        };

        // 5. Decoder MLP: linear → GELU → linear.
        let mut h = linear3d(
            src_normed.view(),
            self.params.decoder_w1.view(),
            self.params.decoder_b1.as_deref(),
        );
        // GELU with exact erf — matches PyTorch's `nn.GELU()` default.
        for v in h.iter_mut() {
            let xv = *v;
            *v = 0.5 * xv * (1.0 + erf_f32(xv / std::f32::consts::SQRT_2));
        }
        linear3d(
            h.view(),
            self.params.decoder_w2.view(),
            self.params.decoder_b2.as_deref(),
        )
    }
}

/// Abramowitz–Stegun 7.1.26 erf approximation, matches PyTorch's
/// `nn.GELU()` default within fp32 noise.
fn erf_f32(x: f32) -> f32 {
    let sign = x.signum();
    let ax = x.abs();
    let t = 1.0 / (1.0 + 0.3275911 * ax);
    let y = 1.0
        - (((((1.061_405_4_f32 * t - 1.453_152_1) * t + 1.421_413_8) * t - 0.284_496_72) * t
            + 0.254_829_6)
            * t)
            * (-ax * ax).exp();
    sign * y
}

#[cfg(test)]
mod tests {
    use super::*;
    use ndarray::Array;

    fn small_cfg(max_classes: usize, out_dim: usize) -> ICLearningConfig {
        ICLearningConfig {
            max_classes,
            out_dim,
            d_model: 4,
            num_blocks: 1,
            nhead: 2,
            dim_feedforward: 8,
            dropout: 0.0,
            activation: Activation::Gelu,
            norm_first: true,
            bias_free_ln: false,
            ssmax: "none".into(),
            recompute: false,
        }
    }

    #[test]
    fn classification_forward_output_shape() {
        let cfg = small_cfg(3, 3);
        let icl = ICLearning::new(cfg);
        // B=2, T=5, train_size=3
        let r = Array::from_shape_fn((2, 5, 4), |(b, t, d)| (b * 100 + t * 10 + d) as f32 * 0.001);
        let y_train: Array2<usize> =
            Array::from_shape_vec((2, 3), vec![0_usize, 1, 2, 2, 1, 0]).unwrap();
        let out = icl.forward(r.view(), Some(y_train.view()), None);
        assert_eq!(out.shape(), &[2, 5, 3]);
    }

    #[test]
    fn regression_forward_output_shape() {
        let cfg = small_cfg(0, 999); // 0 classes → regression
        let icl = ICLearning::new(cfg);
        let r = Array::from_shape_fn((1, 4, 4), |(b, t, d)| (b * 16 + t * 4 + d) as f32 * 0.01);
        let y_train: Array2<f32> = Array::from_shape_vec((1, 2), vec![0.5_f32, 1.5]).unwrap();
        let out = icl.forward(r.view(), None, Some(y_train.view()));
        assert_eq!(out.shape(), &[1, 4, 999]);
    }

    #[test]
    fn zero_init_decoder_gives_zero_logits() {
        let cfg = small_cfg(4, 4);
        let icl = ICLearning::new(cfg);
        let r = Array::from_shape_fn((1, 3, 4), |(_, t, d)| (t * 4 + d) as f32 * 0.01);
        let y_train: Array2<usize> = Array::from_shape_vec((1, 2), vec![0_usize, 1]).unwrap();
        let out = icl.forward(r.view(), Some(y_train.view()), None);
        // Decoder weights/biases zero → all logits are zero.
        for v in out.iter() {
            assert!(v.abs() < 1e-5, "expected zero logit, got {}", v);
        }
    }

    #[test]
    fn rejects_mismatched_task_type() {
        let cfg = small_cfg(3, 3); // classification
        let icl = ICLearning::new(cfg);
        let r = Array::from_shape_vec((1, 2, 4), vec![0.0_f32; 8]).unwrap();
        let y_reg = Array::from_shape_vec((1, 1), vec![1.0_f32]).unwrap();
        // Should panic because we passed a regression target to a
        // classification ICL head.
        let result = std::panic::catch_unwind(|| {
            icl.forward(r.view(), None, Some(y_reg.view()));
        });
        assert!(result.is_err());
    }
}