use rayon::prelude::*;
use sapient_core::error::{Result, SapientError};
use sapient_core::{Shape, Tensor};
pub fn scaled_dot_product_attention(
q: &Tensor,
k: &Tensor,
v: &Tensor,
mask: Option<&Tensor>,
scale: Option<f32>,
n_kv_heads: usize, ) -> Result<Tensor> {
let qs = q.shape().dims().to_vec();
let ks = k.shape().dims().to_vec();
if qs.len() != 4 {
return Err(SapientError::RankMismatch {
expected: 4,
got: qs.len(),
});
}
let (batch, n_heads, seq_q, head_dim) = (qs[0], qs[1], qs[2], qs[3]);
let seq_k = ks[2];
let scale = scale.unwrap_or(1.0 / (head_dim as f32).sqrt());
let q_cow = q.to_f32_cow();
let q_data = q_cow.as_ref();
let k_cow = k.to_f32_cow();
let k_data = k_cow.as_ref();
let v_cow = v.to_f32_cow();
let v_data = v_cow.as_ref();
let q_strides = q.strides();
let k_strides = k.strides();
let v_strides = v.strides();
let kv_rep = n_heads / n_kv_heads;
let head_out_size = seq_q * head_dim;
let mut out = vec![0.0f32; batch * n_heads * head_out_size];
let mask_cow = mask.map(|m| m.to_f32_cow());
let mask_data: Option<&[f32]> = mask_cow.as_deref();
out.par_chunks_mut(head_out_size)
.enumerate()
.for_each(|(bh, out_chunk)| {
let b = bh / n_heads;
let h = bh % n_heads;
let kv_h = h / kv_rep;
let mut scores = vec![0.0f32; seq_q * seq_k];
for qi in 0..seq_q {
for ki in 0..seq_k {
let q_base = b * q_strides[0] + h * q_strides[1] + qi * q_strides[2];
let k_base = b * k_strides[0] + kv_h * k_strides[1] + ki * k_strides[2];
let dot: f32 = (0..head_dim)
.map(|d| {
q_data[q_base + d * q_strides[3]] * k_data[k_base + d * k_strides[3]]
})
.sum();
scores[qi * seq_k + ki] = dot * scale;
}
}
if let Some(m) = mask_data {
for (s, &mv) in scores.iter_mut().zip(m.iter()) {
*s += mv;
}
}
for qi in 0..seq_q {
let row = &mut scores[qi * seq_k..(qi + 1) * seq_k];
let mut max_v = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
if max_v == f32::NEG_INFINITY {
max_v = 0.0;
}
let mut sum = 0.0f32;
for s in row.iter_mut() {
*s = (*s - max_v).exp();
sum += *s;
}
if sum == 0.0 {
sum = f32::EPSILON;
}
for s in row.iter_mut() {
*s /= sum;
}
}
for qi in 0..seq_q {
for d in 0..head_dim {
let acc: f32 = (0..seq_k)
.map(|ki| {
let v_idx = b * v_strides[0]
+ kv_h * v_strides[1]
+ ki * v_strides[2]
+ d * v_strides[3];
scores[qi * seq_k + ki] * v_data[v_idx]
})
.sum();
out_chunk[qi * head_dim + d] = acc;
}
}
});
Tensor::from_f32(&out, Shape::new([batch, n_heads, seq_q, head_dim]))
}
pub fn causal_mask(seq_q: usize, seq_k: usize) -> Tensor {
let mut data = vec![0.0f32; seq_q * seq_k];
let offset = seq_k.saturating_sub(seq_q);
for qi in 0..seq_q {
for ki in 0..seq_k {
if ki > qi + offset {
data[qi * seq_k + ki] = f32::NEG_INFINITY;
}
}
}
Tensor::from_f32(&data, vec![seq_q, seq_k]).unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mha_output_shape() {
let q = Tensor::from_f32(&[0.1f32; 24], vec![1, 2, 3, 4]).unwrap();
let k = Tensor::from_f32(&[0.1f32; 24], vec![1, 2, 3, 4]).unwrap();
let v = Tensor::from_f32(&[0.1f32; 24], vec![1, 2, 3, 4]).unwrap();
let out = scaled_dot_product_attention(&q, &k, &v, None, None, 2).unwrap();
assert_eq!(out.shape().dims(), &[1, 2, 3, 4]);
}
#[test]
fn gqa_kv_repeat() {
let q = Tensor::from_f32(&[0.1f32; 32], vec![1, 4, 2, 4]).unwrap();
let k = Tensor::from_f32(&[0.1f32; 16], vec![1, 2, 2, 4]).unwrap();
let v = Tensor::from_f32(&[0.1f32; 16], vec![1, 2, 2, 4]).unwrap();
let out = scaled_dot_product_attention(&q, &k, &v, None, None, 2).unwrap();
assert_eq!(out.shape().dims(), &[1, 4, 2, 4]);
}
#[test]
fn causal_mask_shape() {
let m = causal_mask(3, 3);
let d = m.as_f32_slice();
assert!(d[1].is_infinite() && d[1] < 0.0);
assert_eq!(d[3], 0.0);
}
}