pub use native_neural_network::embeddings::EmbeddingError;
#[derive(Debug)]
pub enum EmbeddingsStdError {
ShapeMismatch,
TokenOutOfRange,
}
impl From<native_neural_network::embeddings::EmbeddingError> for EmbeddingsStdError {
fn from(e: native_neural_network::embeddings::EmbeddingError) -> Self {
match e {
native_neural_network::embeddings::EmbeddingError::ShapeMismatch => {
EmbeddingsStdError::ShapeMismatch
}
native_neural_network::embeddings::EmbeddingError::TokenOutOfRange => {
EmbeddingsStdError::TokenOutOfRange
}
}
}
}
pub fn gather_embeddings(
table: &[f32],
vocab_size: usize,
hidden_size: usize,
token_ids: &[usize],
out: &mut [f32],
) -> Result<(), EmbeddingsStdError> {
native_neural_network::embeddings::gather_embeddings(
table,
vocab_size,
hidden_size,
token_ids,
out,
)
.map_err(|e| e.into())
}
pub fn tied_output_projection(
hidden: &[f32],
table: &[f32],
vocab_size: usize,
hidden_size: usize,
logits_out: &mut [f32],
) -> Result<(), EmbeddingsStdError> {
native_neural_network::embeddings::tied_output_projection(
hidden,
table,
vocab_size,
hidden_size,
logits_out,
)
.map_err(|e| e.into())
}
impl core::fmt::Display for EmbeddingsStdError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "EmbeddingsStdError::{:?}", self)
}
}
impl std::error::Error for EmbeddingsStdError {}