use candle_core::{D, Tensor};
use candle_nn::VarBuilder;
use crate::error::Result;
pub struct LayerNorm {
weight: Tensor,
bias: Tensor,
eps: f64,
}
impl LayerNorm {
#[allow(clippy::needless_pass_by_value)] pub fn load(size: usize, eps: f64, vb: VarBuilder<'_>) -> Result<Self> {
let weight = vb.get(size, "weight")?;
let bias = vb.get(size, "bias")?;
Ok(Self { weight, bias, eps })
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let mean = x.mean_keepdim(D::Minus1)?;
let x_centered = x.broadcast_sub(&mean)?;
let var = x_centered.sqr()?.mean_keepdim(D::Minus1)?;
let x_normed = x_centered.broadcast_div(&(var + self.eps)?.sqrt()?)?;
Ok(x_normed
.broadcast_mul(&self.weight)?
.broadcast_add(&self.bias)?)
}
}
pub fn group_norm(
x: &Tensor,
num_groups: usize,
weight: &Tensor,
bias: &Tensor,
eps: f64,
) -> Result<Tensor> {
let (n, c) = x.dims2()?;
let channels_per_group = c / num_groups;
let x = x.reshape((n, num_groups, channels_per_group))?;
let mean = x.mean_keepdim(2)?;
let x_centered = x.broadcast_sub(&mean)?;
let var = x_centered.sqr()?.mean_keepdim(2)?;
let x_normed = x_centered.broadcast_div(&(var + eps)?.sqrt()?)?;
let x_normed = x_normed.reshape((n, c))?;
Ok(x_normed
.broadcast_mul(&weight.unsqueeze(0)?)?
.broadcast_add(&bias.unsqueeze(0)?)?)
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use candle_core::{DType, Device};
#[test]
fn layer_norm_basic() {
let device = Device::Cpu;
let x = Tensor::new(&[[1.0_f32, 2.0, 3.0, 4.0]], &device).unwrap();
let weight = Tensor::ones(4, DType::F32, &device).unwrap();
let bias = Tensor::zeros(4, DType::F32, &device).unwrap();
let ln = LayerNorm {
weight,
bias,
eps: 1e-5,
};
let out = ln.forward(&x).unwrap();
let out_vec: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(out_vec.len(), 4);
assert!((out_vec[0] - (-1.3416)).abs() < 0.01);
}
#[test]
fn group_norm_basic() {
let device = Device::Cpu;
let x = Tensor::new(&[[1.0_f32, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], &device).unwrap();
let weight = Tensor::ones(4, DType::F32, &device).unwrap();
let bias = Tensor::zeros(4, DType::F32, &device).unwrap();
let out = group_norm(&x, 2, &weight, &bias, 1e-5).unwrap();
let shape = out.dims2().unwrap();
assert_eq!(shape, (2, 4));
let out_vec: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
assert!((out_vec[0] - (-1.0)).abs() < 0.01);
assert!((out_vec[1] - 1.0).abs() < 0.01);
}
}