native_neural_network 0.1.6

Lib no_std Rust for native neural network (.rnn)
Documentation
use super::mask_ops::apply_causal_mask_row;
use super::shape::{AttentionError, AttentionMask, AttentionShape};
use super::softmax_ops::stable_softmax_row;

pub fn scaled_dot_product_attention(
    q: &[f32],
    k: &[f32],
    v: &[f32],
    shape: AttentionShape,
    out: &mut [f32],
    scratch_scores: &mut [f32],
    mask: AttentionMask,
) -> Result<(), AttentionError> {
    if !shape.validate() {
        return Err(AttentionError::InvalidDim);
    }

    let q_size = shape.q_len.checked_mul(shape.d_k).ok_or(AttentionError::ShapeMismatch)?;
    let k_size = shape.k_len.checked_mul(shape.d_k).ok_or(AttentionError::ShapeMismatch)?;
    let v_size = shape.k_len.checked_mul(shape.d_v).ok_or(AttentionError::ShapeMismatch)?;
    let out_size = shape.output_len().ok_or(AttentionError::ShapeMismatch)?;
    let score_size = shape.score_len().ok_or(AttentionError::BufferTooSmall)?;

    if q.len() < q_size || k.len() < k_size || v.len() < v_size || out.len() < out_size {
        return Err(AttentionError::ShapeMismatch);
    }
    if scratch_scores.len() < score_size {
        return Err(AttentionError::BufferTooSmall);
    }

    let inv_sqrt = 1.0 / crate::math::sqrtf(shape.d_k as f32);

    for qi in 0..shape.q_len {
        let score_row = &mut scratch_scores[qi * shape.k_len..(qi + 1) * shape.k_len];
        for kj in 0..shape.k_len {
            let mut dot = 0.0f32;
            for d in 0..shape.d_k {
                dot += q[qi * shape.d_k + d] * k[kj * shape.d_k + d];
            }
            score_row[kj] = dot * inv_sqrt;
        }

        if matches!(mask, AttentionMask::Causal) {
            apply_causal_mask_row(score_row, qi);
        }

        stable_softmax_row(score_row).ok_or(AttentionError::InvalidDim)?;

        let out_row = &mut out[qi * shape.d_v..(qi + 1) * shape.d_v];
        for elem in out_row.iter_mut() {
            *elem = 0.0;
        }

        for kj in 0..shape.k_len {
            let w = score_row[kj];
            let v_row = &v[kj * shape.d_v..(kj + 1) * shape.d_v];
            for d in 0..shape.d_v {
                out_row[d] += w * v_row[d];
            }
        }
    }

    Ok(())
}