tabicl-model 2.1.1

TabICL transformer model — column embedding, row interaction, ICL learning, KV cache.
//! Rotary positional embedding — port of `tabicl._model.rope`.
//!
//! The Python module wraps `lucidrains/rotary-embedding-torch` and exposes
//! many knobs (xpos, learned frequencies, pixel/constant variants,
//! interpolation, axial). TabICL only ever instantiates RoPE with the
//! `lang` defaults — that is the only path we need to reproduce exactly:
//!
//!   - `freqs = 1.0 / (theta ** (arange(0, dim, 2)[:dim/2] / dim))`
//!   - per-position phases: `phases[n, k] = n * freqs[k]`  for `n in 0..seq_len`
//!   - **interleaved** mode: `freqs_full = repeat(phases, "n d -> n (d r)", r=2)`,
//!     i.e. each phase value is duplicated side-by-side
//!     `[p0, p0, p1, p1, ...]`
//!   - **non-interleaved** mode: `phases` itself is the `(seq_len, dim/2)`
//!     table, then `apply_rotary_emb` doubles cos/sin by *concatenation*
//!     `[c0, c1, ..., c0, c1, ...]`.
//!
//! Two surfaces here:
//!
//!   1. [`RopeTables`] — host-side precomputation of cos/sin frequencies,
//!      bit-for-bit reproduction of Python's `freqs`.
//!   2. [`apply_rotary_emb_ref`] — numerical reference implementation
//!      operating on ndarray tensors. Used for parity tests. The
//!      rlx-graph-builder path constructs the same cos/sin tables and
//!      hands them to `rlx_ir::HirModule::rope`.

use ndarray::{Array2, Array4, ArrayView4, Axis};

/// Parameters for a RoPE instance. Constructed from
/// [`crate::tabicl::TabICLConfig::row_rope_base`] +
/// [`crate::tabicl::TabICLConfig::row_rope_interleaved`].
#[derive(Debug, Clone, Copy)]
pub struct RopeConfig {
    /// Embedding dimension being rotated (== head_dim in attention).
    pub head_dim: usize,
    /// Base for the frequency table (theta). Python default = 10000;
    /// TabICL uses 100000 for the row interaction transformer.
    pub base: f32,
    /// Interleaved (LLaMA-style) vs non-interleaved (TabICL default = false).
    pub interleaved: bool,
}

impl RopeConfig {
    /// Half-dimension count used by the frequency table.
    #[inline]
    pub fn half(&self) -> usize {
        self.head_dim / 2
    }
}

/// Precomputed cos/sin tables for a fixed `(seq_len, RopeConfig)` pair.
///
/// Mirrors `RotaryEmbedding.forward(arange(seq_len))` evaluated on host
/// then split into cos and sin. The Python module rebuilds these on every
/// call (with optional caching); we build them once at module construction.
#[derive(Debug, Clone)]
pub struct RopeTables {
    pub cfg: RopeConfig,
    pub seq_len: usize,
    /// Shape `(seq_len, head_dim)` for interleaved, `(seq_len, head_dim/2)`
    /// for non-interleaved (matching `apply_rotary_emb`'s expectation).
    pub cos: Array2<f32>,
    pub sin: Array2<f32>,
}

