native_neural_network 0.3.1

Lib no_std Rust for native neural network (.rnn)
Documentation
use super::mask_ops::{apply_causal_mask_row_f32, apply_causal_mask_row_f64};
use super::shape::{AttentionError, AttentionMask, AttentionShape};
use super::softmax_ops::{stable_softmax_row_f32, stable_softmax_row_f64};

pub fn scaled_dot_product_attention_f32(
    q: &[f32],
    k: &[f32],
    v: &[f32],
    shape: AttentionShape,
    out: &mut [f32],
    scratch_scores: &mut [f32],
    mask: AttentionMask,
) -> Result<(), AttentionError> {
    if !shape.validate() {
        return Err(AttentionError::InvalidDim);
    }

    let q_size = shape
        .q_len
        .checked_mul(shape.d_k)
        .ok_or(AttentionError::ShapeMismatch)?;
    let k_size = shape
        .k_len
        .checked_mul(shape.d_k)
        .ok_or(AttentionError::ShapeMismatch)?;
    let v_size = shape
        .k_len
        .checked_mul(shape.d_v)
        .ok_or(AttentionError::ShapeMismatch)?;
    let out_size = shape.output_len().ok_or(AttentionError::ShapeMismatch)?;
    let score_size = shape.score_len().ok_or(AttentionError::BufferTooSmall)?;

    if q.len() < q_size || k.len() < k_size || v.len() < v_size || out.len() < out_size {
        return Err(AttentionError::ShapeMismatch);
    }
    if scratch_scores.len() < score_size {
        return Err(AttentionError::BufferTooSmall);
    }

    let mask_u8 = match mask {
        AttentionMask::None => 0,
        AttentionMask::Causal => 1,
    };
    if crate::engine::try_invoke_gpu_attention_f32(crate::engine::AttentionInvokeF32 {
        q,
        k,
        v,
        out,
        scratch_scores,
        q_len: shape.q_len,
        k_len: shape.k_len,
        d_k: shape.d_k,
        d_v: shape.d_v,
        mask: mask_u8,
    }) {
        return Ok(());
    }

    let inv_sqrt_f32 = 1.0 / crate::math::sqrtf(shape.d_k as f32);

    for qi in 0..shape.q_len {
        let score_row = &mut scratch_scores[qi * shape.k_len..(qi + 1) * shape.k_len];
        for kj in 0..shape.k_len {
            let mut dot = 0.0f32;
            for d in 0..shape.d_k {
                dot += q[qi * shape.d_k + d] * k[kj * shape.d_k + d];
            }
            score_row[kj] = dot * inv_sqrt_f32;
        }

        if matches!(mask, AttentionMask::Causal) {
            apply_causal_mask_row_f32(score_row, qi);
        }

        stable_softmax_row_f32(score_row).ok_or(AttentionError::InvalidDim)?;

        let out_row = &mut out[qi * shape.d_v..(qi + 1) * shape.d_v];
        for elem in out_row.iter_mut() {
            *elem = 0.0;
        }
        for kj in 0..shape.k_len {
            let w = score_row[kj];
            let v_row = &v[kj * shape.d_v..(kj + 1) * shape.d_v];
            for d in 0..shape.d_v {
                out_row[d] += w * v_row[d];
            }
        }
    }

    Ok(())
}

pub fn scaled_dot_product_attention_f64(
    q: &[f64],
    k: &[f64],
    v: &[f64],
    shape: AttentionShape,
    out: &mut [f64],
    scratch_scores: &mut [f64],
    mask: AttentionMask,
) -> Result<(), AttentionError> {
    if !shape.validate() {
        return Err(AttentionError::InvalidDim);
    }

    let q_size = shape
        .q_len
        .checked_mul(shape.d_k)
        .ok_or(AttentionError::ShapeMismatch)?;
    let k_size = shape
        .k_len
        .checked_mul(shape.d_k)
        .ok_or(AttentionError::ShapeMismatch)?;
    let v_size = shape
        .k_len
        .checked_mul(shape.d_v)
        .ok_or(AttentionError::ShapeMismatch)?;
    let out_size = shape.output_len().ok_or(AttentionError::ShapeMismatch)?;
    let score_size = shape.score_len().ok_or(AttentionError::BufferTooSmall)?;

    if q.len() < q_size || k.len() < k_size || v.len() < v_size || out.len() < out_size {
        return Err(AttentionError::ShapeMismatch);
    }
    if scratch_scores.len() < score_size {
        return Err(AttentionError::BufferTooSmall);
    }

    let mask_u8 = match mask {
        AttentionMask::None => 0,
        AttentionMask::Causal => 1,
    };
    if crate::engine::try_invoke_gpu_attention_f64(crate::engine::AttentionInvokeF64 {
        q,
        k,
        v,
        out,
        scratch_scores,
        q_len: shape.q_len,
        k_len: shape.k_len,
        d_k: shape.d_k,
        d_v: shape.d_v,
        mask: mask_u8,
    }) {
        return Ok(());
    }

    let inv_sqrt = 1.0 / crate::math::sqrtd(shape.d_k as f64);

    for qi in 0..shape.q_len {
        let score_row = &mut scratch_scores[qi * shape.k_len..(qi + 1) * shape.k_len];
        for kj in 0..shape.k_len {
            let mut dot64 = 0.0f64;
            for d in 0..shape.d_k {
                dot64 += q[qi * shape.d_k + d] * k[kj * shape.d_k + d];
            }
            score_row[kj] = dot64 * inv_sqrt;
        }

        if matches!(mask, AttentionMask::Causal) {
            apply_causal_mask_row_f64(score_row, qi);
        }

        stable_softmax_row_f64(score_row).ok_or(AttentionError::InvalidDim)?;

        let out_row = &mut out[qi * shape.d_v..(qi + 1) * shape.d_v];
        for d in 0..shape.d_v {
            let mut acc = 0.0f64;
            for kj in 0..shape.k_len {
                let w = score_row[kj];
                acc += w * v[kj * shape.d_v + d];
            }
            out_row[d] = acc;
        }
    }

    Ok(())
}