use super::mask_ops::{apply_causal_mask_row_f32, apply_causal_mask_row_f64};
use super::shape::{AttentionError, AttentionMask, AttentionShape};
use super::softmax_ops::{stable_softmax_row_f32, stable_softmax_row_f64};
pub fn scaled_dot_product_attention_f32(
q: &[f32],
k: &[f32],
v: &[f32],
shape: AttentionShape,
out: &mut [f32],
scratch_scores: &mut [f32],
mask: AttentionMask,
) -> Result<(), AttentionError> {
if !shape.validate() {
return Err(AttentionError::InvalidDim);
}
let q_size = shape
.q_len
.checked_mul(shape.d_k)
.ok_or(AttentionError::ShapeMismatch)?;
let k_size = shape
.k_len
.checked_mul(shape.d_k)
.ok_or(AttentionError::ShapeMismatch)?;
let v_size = shape
.k_len
.checked_mul(shape.d_v)
.ok_or(AttentionError::ShapeMismatch)?;
let out_size = shape.output_len().ok_or(AttentionError::ShapeMismatch)?;
let score_size = shape.score_len().ok_or(AttentionError::BufferTooSmall)?;
if q.len() < q_size || k.len() < k_size || v.len() < v_size || out.len() < out_size {
return Err(AttentionError::ShapeMismatch);
}
if scratch_scores.len() < score_size {
return Err(AttentionError::BufferTooSmall);
}
let mask_u8 = match mask {
AttentionMask::None => 0,
AttentionMask::Causal => 1,
};
if crate::engine::try_invoke_gpu_attention_f32(crate::engine::AttentionInvokeF32 {
q,
k,
v,
out,
scratch_scores,
q_len: shape.q_len,
k_len: shape.k_len,
d_k: shape.d_k,
d_v: shape.d_v,
mask: mask_u8,
}) {
return Ok(());
}
let inv_sqrt_f32 = 1.0 / crate::math::sqrtf(shape.d_k as f32);
for qi in 0..shape.q_len {
let score_row = &mut scratch_scores[qi * shape.k_len..(qi + 1) * shape.k_len];
for kj in 0..shape.k_len {
let mut dot = 0.0f32;
for d in 0..shape.d_k {
dot += q[qi * shape.d_k + d] * k[kj * shape.d_k + d];
}
score_row[kj] = dot * inv_sqrt_f32;
}
if matches!(mask, AttentionMask::Causal) {
apply_causal_mask_row_f32(score_row, qi);
}
stable_softmax_row_f32(score_row).ok_or(AttentionError::InvalidDim)?;
let out_row = &mut out[qi * shape.d_v..(qi + 1) * shape.d_v];
for elem in out_row.iter_mut() {
*elem = 0.0;
}
for kj in 0..shape.k_len {
let w = score_row[kj];
let v_row = &v[kj * shape.d_v..(kj + 1) * shape.d_v];
for d in 0..shape.d_v {
out_row[d] += w * v_row[d];
}
}
}
Ok(())
}
pub fn scaled_dot_product_attention_f64(
q: &[f64],
k: &[f64],
v: &[f64],
shape: AttentionShape,
out: &mut [f64],
scratch_scores: &mut [f64],
mask: AttentionMask,
) -> Result<(), AttentionError> {
if !shape.validate() {
return Err(AttentionError::InvalidDim);
}
let q_size = shape
.q_len
.checked_mul(shape.d_k)
.ok_or(AttentionError::ShapeMismatch)?;
let k_size = shape
.k_len
.checked_mul(shape.d_k)
.ok_or(AttentionError::ShapeMismatch)?;
let v_size = shape
.k_len
.checked_mul(shape.d_v)
.ok_or(AttentionError::ShapeMismatch)?;
let out_size = shape.output_len().ok_or(AttentionError::ShapeMismatch)?;
let score_size = shape.score_len().ok_or(AttentionError::BufferTooSmall)?;
if q.len() < q_size || k.len() < k_size || v.len() < v_size || out.len() < out_size {
return Err(AttentionError::ShapeMismatch);
}
if scratch_scores.len() < score_size {
return Err(AttentionError::BufferTooSmall);
}
let mask_u8 = match mask {
AttentionMask::None => 0,
AttentionMask::Causal => 1,
};
if crate::engine::try_invoke_gpu_attention_f64(crate::engine::AttentionInvokeF64 {
q,
k,
v,
out,
scratch_scores,
q_len: shape.q_len,
k_len: shape.k_len,
d_k: shape.d_k,
d_v: shape.d_v,
mask: mask_u8,
}) {
return Ok(());
}
let inv_sqrt = 1.0 / crate::math::sqrtd(shape.d_k as f64);
for qi in 0..shape.q_len {
let score_row = &mut scratch_scores[qi * shape.k_len..(qi + 1) * shape.k_len];
for kj in 0..shape.k_len {
let mut dot64 = 0.0f64;
for d in 0..shape.d_k {
dot64 += q[qi * shape.d_k + d] * k[kj * shape.d_k + d];
}
score_row[kj] = dot64 * inv_sqrt;
}
if matches!(mask, AttentionMask::Causal) {
apply_causal_mask_row_f64(score_row, qi);
}
stable_softmax_row_f64(score_row).ok_or(AttentionError::InvalidDim)?;
let out_row = &mut out[qi * shape.d_v..(qi + 1) * shape.d_v];
for d in 0..shape.d_v {
let mut acc = 0.0f64;
for kj in 0..shape.k_len {
let w = score_row[kj];
acc += w * v[kj * shape.d_v + d];
}
out_row[d] = acc;
}
}
Ok(())
}