#[derive(Debug)]
pub enum NormStdError {
Empty,
ShapeMismatch,
InvalidEps,
}
impl From<native_neural_network::normalization::NormError> for NormStdError {
fn from(e: native_neural_network::normalization::NormError) -> Self {
match e {
native_neural_network::normalization::NormError::Empty => NormStdError::Empty,
native_neural_network::normalization::NormError::ShapeMismatch => {
NormStdError::ShapeMismatch
}
native_neural_network::normalization::NormError::InvalidEps => NormStdError::InvalidEps,
}
}
}
pub fn layer_norm_in_place(
x: &mut [f32],
gamma: &[f32],
beta: &[f32],
eps: f32,
) -> Result<(), NormStdError> {
native_neural_network::normalization::layer_norm_in_place_f32(x, gamma, beta, eps)
.map_err(|e| e.into())
}
pub fn layer_norm(
input: &[f32],
gamma: &[f32],
beta: &[f32],
eps: f32,
out: &mut [f32],
) -> Result<(), NormStdError> {
native_neural_network::normalization::layer_norm_f32(input, gamma, beta, eps, out)
.map_err(|e| e.into())
}
pub fn rms_norm_in_place(x: &mut [f32], gamma: &[f32], eps: f32) -> Result<(), NormStdError> {
native_neural_network::normalization::rms_norm_in_place_f32(x, gamma, eps).map_err(|e| e.into())
}
pub fn rms_norm(
input: &[f32],
gamma: &[f32],
eps: f32,
out: &mut [f32],
) -> Result<(), NormStdError> {
native_neural_network::normalization::rms_norm_f32(input, gamma, eps, out).map_err(|e| e.into())
}
impl core::fmt::Display for NormStdError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "NormStdError::{:?}", self)
}
}
impl std::error::Error for NormStdError {}