tabicl-model 2.1.1

TabICL transformer model — column embedding, row interaction, ICL learning, KV cache.
//! Flash-Attention-3-via-rlx-cuda integration.
//!
//! rlx-cuda already ships a fused attention kernel
//! (`rlx-cuda/src/kernels/attention.cu`). To use it from tabicl, we
//! build the attention sub-graph via `rlx-ir` and compile to
//! `Device::Cuda`. The same graph runs on `Device::Cpu` via the
//! CPU fallback kernel, which is what we use in tests here since GPU
//! hardware isn't required for CI.
//!
//! Usage:
//!
//! ```ignore
//! use rlx_driver::Device;
//! let out = tabicl_model::gpu_attention::compute_attention(
//!     q.view(), k.view(), v.view(), num_heads, head_dim,
//!     Device::Cuda, // or Device::Cpu for testing
//! );
//! ```

use ndarray::{Array3, ArrayView3};
use rlx_driver::Device;
use rlx_ir::op::MaskKind;
use rlx_ir::{DType, Graph, Shape};
use rlx_runtime::Session;

/// Compute scaled-dot-product attention through the rlx graph compiler.
/// `q`, `k`, `v` are `(B, H, T, D)`. The graph's [`MaskKind::None`] path
/// has the fused-attention fast path on every backend that supports it,
/// including the rlx-cuda Flash-Attention-style kernel.
pub fn compute_attention(
    q: ArrayView3<f32>,
    k: ArrayView3<f32>,
    v: ArrayView3<f32>,
    num_heads: usize,
    head_dim: usize,
    device: Device,
) -> Array3<f32> {
    let (b_q, t_q, hd_q) = (q.shape()[0], q.shape()[1], q.shape()[2]);
    let (b_k, t_k, hd_k) = (k.shape()[0], k.shape()[1], k.shape()[2]);
    assert_eq!(b_q, b_k);
    assert_eq!(v.shape()[0], b_k);
    assert_eq!(v.shape()[1], t_k);
    assert_eq!(hd_q, num_heads * head_dim);
    assert_eq!(hd_k, num_heads * head_dim);

    // Build (B, T, num_heads*head_dim) → attention → same shape.
    let mut g = Graph::new("tabicl_attn");
    let q_id = g.input("q", Shape::new(&[b_q, t_q, hd_q], DType::F32));
    let k_id = g.input("k", Shape::new(&[b_k, t_k, hd_k], DType::F32));
    let v_id = g.input("v", Shape::new(&[b_k, t_k, hd_k], DType::F32));
    let out = g.attention_kind(
        q_id,
        k_id,
        v_id,
        num_heads,
        head_dim,
        MaskKind::None,
        Shape::new(&[b_q, t_q, hd_q], DType::F32),
    );
    g.set_outputs(vec![out]);

    let session = Session::new(device);
    let mut compiled = session.compile(g);
    let q_flat: Vec<f32> = q.iter().copied().collect();
    let k_flat: Vec<f32> = k.iter().copied().collect();
    let v_flat: Vec<f32> = v.iter().copied().collect();
    let outs = compiled.run(&[("q", &q_flat), ("k", &k_flat), ("v", &v_flat)]);
    let flat = &outs[0];
    Array3::from_shape_vec((b_q, t_q, hd_q), flat.clone()).expect("shape matches")
}

#[cfg(test)]
mod tests {
    use super::*;
    use ndarray::Array;

    #[test]
    fn cpu_attention_via_rlx_runs() {
        let b = 1;
        let h = 2;
        let t = 4;
        let d = 4;
        let q = Array::from_shape_fn((b, t, h * d), |(_, t, k)| (t * h * d + k) as f32 * 0.01);
        let k = q.clone();
        let v = q.clone();
        let out = compute_attention(q.view(), k.view(), v.view(), h, d, Device::Cpu);
        assert_eq!(out.shape(), &[b, t, h * d]);
        for x in out.iter() {
            assert!(x.is_finite());
        }
    }

    #[test]
    fn rlx_attention_parity_against_reference_smoke() {
        // Tiny inputs: just verify both implementations produce finite
        // outputs of the right shape. Bit-exact parity against the
        // ndarray reference requires matching reshape semantics with
        // rlx's attention op which is a follow-up.
        let b = 1;
        let h = 2;
        let t = 3;
        let d = 4;
        let q = Array::from_shape_fn((b, t, h * d), |(_, t, k)| ((t * h * d + k) as f32).cos());
        let k = q.clone();
        let v = q.clone();
        let rlx_out = compute_attention(q.view(), k.view(), v.view(), h, d, Device::Cpu);
        assert_eq!(rlx_out.shape(), &[b, t, h * d]);
        for x in rlx_out.iter() {
            assert!(x.is_finite());
        }
    }
}