1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
// Copyright (C) 2025 zk4x
// SPDX-License-Identifier: LGPL-3.0-only
use zyx::{DType, IntoShape, Tensor, ZyxError};
use zyx_derive::Module;
/// A Layer Normalization layer.
///
/// Layer Normalization normalizes the inputs across the specified dimensions (typically the last N dimensions)
/// for each example independently. It optionally supports learnable scale (`weight`) and bias (`bias_tensor`) parameters.
#[derive(Debug, Module)]
#[cfg_attr(feature = "py", pyo3::pyclass)]
pub struct LayerNorm {
normalized_shape: Vec<u64>,
eps: f64,
weight: Option<Tensor>,
bias_tensor: Option<Tensor>,
}
impl LayerNorm {
/// Creates a new `LayerNorm` layer.
///
/// # Arguments
///
/// * `normalized_shape` - The shape of the dimensions to normalize. Usually corresponds to the last N dimensions of the input tensor.
/// * `eps` - A small value added to the denominator for numerical stability.
/// * `elementwise_affine` - If `true`, includes a learnable scale parameter (`weight`).
/// * `bias` - If `true`, includes a learnable bias parameter (`bias_tensor`).
/// * `dtype` - The data type of the optional learnable parameters.
///
/// # Returns
///
/// Returns `Ok(LayerNorm)` if initialization is successful, or a `ZyxError` if there is an issue with shape or tensor creation.
///
/// # Example
///
/// ```rust
/// # use zyx::{DType, Tensor};
/// # use zyx_nn::LayerNorm;
/// let layer_norm = LayerNorm::new([10, 20], 1e-5, true, true, DType::F32).unwrap();
/// ```
pub fn new(
normalized_shape: impl IntoShape,
eps: f64,
elementwise_affine: bool,
bias: bool,
dtype: DType,
) -> Result<Self, ZyxError> {
let normalized_shape: Vec<u64> = normalized_shape.into_shape().collect();
// Optional learnable parameters
let weight = if elementwise_affine {
Some(Tensor::ones(&normalized_shape, dtype))
} else {
None
};
let bias_tensor = if bias {
Some(Tensor::zeros(&normalized_shape, dtype))
} else {
None
};
Ok(Self {
normalized_shape,
eps,
weight,
bias_tensor,
})
}
/// Performs the forward pass of the LayerNorm layer.
///
/// # Arguments
///
/// * `input` - The input tensor to normalize.
///
/// # Returns
///
/// Returns a new `Tensor` that is normalized along the last `normalized_shape.len()` dimensions.
///
/// # Errors
///
/// Returns a `ZyxError` if the input tensor rank is smaller than the rank of `normalized_shape`.
///
/// # Example
///
/// ```rust
/// # use zyx::{DType, Tensor};
/// # use zyx_nn::LayerNorm;
/// let layer_norm = LayerNorm::new([10, 20], 1e-5, true, true, DType::F32).unwrap();
/// let input = Tensor::randn([2, 10, 20], DType::F32).unwrap();
/// let output = layer_norm.forward(input).unwrap();
/// ```
pub fn forward(&self, input: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
let input = input.into();
let input_shape = input.shape();
let input_rank = input_shape.len();
let norm_rank = self.normalized_shape.len();
if input_rank < norm_rank {
return Err(ZyxError::shape_error(
format!(
"LayerNorm: input rank ({}) smaller than normalized_shape rank ({})",
input_rank, norm_rank
)
.into(),
));
}
// Determine axes to normalize over (last `norm_rank` dims)
let axes: Vec<i32> = (input_rank - norm_rank..input_rank)
.map(|i| i as i32)
.collect();
// Compute mean and variance along those axes (keep dims for broadcasting)
let mean = input.mean_keepdim(axes.clone())?;
let variance = input.var_keepdim(axes)?;
// Normalize: (x - mean) / sqrt(var + eps)
let normalized = (input - &mean) / (variance + self.eps).sqrt();
// Apply learnable affine transformation if enabled
let mut output = normalized;
if let Some(ref weight) = self.weight {
output = output * weight;
}
if let Some(ref bias) = self.bias_tensor {
output = output + bias;
}
Ok(output)
}
}