use super::errors::EmbeddingError;
pub fn gather_embeddings(
table: &[f32],
vocab_size: usize,
hidden_size: usize,
token_ids: &[usize],
out: &mut [f32],
) -> Result<(), EmbeddingError> {
let table_len = vocab_size.checked_mul(hidden_size).ok_or(EmbeddingError::ShapeMismatch)?;
let out_len = token_ids.len().checked_mul(hidden_size).ok_or(EmbeddingError::ShapeMismatch)?;
if table.len() < table_len || out.len() < out_len {
return Err(EmbeddingError::ShapeMismatch);
}
for (t, &id) in token_ids.iter().enumerate() {
if id >= vocab_size {
return Err(EmbeddingError::TokenOutOfRange);
}
let src = id * hidden_size;
let dst = t * hidden_size;
out[dst..dst + hidden_size].copy_from_slice(&table[src..src + hidden_size]);
}
Ok(())
}