Skip to main content

mlx_nn/
norm.rs

1//! Normalization layers: LayerNorm and RMSNorm.
2
3use mlx_core::{MlxError, Result, Tensor};
4
5use crate::Module;
6
7/// Layer normalization over the last dimension.
8///
9/// Normalizes to zero mean and unit variance, controlled by `eps` for
10/// numerical stability. The `dim` parameter specifies the expected size
11/// of the last dimension and is validated during `forward()`.
12pub struct LayerNorm {
13    dim: usize,
14    eps: f32,
15}
16
17impl LayerNorm {
18    pub fn new(dim: usize, eps: f32) -> Self {
19        Self { dim, eps }
20    }
21}
22
23impl Module for LayerNorm {
24    fn forward(&self, input: &Tensor) -> Result<Tensor> {
25        let last_dim = *input.shape().0.last().ok_or_else(|| {
26            MlxError::InvalidArgument("LayerNorm requires at least 1D input".into())
27        })? as usize;
28        if last_dim != self.dim {
29            return Err(MlxError::InvalidArgument(format!(
30                "LayerNorm expected last dim {}, got {}",
31                self.dim, last_dim
32            )));
33        }
34        Ok(input.layer_norm(self.eps))
35    }
36}
37
38/// RMS normalization over the last dimension.
39///
40/// Like LayerNorm but skips the mean-centering step, normalizing only by
41/// the root-mean-square. The `dim` parameter specifies the expected size
42/// of the last dimension and is validated during `forward()`.
43pub struct RmsNorm {
44    dim: usize,
45    eps: f32,
46}
47
48impl RmsNorm {
49    pub fn new(dim: usize, eps: f32) -> Self {
50        Self { dim, eps }
51    }
52}
53
54impl Module for RmsNorm {
55    fn forward(&self, input: &Tensor) -> Result<Tensor> {
56        let last_dim =
57            *input.shape().0.last().ok_or_else(|| {
58                MlxError::InvalidArgument("RmsNorm requires at least 1D input".into())
59            })? as usize;
60        if last_dim != self.dim {
61            return Err(MlxError::InvalidArgument(format!(
62                "RmsNorm expected last dim {}, got {}",
63                self.dim, last_dim
64            )));
65        }
66        Ok(input.rms_norm(self.eps))
67    }
68}