native_neural_network 0.1.6

Lib no_std Rust for native neural network (.rnn)
Documentation
use super::errors::EmbeddingError;

pub fn tied_output_projection(
    hidden: &[f32],
    table: &[f32],
    vocab_size: usize,
    hidden_size: usize,
    logits_out: &mut [f32],
) -> Result<(), EmbeddingError> {
    if hidden.len() != hidden_size {
        return Err(EmbeddingError::ShapeMismatch);
    }
    let table_len = vocab_size.checked_mul(hidden_size).ok_or(EmbeddingError::ShapeMismatch)?;
    if table.len() < table_len || logits_out.len() < vocab_size {
        return Err(EmbeddingError::ShapeMismatch);
    }

    for v in 0..vocab_size {
        let row = &table[v * hidden_size..(v + 1) * hidden_size];
        let mut acc = 0.0f32;
        for i in 0..hidden_size {
            acc += hidden[i] * row[i];
        }
        logits_out[v] = acc;
    }
    Ok(())
}