native_neural_network_std 0.2.1

Ergonomic std wrapper for the `native_neural_network` crate (no_std) — std-friendly re-exports and utilities.
Documentation
pub use native_neural_network::attention::AttentionMask;
pub use native_neural_network::attention::AttentionShape;

#[derive(Debug)]
pub enum AttentionStdError {
    ShapeMismatch,
    BufferTooSmall,
    InvalidDim,
}

impl From<native_neural_network::attention::AttentionError> for AttentionStdError {
    fn from(e: native_neural_network::attention::AttentionError) -> Self {
        match e {
            native_neural_network::attention::AttentionError::ShapeMismatch => {
                AttentionStdError::ShapeMismatch
            }
            native_neural_network::attention::AttentionError::BufferTooSmall => {
                AttentionStdError::BufferTooSmall
            }
            native_neural_network::attention::AttentionError::InvalidDim => {
                AttentionStdError::InvalidDim
            }
        }
    }
}

pub fn scaled_dot_product_attention(
    q: &[f32],
    k: &[f32],
    v: &[f32],
    shape: AttentionShape,
    out: &mut [f32],
    scratch_scores: &mut [f32],
    mask: AttentionMask,
) -> Result<(), AttentionStdError> {
    native_neural_network::attention::scaled_dot_product_attention_f32(
        q,
        k,
        v,
        shape,
        out,
        scratch_scores,
        mask,
    )
    .map_err(|e| e.into())
}

impl core::fmt::Display for AttentionStdError {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        write!(f, "AttentionStdError::{:?}", self)
    }
}
impl std::error::Error for AttentionStdError {}