impl RopeTables {
    /// Build the cos/sin tables for `seq_len` positions.
    pub fn new(cfg: RopeConfig, seq_len: usize) -> Self {
        assert!(
            cfg.head_dim.is_multiple_of(2),
            "head_dim must be even for RoPE"
        );
        let half = cfg.half();

        // freqs = 1.0 / (base ** (arange(0, head_dim, 2) / head_dim))
        // length == half.
        let freqs: Vec<f32> = (0..half)
            .map(|k| 1.0_f32 / cfg.base.powf((2 * k) as f32 / cfg.head_dim as f32))
            .collect();

        // phases[n, k] = n * freqs[k], shape (seq_len, half)
        let mut phases = Array2::<f32>::zeros((seq_len, half));
        for n in 0..seq_len {
            for k in 0..half {
                phases[(n, k)] = n as f32 * freqs[k];
            }
        }

        let (cos, sin) = if cfg.interleaved {
            // Python: `repeat(freqs, "... n -> ... (n r)", r=2)`
            // i.e. each phase duplicated side-by-side → shape (seq_len, head_dim).
            let mut c = Array2::<f32>::zeros((seq_len, cfg.head_dim));
            let mut s = Array2::<f32>::zeros((seq_len, cfg.head_dim));
            for n in 0..seq_len {
                for k in 0..half {
                    let cv = phases[(n, k)].cos();
                    let sv = phases[(n, k)].sin();
                    c[(n, 2 * k)] = cv;
                    c[(n, 2 * k + 1)] = cv;
                    s[(n, 2 * k)] = sv;
                    s[(n, 2 * k + 1)] = sv;
                }
            }
            (c, s)
        } else {
            // Non-interleaved: cos/sin computed at half-width;
            // `apply_rotary_emb_ref` does the [c, c] concat at apply-time.
            let c = phases.mapv(f32::cos);
            let s = phases.mapv(f32::sin);
            (c, s)
        };

        Self {
            cfg,
            seq_len,
            cos,
            sin,
        }
    }
}

/// Apply rotary embedding (numerical reference, operates on
/// `(B, H, T, D)` tensors). Mirrors `apply_rotary_emb(freqs, t, ...)` for
/// the TabICL configuration (`start_index=0`, `scale=1.0`, `seq_dim=-2`,
/// `dtype f32`).
///
/// Returns a freshly allocated array.
pub fn apply_rotary_emb_ref(t: &ArrayView4<f32>, tables: &RopeTables) -> Array4<f32> {
    let (b, h, seq, d) = (t.shape()[0], t.shape()[1], t.shape()[2], t.shape()[3]);
    assert!(
        seq <= tables.seq_len,
        "input seq_len {seq} > table seq_len {}",
        tables.seq_len
    );
    assert_eq!(d, tables.cfg.head_dim, "input head_dim mismatch");

    // Python: `if t.ndim == 3 ... freqs = freqs[-seq_len:]`. We have ndim=4,
    // which skips that branch — full table is used (seq_len of input ≤ table).
    let head_dim = d;
    let half = tables.cfg.half();

    let mut out = Array4::<f32>::zeros((b, h, seq, d));

    if tables.cfg.interleaved {
        // cos/sin already at full head_dim.
        for bi in 0..b {
            for hi in 0..h {
                for ti in 0..seq {
                    for k in 0..half {
                        let c = tables.cos[(ti, 2 * k)];
                        let s = tables.sin[(ti, 2 * k)];
                        let x0 = t[(bi, hi, ti, 2 * k)];
                        let x1 = t[(bi, hi, ti, 2 * k + 1)];
                        // rotate_half_interleaved: stack(-x1, x0) per pair.
                        // formula: x*cos + rotate(x)*sin
                        //   → out[2k]   = x0*c + (-x1)*s
                        //   → out[2k+1] = x1*c +   x0 *s
                        out[(bi, hi, ti, 2 * k)] = x0 * c + (-x1) * s;
                        out[(bi, hi, ti, 2 * k + 1)] = x1 * c + x0 * s;
                    }
                }
            }
        }
    } else {
        // Non-interleaved: cos/sin are (seq, half); split tensor into halves.
        // formula:
        //   cos_full = cat([cos, cos], -1)
        //   sin_full = cat([sin, sin], -1)
        //   rotate_half_contiguous(x) = cat([-x_high, x_low], -1)
        //   out = x * cos_full + rotate_half_contiguous(x) * sin_full
        for bi in 0..b {
            for hi in 0..h {
                for ti in 0..seq {
                    for k in 0..half {
                        let c = tables.cos[(ti, k)];
                        let s = tables.sin[(ti, k)];
                        let x_low = t[(bi, hi, ti, k)];
                        let x_high = t[(bi, hi, ti, k + half)];
                        // First half:   x_low * c + (-x_high) * s
                        // Second half:  x_high * c + ( x_low ) * s
                        out[(bi, hi, ti, k)] = x_low * c - x_high * s;
                        out[(bi, hi, ti, k + half)] = x_high * c + x_low * s;
                    }
                }
            }
        }
    }

    // Pass through head_dim bits that lie outside rotation (TabICL: none,
    // since rotation always covers full head_dim — but be explicit).
    if head_dim > 2 * half {
        for bi in 0..b {
            for hi in 0..h {
                for ti in 0..seq {
                    for k in (2 * half)..head_dim {
                        out[(bi, hi, ti, k)] = t[(bi, hi, ti, k)];
                    }
                }
            }
        }
    }

    out
}

