use rayon::prelude::*;
use sapient_core::error::{Result, SapientError};
use sapient_core::{Shape, Tensor};
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn dot_f32_neon(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
let n = a.len();
let mut acc = vdupq_n_f32(0.0);
let mut i = 0;
while i + 4 <= n {
let va = vld1q_f32(a.as_ptr().add(i));
let vb = vld1q_f32(b.as_ptr().add(i));
acc = vfmaq_f32(acc, va, vb);
i += 4;
}
let mut s = vaddvq_f32(acc);
while i < n {
s += a[i] * b[i];
i += 1;
}
s
}
#[cfg(not(target_arch = "aarch64"))]
fn dot_f32_neon(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(x, y)| x * y).sum()
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn saxpby_neon(o: &mut [f32], v: &[f32], alpha: f32, beta: f32) {
use std::arch::aarch64::*;
let va = vdupq_n_f32(alpha);
let vb = vdupq_n_f32(beta);
let n = o.len();
let mut i = 0;
while i + 4 <= n {
let vo = vld1q_f32(o.as_ptr().add(i));
let vv = vld1q_f32(v.as_ptr().add(i));
let r = vfmaq_f32(vmulq_f32(vo, va), vv, vb);
vst1q_f32(o.as_mut_ptr().add(i), r);
i += 4;
}
while i < n {
o[i] = alpha * o[i] + beta * v[i];
i += 1;
}
}
#[cfg(not(target_arch = "aarch64"))]
fn saxpby_neon(o: &mut [f32], v: &[f32], alpha: f32, beta: f32) {
for (oi, vi) in o.iter_mut().zip(v) {
*oi = alpha * *oi + beta * vi;
}
}
#[inline(always)]
fn flash_attn_row(
q_row: &[f32], k_head: &[f32], v_head: &[f32], o_row: &mut [f32], scale: f32,
_seq_k: usize,
head_dim: usize,
attend_len: usize, mask_row: Option<&[f32]>, ) {
let mut m = f32::NEG_INFINITY; let mut l = 0.0f32;
for x in o_row.iter_mut() {
*x = 0.0;
}
for ki in 0..attend_len {
let k_row = &k_head[ki * head_dim..(ki + 1) * head_dim];
#[cfg(target_arch = "aarch64")]
let raw_s = unsafe { dot_f32_neon(q_row, k_row) } * scale;
#[cfg(not(target_arch = "aarch64"))]
let raw_s = dot_f32_neon(q_row, k_row) * scale;
let s = raw_s + mask_row.map(|m| m[ki]).unwrap_or(0.0);
let m_new = if s > m { s } else { m };
let p = (s - m_new).exp();
let correction = (m - m_new).exp();
let v_row = &v_head[ki * head_dim..(ki + 1) * head_dim];
#[cfg(target_arch = "aarch64")]
unsafe {
saxpby_neon(o_row, v_row, correction, p);
}
#[cfg(not(target_arch = "aarch64"))]
saxpby_neon(o_row, v_row, correction, p);
l = correction * l + p;
m = m_new;
}
let inv_l = if l == 0.0 { 1.0 / f32::EPSILON } else { 1.0 / l };
for x in o_row.iter_mut() {
*x *= inv_l;
}
}
pub fn scaled_dot_product_attention(
q: &Tensor,
k: &Tensor,
v: &Tensor,
mask: Option<&Tensor>,
scale: Option<f32>,
n_kv_heads: usize,
) -> Result<Tensor> {
let qs = q.shape().dims().to_vec();
let ks = k.shape().dims().to_vec();
if qs.len() != 4 {
return Err(SapientError::RankMismatch {
expected: 4,
got: qs.len(),
});
}
let (batch, n_heads, seq_q, head_dim) = (qs[0], qs[1], qs[2], qs[3]);
let seq_k = ks[2];
let scale = scale.unwrap_or(1.0 / (head_dim as f32).sqrt());
let kv_rep = n_heads / n_kv_heads;
let k_data: Vec<f32> = k.to_contiguous_f32_vec();
let v_data: Vec<f32> = v.to_contiguous_f32_vec();
let q_cow = q.to_f32_cow();
let q_data = q_cow.as_ref();
let q_strides = q.strides();
let mask_cow = mask.map(|m| m.to_f32_cow());
let mask_data: Option<&[f32]> = mask_cow.as_deref();
let kv_offset = seq_k.saturating_sub(seq_q);
let head_out_size = seq_q * head_dim;
let kv_head_size = seq_k * head_dim;
let mut out = vec![0.0f32; batch * n_heads * head_out_size];
out.par_chunks_mut(head_out_size)
.enumerate()
.for_each(|(bh, out_chunk)| {
let b = bh / n_heads;
let h = bh % n_heads;
let kv_h = h / kv_rep;
let k_base = (b * n_kv_heads + kv_h) * kv_head_size;
let v_base = (b * n_kv_heads + kv_h) * kv_head_size;
let k_head = &k_data[k_base..k_base + kv_head_size];
let v_head = &v_data[v_base..v_base + kv_head_size];
for qi in 0..seq_q {
let q_base_elem =
b * q_strides[0] + h * q_strides[1] + qi * q_strides[2];
let q_row_owned: Vec<f32>;
let q_row: &[f32] = if q_strides[3] == 1 {
&q_data[q_base_elem..q_base_elem + head_dim]
} else {
q_row_owned = (0..head_dim)
.map(|d| q_data[q_base_elem + d * q_strides[3]])
.collect();
&q_row_owned
};
let attend_len = if mask_data.is_some() {
seq_k
} else {
(qi + kv_offset + 1).min(seq_k)
};
let mask_row = mask_data.map(|m| &m[qi * seq_k..(qi + 1) * seq_k]);
let o_row = &mut out_chunk[qi * head_dim..(qi + 1) * head_dim];
flash_attn_row(
q_row,
k_head,
v_head,
o_row,
scale,
seq_k,
head_dim,
attend_len,
mask_row,
);
}
});
Tensor::from_f32(&out, Shape::new([batch, n_heads, seq_q, head_dim]))
}
pub fn causal_mask(seq_q: usize, seq_k: usize) -> Tensor {
let mut data = vec![0.0f32; seq_q * seq_k];
let offset = seq_k.saturating_sub(seq_q);
for qi in 0..seq_q {
for ki in 0..seq_k {
if ki > qi + offset {
data[qi * seq_k + ki] = f32::NEG_INFINITY;
}
}
}
Tensor::from_f32(&data, vec![seq_q, seq_k]).unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mha_output_shape() {
let q = Tensor::from_f32(&[0.1f32; 24], vec![1, 2, 3, 4]).unwrap();
let k = Tensor::from_f32(&[0.1f32; 24], vec![1, 2, 3, 4]).unwrap();
let v = Tensor::from_f32(&[0.1f32; 24], vec![1, 2, 3, 4]).unwrap();
let out = scaled_dot_product_attention(&q, &k, &v, None, None, 2).unwrap();
assert_eq!(out.shape().dims(), &[1, 2, 3, 4]);
}
#[test]
fn gqa_kv_repeat() {
let q = Tensor::from_f32(&[0.1f32; 32], vec![1, 4, 2, 4]).unwrap();
let k = Tensor::from_f32(&[0.1f32; 16], vec![1, 2, 2, 4]).unwrap();
let v = Tensor::from_f32(&[0.1f32; 16], vec![1, 2, 2, 4]).unwrap();
let out = scaled_dot_product_attention(&q, &k, &v, None, None, 2).unwrap();
assert_eq!(out.shape().dims(), &[1, 4, 2, 4]);
}
#[test]
fn causal_mask_shape() {
let m = causal_mask(3, 3);
let d = m.as_f32_slice();
assert!(d[1].is_infinite() && d[1] < 0.0);
assert_eq!(d[3], 0.0);
}
#[test]
fn uniform_attention_recovers_v() {
let seq = 4usize;
let dim = 8usize;
let mut v_data = vec![0.0f32; seq * dim];
for i in 0..seq {
for d in 0..dim {
v_data[i * dim + d] = (i + 1) as f32;
}
}
let q = Tensor::from_f32(&vec![1.0f32; seq * dim], vec![1, 1, seq, dim]).unwrap();
let k = Tensor::from_f32(&vec![1.0f32; seq * dim], vec![1, 1, seq, dim]).unwrap();
let v = Tensor::from_f32(&v_data, vec![1, 1, seq, dim]).unwrap();
let mask = Tensor::from_f32(&vec![0.0f32; seq * seq], vec![seq, seq]).unwrap();
let out = scaled_dot_product_attention(&q, &k, &v, Some(&mask), None, 1).unwrap();
let out_data = out.as_f32_slice();
let expected = (1..=seq).map(|x| x as f32).sum::<f32>() / seq as f32;
for &val in out_data.iter() {
let diff = (val - expected).abs();
assert!(diff < 1e-4, "Expected ~{expected}, got {val}");
}
}
#[test]
fn flash_matches_naive() {
use std::f32;
let batch = 1;
let n_heads = 2;
let seq_q = 4;
let seq_k = 4;
let head_dim = 8;
let gen = |i: usize| (i as f32 * 1.3 + 0.7).sin() * 0.5 + 0.5;
let q_data: Vec<f32> = (0..batch * n_heads * seq_q * head_dim)
.map(gen)
.collect();
let k_data: Vec<f32> = (0..batch * n_heads * seq_k * head_dim)
.map(|i| gen(i + 100))
.collect();
let v_data: Vec<f32> = (0..batch * n_heads * seq_k * head_dim)
.map(|i| gen(i + 200))
.collect();
let q = Tensor::from_f32(&q_data, vec![batch, n_heads, seq_q, head_dim]).unwrap();
let k = Tensor::from_f32(&k_data, vec![batch, n_heads, seq_k, head_dim]).unwrap();
let v = Tensor::from_f32(&v_data, vec![batch, n_heads, seq_k, head_dim]).unwrap();
let mask_t = causal_mask(seq_q, seq_k);
let flash_out =
scaled_dot_product_attention(&q, &k, &v, Some(&mask_t), None, n_heads).unwrap();
let scale = 1.0 / (head_dim as f32).sqrt();
let mask_data = mask_t.as_f32_slice();
let mut ref_out = vec![0.0f32; batch * n_heads * seq_q * head_dim];
for b in 0..batch {
for h in 0..n_heads {
let q_off = (b * n_heads + h) * seq_q * head_dim;
let k_off = (b * n_heads + h) * seq_k * head_dim;
let v_off = (b * n_heads + h) * seq_k * head_dim;
let o_off = (b * n_heads + h) * seq_q * head_dim;
for qi in 0..seq_q {
let mut scores = vec![0.0f32; seq_k];
for ki in 0..seq_k {
let dot: f32 = (0..head_dim)
.map(|d| {
q_data[q_off + qi * head_dim + d]
* k_data[k_off + ki * head_dim + d]
})
.sum();
scores[ki] = dot * scale + mask_data[qi * seq_k + ki];
}
let max_s = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let max_s = if max_s.is_infinite() { 0.0 } else { max_s };
let mut sum = 0.0f32;
for s in scores.iter_mut() {
*s = (*s - max_s).exp();
sum += *s;
}
if sum < f32::EPSILON {
sum = f32::EPSILON;
}
for d in 0..head_dim {
let acc: f32 = (0..seq_k)
.map(|ki| scores[ki] / sum * v_data[v_off + ki * head_dim + d])
.sum();
ref_out[o_off + qi * head_dim + d] = acc;
}
}
}
}
let flash_data = flash_out.as_f32_slice();
for (i, (&flash, &reference)) in flash_data.iter().zip(ref_out.iter()).enumerate() {
let diff = (flash - reference).abs();
assert!(
diff < 1e-4,
"Mismatch at index {i}: flash={flash} ref={reference} diff={diff}"
);
}
}
#[test]
fn decode_mode_no_nan() {
let batch = 1;
let n_heads = 4;
let seq_q = 1;
let seq_k = 16; let head_dim = 8;
let q = Tensor::from_f32(&vec![0.1f32; batch * n_heads * seq_q * head_dim], vec![batch, n_heads, seq_q, head_dim]).unwrap();
let k = Tensor::from_f32(&vec![0.1f32; batch * n_heads * seq_k * head_dim], vec![batch, n_heads, seq_k, head_dim]).unwrap();
let v = Tensor::from_f32(&vec![0.2f32; batch * n_heads * seq_k * head_dim], vec![batch, n_heads, seq_k, head_dim]).unwrap();
let out = scaled_dot_product_attention(&q, &k, &v, None, None, n_heads).unwrap();
assert_eq!(out.shape().dims(), &[batch, n_heads, seq_q, head_dim]);
for &val in out.as_f32_slice() {
assert!(val.is_finite(), "NaN/Inf in decode output: {val}");
}
}
}