use super::errors::EmbeddingError;
pub fn tied_output_projection(
hidden: &[f32],
table: &[f32],
vocab_size: usize,
hidden_size: usize,
logits_out: &mut [f32],
) -> Result<(), EmbeddingError> {
if hidden.len() != hidden_size {
return Err(EmbeddingError::ShapeMismatch);
}
let table_len = vocab_size
.checked_mul(hidden_size)
.ok_or(EmbeddingError::ShapeMismatch)?;
if table.len() < table_len || logits_out.len() < vocab_size {
return Err(EmbeddingError::ShapeMismatch);
}
for v in 0..vocab_size {
let row = &table[v * hidden_size..(v + 1) * hidden_size];
let mut acc = 0.0f32;
for i in 0..hidden_size {
acc += hidden[i] * row[i];
}
logits_out[v] = acc;
}
Ok(())
}