native_neural_network_std 0.2.1

Ergonomic std wrapper for the `native_neural_network` crate (no_std) — std-friendly re-exports and utilities.
Documentation
pub use native_neural_network::embeddings::EmbeddingError;

#[derive(Debug)]
pub enum EmbeddingsStdError {
    ShapeMismatch,
    TokenOutOfRange,
}

impl From<native_neural_network::embeddings::EmbeddingError> for EmbeddingsStdError {
    fn from(e: native_neural_network::embeddings::EmbeddingError) -> Self {
        match e {
            native_neural_network::embeddings::EmbeddingError::ShapeMismatch => {
                EmbeddingsStdError::ShapeMismatch
            }
            native_neural_network::embeddings::EmbeddingError::TokenOutOfRange => {
                EmbeddingsStdError::TokenOutOfRange
            }
        }
    }
}

pub fn gather_embeddings(
    table: &[f32],
    vocab_size: usize,
    hidden_size: usize,
    token_ids: &[usize],
    out: &mut [f32],
) -> Result<(), EmbeddingsStdError> {
    native_neural_network::embeddings::gather_embeddings(
        table,
        vocab_size,
        hidden_size,
        token_ids,
        out,
    )
    .map_err(|e| e.into())
}

pub fn tied_output_projection(
    hidden: &[f32],
    table: &[f32],
    vocab_size: usize,
    hidden_size: usize,
    logits_out: &mut [f32],
) -> Result<(), EmbeddingsStdError> {
    native_neural_network::embeddings::tied_output_projection(
        hidden,
        table,
        vocab_size,
        hidden_size,
        logits_out,
    )
    .map_err(|e| e.into())
}

impl core::fmt::Display for EmbeddingsStdError {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        write!(f, "EmbeddingsStdError::{:?}", self)
    }
}
impl std::error::Error for EmbeddingsStdError {}