1use mlx_core::{MlxError, Result, Tensor};
4
5use crate::Module;
6
7pub struct LayerNorm {
13 dim: usize,
14 eps: f32,
15}
16
17impl LayerNorm {
18 pub fn new(dim: usize, eps: f32) -> Self {
19 Self { dim, eps }
20 }
21}
22
23impl Module for LayerNorm {
24 fn forward(&self, input: &Tensor) -> Result<Tensor> {
25 let last_dim = *input.shape().0.last().ok_or_else(|| {
26 MlxError::InvalidArgument("LayerNorm requires at least 1D input".into())
27 })? as usize;
28 if last_dim != self.dim {
29 return Err(MlxError::InvalidArgument(format!(
30 "LayerNorm expected last dim {}, got {}",
31 self.dim, last_dim
32 )));
33 }
34 Ok(input.layer_norm(self.eps))
35 }
36}
37
38pub struct RmsNorm {
44 dim: usize,
45 eps: f32,
46}
47
48impl RmsNorm {
49 pub fn new(dim: usize, eps: f32) -> Self {
50 Self { dim, eps }
51 }
52}
53
54impl Module for RmsNorm {
55 fn forward(&self, input: &Tensor) -> Result<Tensor> {
56 let last_dim =
57 *input.shape().0.last().ok_or_else(|| {
58 MlxError::InvalidArgument("RmsNorm requires at least 1D input".into())
59 })? as usize;
60 if last_dim != self.dim {
61 return Err(MlxError::InvalidArgument(format!(
62 "RmsNorm expected last dim {}, got {}",
63 self.dim, last_dim
64 )));
65 }
66 Ok(input.rms_norm(self.eps))
67 }
68}