native_neural_network 0.3.1

Lib no_std Rust for native neural network (.rnn)
Documentation
use super::errors::EmbeddingError;

pub fn gather_embeddings(
    table: &[f32],
    vocab_size: usize,
    hidden_size: usize,
    token_ids: &[usize],
    out: &mut [f32],
) -> Result<(), EmbeddingError> {
    let table_len = vocab_size
        .checked_mul(hidden_size)
        .ok_or(EmbeddingError::ShapeMismatch)?;
    let out_len = token_ids
        .len()
        .checked_mul(hidden_size)
        .ok_or(EmbeddingError::ShapeMismatch)?;
    if table.len() < table_len || out.len() < out_len {
        return Err(EmbeddingError::ShapeMismatch);
    }

    for (t, &id) in token_ids.iter().enumerate() {
        if id >= vocab_size {
            return Err(EmbeddingError::TokenOutOfRange);
        }
        let src = id * hidden_size;
        let dst = t * hidden_size;
        out[dst..dst + hidden_size].copy_from_slice(&table[src..src + hidden_size]);
    }
    Ok(())
}