use candle_core::{DType, Module, Tensor};
use candle_nn::VarBuilder;
use crate::config::NormType;
use crate::error::Result;
#[allow(clippy::exhaustive_enums)]
pub enum Norm {
Rms(candle_nn::RmsNorm),
Layer(candle_nn::LayerNorm),
GemmaRms(GemmaRmsNorm),
}
impl Norm {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::Rms(norm) => Ok(norm.forward(xs)?),
Self::Layer(norm) => Ok(norm.forward(xs)?),
Self::GemmaRms(norm) => norm.forward(xs),
}
}
}
#[allow(clippy::needless_pass_by_value)] pub fn create_norm(
norm_type: NormType,
hidden_size: usize,
eps: f64,
vb: VarBuilder<'_>,
) -> Result<Norm> {
match norm_type {
NormType::RmsNorm => {
let norm = candle_nn::rms_norm(hidden_size, eps, vb)?;
Ok(Norm::Rms(norm))
}
NormType::LayerNorm => {
let config = candle_nn::LayerNormConfig {
eps,
..Default::default()
};
let norm = candle_nn::layer_norm(hidden_size, config, vb)?;
Ok(Norm::Layer(norm))
}
NormType::GemmaRmsNorm => {
let norm = GemmaRmsNorm::load(hidden_size, eps, vb)?;
Ok(Norm::GemmaRms(norm))
}
}
}
pub struct GemmaRmsNorm {
weight: Tensor,
eps: f64,
}
impl GemmaRmsNorm {
#[allow(clippy::needless_pass_by_value)] fn load(hidden_size: usize, eps: f64, vb: VarBuilder<'_>) -> Result<Self> {
let weight = vb.get(hidden_size, "weight")?;
Ok(Self { weight, eps })
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let original_dtype = xs.dtype();
let xs_f32 = if original_dtype == DType::F32 {
xs.clone()
} else {
xs.to_dtype(DType::F32)?
};
let variance = xs_f32.sqr()?.mean_keepdim(candle_core::D::Minus1)?;
let rms = (variance + self.eps)?.sqrt()?;
let normed = xs_f32.broadcast_div(&rms)?;
let weight_plus_one = (&self.weight.to_dtype(DType::F32)? + 1.0)?;
let result = normed.broadcast_mul(&weight_plus_one)?;
if original_dtype == DType::F32 {
Ok(result)
} else {
Ok(result.to_dtype(original_dtype)?)
}
}
}