1use crate::autograd::{BackwardOp, Tensor};
4use ndarray::Array1;
5use std::cell::RefCell;
6use std::rc::Rc;
7
8pub fn layer_norm(x: &Tensor, gamma: &Tensor, beta: &Tensor, epsilon: f32) -> Tensor {
13 let n = x.len() as f32;
14
15 let mean = x.data().sum() / n;
17
18 let variance = x.data().mapv(|val| (val - mean).powi(2)).sum() / n;
20 let std = (variance + epsilon).sqrt();
21
22 let normalized = x.data().mapv(|val| (val - mean) / std);
24
25 let data = &normalized * gamma.data() + beta.data();
27
28 let requires_grad = x.requires_grad() || gamma.requires_grad() || beta.requires_grad();
29 let mut result = Tensor::new(data, requires_grad);
30
31 if requires_grad {
32 let x_clone = x.clone();
33 let gamma_clone = gamma.clone();
34 let beta_clone = beta.clone();
35 let backward_op = Rc::new(LayerNormBackward {
36 x: x_clone,
37 gamma: gamma_clone,
38 beta: beta_clone,
39 normalized: normalized.clone(),
40 std,
41 result_grad: result.grad_cell(),
42 });
43 result.set_backward_op(backward_op);
44 }
45
46 contract_post_layernorm!(result.data().as_slice().unwrap_or(&[]));
47 result
48}
49
50struct LayerNormBackward {
51 x: Tensor,
52 gamma: Tensor,
53 beta: Tensor,
54 normalized: Array1<f32>,
55 std: f32,
56 result_grad: Rc<RefCell<Option<Array1<f32>>>>,
57}
58
59impl BackwardOp for LayerNormBackward {
60 fn backward(&self) {
61 if let Some(grad_output) = self.result_grad.borrow().as_ref() {
62 let n = self.x.len() as f32;
63
64 if self.beta.requires_grad() {
66 self.beta.accumulate_grad(grad_output.clone());
67 }
68
69 if self.gamma.requires_grad() {
71 let grad_gamma = grad_output * &self.normalized;
72 self.gamma.accumulate_grad(grad_gamma);
73 }
74
75 if self.x.requires_grad() {
77 let grad_normalized = grad_output * self.gamma.data();
79
80 let sum_grad = grad_normalized.sum();
82
83 let sum_grad_normalized = (&grad_normalized * &self.normalized).sum();
85
86 let grad_x: Vec<f32> = grad_normalized
89 .iter()
90 .zip(self.normalized.iter())
91 .map(|(&grad_norm, &norm)| {
92 (grad_norm - sum_grad / n - norm * sum_grad_normalized / n) / self.std
93 })
94 .collect();
95
96 self.x.accumulate_grad(Array1::from(grad_x));
97 }
98
99 if let Some(op) = self.x.backward_op() {
101 op.backward();
102 }
103 if let Some(op) = self.gamma.backward_op() {
104 op.backward();
105 }
106 if let Some(op) = self.beta.backward_op() {
107 op.backward();
108 }
109 }
110 }
111}
112
113#[cfg(test)]
129mod normalization_correctness_tests {
130 use super::*;
131 use crate::autograd::Tensor;
132
133 fn reference_layer_norm_f64(x: &[f32], gamma: &[f32], beta: &[f32], eps: f32) -> Vec<f32> {
135 let n = x.len() as f64;
136 let x_f64: Vec<f64> = x.iter().map(|&v| f64::from(v)).collect();
137 let mean: f64 = x_f64.iter().sum::<f64>() / n;
138 let variance: f64 = x_f64.iter().map(|&v| (v - mean) * (v - mean)).sum::<f64>() / n;
139 let std = (variance + f64::from(eps)).sqrt();
140 x_f64
141 .iter()
142 .enumerate()
143 .map(|(i, &v)| ((v - mean) / std * f64::from(gamma[i]) + f64::from(beta[i])) as f32)
144 .collect()
145 }
146
147 #[test]
148 fn test_normalization_correctness_matches_reference() {
149 let x_data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
150 let gamma_data = vec![1.0_f32; 5];
151 let beta_data = vec![0.0_f32; 5];
152 let eps = 1e-5;
153 let reference = reference_layer_norm_f64(&x_data, &gamma_data, &beta_data, eps);
154 let x = Tensor::from_vec(x_data, false);
155 let gamma = Tensor::from_vec(gamma_data, false);
156 let beta = Tensor::from_vec(beta_data, false);
157 let result = layer_norm(&x, &gamma, &beta, eps);
158 for (i, (&actual, &expected)) in result.data().iter().zip(reference.iter()).enumerate() {
159 let diff = (actual - expected).abs();
160 assert!(
161 diff < 1e-5,
162 "LayerNorm correctness[{i}]: actual={actual}, ref={expected}, diff={diff}"
163 );
164 }
165 }
166
167 #[test]
168 fn test_normalization_correctness_with_scaling() {
169 let x_data = vec![1.0_f32, 2.0, 3.0, 4.0];
170 let gamma_data = vec![2.0_f32; 4];
171 let beta_data = vec![1.0_f32; 4];
172 let eps = 1e-5;
173 let reference = reference_layer_norm_f64(&x_data, &gamma_data, &beta_data, eps);
174 let x = Tensor::from_vec(x_data, false);
175 let gamma = Tensor::from_vec(gamma_data, false);
176 let beta = Tensor::from_vec(beta_data, false);
177 let result = layer_norm(&x, &gamma, &beta, eps);
178 for (i, (&actual, &expected)) in result.data().iter().zip(reference.iter()).enumerate() {
179 let diff = (actual - expected).abs();
180 assert!(diff < 1e-5, "LayerNorm correctness scaled[{i}]: diff={diff}");
181 }
182 }
183}
184
185#[cfg(test)]
186mod ln_contract_tests {
187 use super::*;
188 use crate::autograd::Tensor;
189
190 fn make_unit_params(dim: usize) -> (Tensor, Tensor) {
191 let gamma = Tensor::from_vec(vec![1.0; dim], false);
192 let beta = Tensor::from_vec(vec![0.0; dim], false);
193 (gamma, beta)
194 }
195
196 #[test]
198 fn falsify_ln_001_centering() {
199 let (gamma, beta) = make_unit_params(8);
200 let data = vec![1.0, -2.0, 3.0, 0.5, -1.5, 2.5, -0.5, 1.5];
201 let x = Tensor::from_vec(data, false);
202 let y = layer_norm(&x, &gamma, &beta, 1e-5);
203
204 let mean: f32 = y.data().sum() / y.len() as f32;
205 assert!(mean.abs() < 1e-5, "FALSIFIED LN-001: mean(LN(x)) = {mean}, expected ≈ 0");
206 }
207
208 #[test]
210 fn falsify_ln_002_standardization() {
211 let (gamma, beta) = make_unit_params(8);
212 let data = vec![1.0, -2.0, 3.0, 0.5, -1.5, 2.5, -0.5, 1.5];
213 let x = Tensor::from_vec(data, false);
214 let y = layer_norm(&x, &gamma, &beta, 1e-5);
215 let y_data = y.data();
216 let n = y.len() as f32;
217
218 let mean: f32 = y_data.sum() / n;
219 let var: f32 = y_data.mapv(|v| (v - mean).powi(2)).sum() / n;
220 assert!((var - 1.0).abs() < 0.05, "FALSIFIED LN-002: var(LN(x)) = {var}, expected ≈ 1.0");
221 }
222
223 #[test]
225 fn falsify_ln_003_denominator_safety() {
226 let (gamma, beta) = make_unit_params(4);
227 let test_cases: Vec<(&str, Vec<f32>)> = vec![
228 ("normal", vec![1.0, 2.0, 3.0, 4.0]),
229 ("small", vec![1e-7, 1e-7, 1e-7, 1e-7]),
230 ("large", vec![1e6, 1e6, 1e6, 1e6]),
231 ("mixed_sign", vec![-3.0, 2.0, -1.0, 4.0]),
232 ("near_zero", vec![1e-20, 0.0, 1e-20, 0.0]),
233 ("all_zero", vec![0.0, 0.0, 0.0, 0.0]),
234 ];
235
236 for (name, data) in &test_cases {
237 let x = Tensor::from_vec(data.clone(), false);
238 let y = layer_norm(&x, &gamma, &beta, 1e-5);
239 for (i, &val) in y.data().iter().enumerate() {
240 assert!(
241 val.is_finite(),
242 "FALSIFIED LN-003: output[{i}] = {val} not finite for case '{name}'"
243 );
244 }
245 }
246 }
247
248 #[test]
250 fn falsify_ln_005_idempotency() {
251 let (gamma, beta) = make_unit_params(6);
252 let x = Tensor::from_vec(vec![10.0, -5.0, 3.0, 7.0, -2.0, 0.5], false);
253 let y1 = layer_norm(&x, &gamma, &beta, 1e-5);
254 let y2 = layer_norm(&y1, &gamma, &beta, 1e-5);
255
256 for (i, (&a, &b)) in y1.data().iter().zip(y2.data().iter()).enumerate() {
257 let diff = (a - b).abs();
258 assert!(
259 diff < 1e-4,
260 "FALSIFIED LN-005: LN(LN(x))[{i}] = {b}, LN(x)[{i}] = {a}, diff = {diff}"
261 );
262 }
263 }
264
265 #[test]
267 fn falsify_ln_006_shift_invariance() {
268 let (gamma, beta) = make_unit_params(5);
269 let data = vec![1.0, -2.0, 3.0, 0.5, -1.5];
270 let x = Tensor::from_vec(data.clone(), false);
271 let y_base = layer_norm(&x, &gamma, &beta, 1e-5);
272
273 for &c in &[10.0_f32, -100.0, 0.001, 1000.0] {
274 let shifted: Vec<f32> = data.iter().map(|&v| v + c).collect();
275 let x_shifted = Tensor::from_vec(shifted, false);
276 let y_shifted = layer_norm(&x_shifted, &gamma, &beta, 1e-5);
277
278 for (i, (&a, &b)) in y_base.data().iter().zip(y_shifted.data().iter()).enumerate() {
279 let tol = 1e-3 * a.abs().max(1.0);
280 assert!(
281 (a - b).abs() < tol,
282 "FALSIFIED LN-006: LN(x)[{i}]={a}, LN(x+{c})[{i}]={b}"
283 );
284 }
285 }
286 }
287
288 #[test]
290 fn falsify_ln_007_constant_input() {
291 let (gamma, beta) = make_unit_params(4);
292 for &c in &[0.0_f32, 1.0, -5.0, 1e6, 1e-6] {
293 let x = Tensor::from_vec(vec![c; 4], false);
294 let y = layer_norm(&x, &gamma, &beta, 1e-5);
295
296 for (i, &val) in y.data().iter().enumerate() {
297 assert!(val.is_finite(), "FALSIFIED LN-003 (via LN-007): NaN/Inf for constant {c}");
298 assert!(
299 val.abs() < 1e-3,
300 "FALSIFIED LN-007: LN([{c};4])[{i}] = {val}, expected ≈ 0"
301 );
302 }
303 }
304 }
305
306 mod ln_proptest_falsify {
307 use super::*;
308 use proptest::prelude::*;
309
310 proptest! {
312 #![proptest_config(ProptestConfig::with_cases(200))]
313 #[test]
314 fn falsify_ln_001_prop_centering(
315 dim in prop::sample::select(vec![4_usize, 8, 16, 32, 64]),
316 scale in 0.01_f32..100.0,
317 ) {
318 let (gamma, beta) = make_unit_params(dim);
319 let data: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.37 * scale).sin() * scale).collect();
320 let x = Tensor::from_vec(data, false);
321 let y = layer_norm(&x, &gamma, &beta, 1e-5);
322
323 let mean: f32 = y.data().sum() / dim as f32;
324 prop_assert!(
325 mean.abs() < 1e-4,
326 "FALSIFIED LN-001-prop: mean(LN(x)) = {} (d={}, scale={})",
327 mean, dim, scale
328 );
329 }
330 }
331
332 proptest! {
334 #![proptest_config(ProptestConfig::with_cases(200))]
335 #[test]
336 fn falsify_ln_002_prop_standardization(
337 dim in prop::sample::select(vec![8_usize, 16, 32, 64]),
338 scale in 0.1_f32..100.0,
339 ) {
340 let (gamma, beta) = make_unit_params(dim);
341 let data: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.23).sin() * scale).collect();
342 let x = Tensor::from_vec(data, false);
343 let y = layer_norm(&x, &gamma, &beta, 1e-5);
344 let y_data = y.data();
345 let n = dim as f32;
346
347 let mean: f32 = y_data.sum() / n;
348 let var: f32 = y_data.mapv(|v| (v - mean).powi(2)).sum() / n;
349 prop_assert!(
350 (var - 1.0).abs() < 0.1,
351 "FALSIFIED LN-002-prop: var(LN(x)) = {} (d={}, scale={})",
352 var, dim, scale
353 );
354 }
355 }
356
357 proptest! {
359 #![proptest_config(ProptestConfig::with_cases(100))]
360 #[test]
361 fn falsify_ln_006_prop_shift_invariance(
362 dim in prop::sample::select(vec![4_usize, 8, 16, 32]),
363 shift in prop::sample::select(vec![-100.0_f32, -1.0, 0.5, 10.0, 1000.0]),
364 ) {
365 let (gamma, beta) = make_unit_params(dim);
366 let data: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.37).sin() * 5.0).collect();
367 let x = Tensor::from_vec(data.clone(), false);
368 let y_base = layer_norm(&x, &gamma, &beta, 1e-5);
369
370 let shifted: Vec<f32> = data.iter().map(|&v| v + shift).collect();
371 let x_shifted = Tensor::from_vec(shifted, false);
372 let y_shifted = layer_norm(&x_shifted, &gamma, &beta, 1e-5);
373
374 for (i, (&a, &b)) in y_base.data().iter().zip(y_shifted.data().iter()).enumerate() {
375 let tol = 1e-3 * a.abs().max(1.0);
376 prop_assert!(
377 (a - b).abs() < tol,
378 "FALSIFIED LN-006-prop: LN(x)[{i}]={a}, LN(x+{shift})[{i}]={b} (d={dim})"
379 );
380 }
381 }
382 }
383
384 proptest! {
386 #![proptest_config(ProptestConfig::with_cases(100))]
387 #[test]
388 fn falsify_ln_007_prop_constant_input(
389 dim in prop::sample::select(vec![4_usize, 8, 16, 32]),
390 c in prop::sample::select(vec![-1e6_f32, -1.0, 0.0, 1.0, 1e6]),
391 ) {
392 let (gamma, beta) = make_unit_params(dim);
393 let x = Tensor::from_vec(vec![c; dim], false);
394 let y = layer_norm(&x, &gamma, &beta, 1e-5);
395
396 for (i, &val) in y.data().iter().enumerate() {
397 prop_assert!(
398 val.is_finite(),
399 "FALSIFIED LN-003-prop: NaN/Inf at [{i}] for constant {c} (d={dim})"
400 );
401 prop_assert!(
402 val.abs() < 1e-3,
403 "FALSIFIED LN-007-prop: LN([{c};{dim}])[{i}] = {val} (expected ≈ 0)"
404 );
405 }
406 }
407 }
408 }
409}