use super::mask_ops::apply_causal_mask_row;
use super::shape::{AttentionError, AttentionMask, AttentionShape};
use super::softmax_ops::stable_softmax_row;
pub fn scaled_dot_product_attention(
q: &[f32],
k: &[f32],
v: &[f32],
shape: AttentionShape,
out: &mut [f32],
scratch_scores: &mut [f32],
mask: AttentionMask,
) -> Result<(), AttentionError> {
if !shape.validate() {
return Err(AttentionError::InvalidDim);
}
let q_size = shape.q_len.checked_mul(shape.d_k).ok_or(AttentionError::ShapeMismatch)?;
let k_size = shape.k_len.checked_mul(shape.d_k).ok_or(AttentionError::ShapeMismatch)?;
let v_size = shape.k_len.checked_mul(shape.d_v).ok_or(AttentionError::ShapeMismatch)?;
let out_size = shape.output_len().ok_or(AttentionError::ShapeMismatch)?;
let score_size = shape.score_len().ok_or(AttentionError::BufferTooSmall)?;
if q.len() < q_size || k.len() < k_size || v.len() < v_size || out.len() < out_size {
return Err(AttentionError::ShapeMismatch);
}
if scratch_scores.len() < score_size {
return Err(AttentionError::BufferTooSmall);
}
let inv_sqrt = 1.0 / crate::math::sqrtf(shape.d_k as f32);
for qi in 0..shape.q_len {
let score_row = &mut scratch_scores[qi * shape.k_len..(qi + 1) * shape.k_len];
for kj in 0..shape.k_len {
let mut dot = 0.0f32;
for d in 0..shape.d_k {
dot += q[qi * shape.d_k + d] * k[kj * shape.d_k + d];
}
score_row[kj] = dot * inv_sqrt;
}
if matches!(mask, AttentionMask::Causal) {
apply_causal_mask_row(score_row, qi);
}
stable_softmax_row(score_row).ok_or(AttentionError::InvalidDim)?;
let out_row = &mut out[qi * shape.d_v..(qi + 1) * shape.d_v];
for elem in out_row.iter_mut() {
*elem = 0.0;
}
for kj in 0..shape.k_len {
let w = score_row[kj];
let v_row = &v[kj * shape.d_v..(kj + 1) * shape.d_v];
for d in 0..shape.d_v {
out_row[d] += w * v_row[d];
}
}
}
Ok(())
}