/// Convenience: build a `(seq_len, head_dim)` cos table for passing into
/// `rlx_ir::HirModule::rope`. For interleaved mode this is the table
/// directly; for non-interleaved mode we expand `[cos, cos]` along the
/// last axis so rlx's RoPE op sees a head-dim-wide cosine.
pub fn cos_table_full(tables: &RopeTables) -> Array2<f32> {
    if tables.cfg.interleaved {
        tables.cos.clone()
    } else {
        ndarray::concatenate(Axis(1), &[tables.cos.view(), tables.cos.view()]).unwrap()
    }
}

/// Companion to [`cos_table_full`].
pub fn sin_table_full(tables: &RopeTables) -> Array2<f32> {
    if tables.cfg.interleaved {
        tables.sin.clone()
    } else {
        ndarray::concatenate(Axis(1), &[tables.sin.view(), tables.sin.view()]).unwrap()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use approx::assert_abs_diff_eq;

    /// Freqs for the language formula reproduce Python's
    /// `1.0 / (theta ** (arange(0, dim, 2) / dim))` exactly.
    #[test]
    fn lang_freqs_match_python_formula() {
        let cfg = RopeConfig {
            head_dim: 16,
            base: 100_000.0,
            interleaved: false,
        };
        let t = RopeTables::new(cfg, 1);
        // phases[0,k] = 0 → cos = 1, sin = 0
        for k in 0..cfg.half() {
            assert_abs_diff_eq!(t.cos[(0, k)], 1.0, epsilon = 1e-7);
            assert_abs_diff_eq!(t.sin[(0, k)], 0.0, epsilon = 1e-7);
        }
        // phases[1,k] = freqs[k] = 1 / base^(2k/dim).
        // Specifically, k=0 → freqs[0]=1, so cos = cos(1), sin = sin(1).
        let t2 = RopeTables::new(cfg, 2);
        assert_abs_diff_eq!(t2.cos[(1, 0)], 1.0_f32.cos(), epsilon = 1e-6);
        assert_abs_diff_eq!(t2.sin[(1, 0)], 1.0_f32.sin(), epsilon = 1e-6);
        // k = half-1 → freqs = 1 / base^((dim-2)/dim).
        let k_last = cfg.half() - 1;
        let f_last = 1.0_f32 / cfg.base.powf((2 * k_last) as f32 / cfg.head_dim as f32);
        assert_abs_diff_eq!(t2.cos[(1, k_last)], f_last.cos(), epsilon = 1e-6);
    }

    fn assert_arrays_close(a: &Array4<f32>, b: &Array4<f32>, eps: f32) {
        assert_eq!(a.shape(), b.shape());
        for (x, y) in a.iter().zip(b.iter()) {
            assert!(
                (x - y).abs() <= eps,
                "values differ: {} vs {} (eps {})",
                x,
                y,
                eps
            );
        }
    }

    /// At position 0, rotation is the identity (cos=1, sin=0).
    #[test]
    fn identity_at_position_zero_non_interleaved() {
        let cfg = RopeConfig {
            head_dim: 8,
            base: 100_000.0,
            interleaved: false,
        };
        let tables = RopeTables::new(cfg, 4);
        let x = Array4::from_shape_fn((2, 3, 1, 8), |(b, h, _, d)| {
            (b as f32) * 0.1 + (h as f32) * 0.01 + (d as f32) * 0.001
        });
        let y = apply_rotary_emb_ref(&x.view(), &tables);
        assert_arrays_close(&x, &y, 1e-6);
    }

    #[test]
    fn identity_at_position_zero_interleaved() {
        let cfg = RopeConfig {
            head_dim: 8,
            base: 100_000.0,
            interleaved: true,
        };
        let tables = RopeTables::new(cfg, 4);
        let x = Array4::from_shape_fn((2, 3, 1, 8), |(b, h, _, d)| {
            (b as f32) * 0.1 + (h as f32) * 0.01 + (d as f32) * 0.001
        });
        let y = apply_rotary_emb_ref(&x.view(), &tables);
        assert_arrays_close(&x, &y, 1e-6);
    }

    /// Rotation by angle θ preserves the per-pair L2 norm in interleaved
    /// mode, and the L2 norm of (x_low, x_high) pairs in non-interleaved.
    #[test]
    fn rotation_preserves_norm_interleaved() {
        let cfg = RopeConfig {
            head_dim: 8,
            base: 10_000.0,
            interleaved: true,
        };
        let tables = RopeTables::new(cfg, 5);
        let x = Array4::from_shape_fn((1, 2, 5, 8), |(_, _, t, d)| ((t * 8 + d + 1) as f32).sin());
        let y = apply_rotary_emb_ref(&x.view(), &tables);
        for h in 0..2 {
            for t in 0..5 {
                for k in 0..cfg.half() {
                    let pre = x[(0, h, t, 2 * k)].powi(2) + x[(0, h, t, 2 * k + 1)].powi(2);
                    let post = y[(0, h, t, 2 * k)].powi(2) + y[(0, h, t, 2 * k + 1)].powi(2);
                    assert_abs_diff_eq!(pre, post, epsilon = 1e-5);
                }
            }
        }
    }

    #[test]
    fn rotation_preserves_norm_non_interleaved() {
        let cfg = RopeConfig {
            head_dim: 8,
            base: 10_000.0,
            interleaved: false,
        };
        let tables = RopeTables::new(cfg, 5);
        let x = Array4::from_shape_fn((1, 2, 5, 8), |(_, _, t, d)| ((t * 8 + d + 1) as f32).cos());
        let y = apply_rotary_emb_ref(&x.view(), &tables);
        let half = cfg.half();
        for h in 0..2 {
            for t in 0..5 {
                for k in 0..half {
                    let pre = x[(0, h, t, k)].powi(2) + x[(0, h, t, k + half)].powi(2);
                    let post = y[(0, h, t, k)].powi(2) + y[(0, h, t, k + half)].powi(2);
                    assert_abs_diff_eq!(pre, post, epsilon = 1e-5);
                }
            }
        }
    }

    #[test]
    fn full_cos_sin_tables_double_correctly() {
        let cfg = RopeConfig {
            head_dim: 8,
            base: 10_000.0,
            interleaved: false,
        };
        let t = RopeTables::new(cfg, 3);
        let cos = cos_table_full(&t);
        let sin = sin_table_full(&t);
        assert_eq!(cos.shape(), &[3, 8]);
        assert_eq!(sin.shape(), &[3, 8]);
        // Halves are duplicated by concatenation.
        for ti in 0..3 {
            for k in 0..4 {
                assert_eq!(cos[(ti, k)], cos[(ti, k + 4)]);
                assert_eq!(sin[(ti, k)], sin[(ti, k + 4)]);
            }
        }
    }
}