use ndarray::{Array3, ArrayView3};
use rlx_driver::Device;
use rlx_ir::op::MaskKind;
use rlx_ir::{DType, Graph, Shape};
use rlx_runtime::Session;
pub fn compute_attention(
q: ArrayView3<f32>,
k: ArrayView3<f32>,
v: ArrayView3<f32>,
num_heads: usize,
head_dim: usize,
device: Device,
) -> Array3<f32> {
let (b_q, t_q, hd_q) = (q.shape()[0], q.shape()[1], q.shape()[2]);
let (b_k, t_k, hd_k) = (k.shape()[0], k.shape()[1], k.shape()[2]);
assert_eq!(b_q, b_k);
assert_eq!(v.shape()[0], b_k);
assert_eq!(v.shape()[1], t_k);
assert_eq!(hd_q, num_heads * head_dim);
assert_eq!(hd_k, num_heads * head_dim);
let mut g = Graph::new("tabicl_attn");
let q_id = g.input("q", Shape::new(&[b_q, t_q, hd_q], DType::F32));
let k_id = g.input("k", Shape::new(&[b_k, t_k, hd_k], DType::F32));
let v_id = g.input("v", Shape::new(&[b_k, t_k, hd_k], DType::F32));
let out = g.attention_kind(
q_id,
k_id,
v_id,
num_heads,
head_dim,
MaskKind::None,
Shape::new(&[b_q, t_q, hd_q], DType::F32),
);
g.set_outputs(vec![out]);
let session = Session::new(device);
let mut compiled = session.compile(g);
let q_flat: Vec<f32> = q.iter().copied().collect();
let k_flat: Vec<f32> = k.iter().copied().collect();
let v_flat: Vec<f32> = v.iter().copied().collect();
let outs = compiled.run(&[("q", &q_flat), ("k", &k_flat), ("v", &v_flat)]);
let flat = &outs[0];
Array3::from_shape_vec((b_q, t_q, hd_q), flat.clone()).expect("shape matches")
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array;
#[test]
fn cpu_attention_via_rlx_runs() {
let b = 1;
let h = 2;
let t = 4;
let d = 4;
let q = Array::from_shape_fn((b, t, h * d), |(_, t, k)| (t * h * d + k) as f32 * 0.01);
let k = q.clone();
let v = q.clone();
let out = compute_attention(q.view(), k.view(), v.view(), h, d, Device::Cpu);
assert_eq!(out.shape(), &[b, t, h * d]);
for x in out.iter() {
assert!(x.is_finite());
}
}
#[test]
fn rlx_attention_parity_against_reference_smoke() {
let b = 1;
let h = 2;
let t = 3;
let d = 4;
let q = Array::from_shape_fn((b, t, h * d), |(_, t, k)| ((t * h * d + k) as f32).cos());
let k = q.clone();
let v = q.clone();
let rlx_out = compute_attention(q.view(), k.view(), v.view(), h, d, Device::Cpu);
assert_eq!(rlx_out.shape(), &[b, t, h * d]);
for x in rlx_out.iter() {
assert!(x.is_finite());
}
}
}