Skip to main content

entrenar/autograd/ops/
normalize.rs

1//! Normalization autograd operations: layer_norm
2
3use crate::autograd::{BackwardOp, Tensor};
4use ndarray::Array1;
5use std::cell::RefCell;
6use std::rc::Rc;
7
8/// Layer Normalization
9///
10/// Normalizes input to have mean=0 and variance=1, then applies learned scale (gamma) and shift (beta)
11/// LayerNorm(x) = gamma * (x - mean) / sqrt(var + epsilon) + beta
12pub fn layer_norm(x: &Tensor, gamma: &Tensor, beta: &Tensor, epsilon: f32) -> Tensor {
13    let n = x.len() as f32;
14
15    // Compute mean
16    let mean = x.data().sum() / n;
17
18    // Compute variance
19    let variance = x.data().mapv(|val| (val - mean).powi(2)).sum() / n;
20    let std = (variance + epsilon).sqrt();
21
22    // Normalize
23    let normalized = x.data().mapv(|val| (val - mean) / std);
24
25    // Scale and shift
26    let data = &normalized * gamma.data() + beta.data();
27
28    let requires_grad = x.requires_grad() || gamma.requires_grad() || beta.requires_grad();
29    let mut result = Tensor::new(data, requires_grad);
30
31    if requires_grad {
32        let x_clone = x.clone();
33        let gamma_clone = gamma.clone();
34        let beta_clone = beta.clone();
35        let backward_op = Rc::new(LayerNormBackward {
36            x: x_clone,
37            gamma: gamma_clone,
38            beta: beta_clone,
39            normalized: normalized.clone(),
40            std,
41            result_grad: result.grad_cell(),
42        });
43        result.set_backward_op(backward_op);
44    }
45
46    contract_post_layernorm!(result.data().as_slice().unwrap_or(&[]));
47    result
48}
49
50struct LayerNormBackward {
51    x: Tensor,
52    gamma: Tensor,
53    beta: Tensor,
54    normalized: Array1<f32>,
55    std: f32,
56    result_grad: Rc<RefCell<Option<Array1<f32>>>>,
57}
58
59impl BackwardOp for LayerNormBackward {
60    fn backward(&self) {
61        if let Some(grad_output) = self.result_grad.borrow().as_ref() {
62            let n = self.x.len() as f32;
63
64            // ∂L/∂beta = ∂L/∂y (gradient flows directly through addition)
65            if self.beta.requires_grad() {
66                self.beta.accumulate_grad(grad_output.clone());
67            }
68
69            // ∂L/∂gamma = ∂L/∂y * x_normalized
70            if self.gamma.requires_grad() {
71                let grad_gamma = grad_output * &self.normalized;
72                self.gamma.accumulate_grad(grad_gamma);
73            }
74
75            // ∂L/∂x is complex due to mean and variance dependencies
76            if self.x.requires_grad() {
77                // Gradient through scale: grad_normalized = grad_output * gamma
78                let grad_normalized = grad_output * self.gamma.data();
79
80                // Sum of gradients (for mean term)
81                let sum_grad = grad_normalized.sum();
82
83                // Sum of gradients weighted by normalized values (for variance term)
84                let sum_grad_normalized = (&grad_normalized * &self.normalized).sum();
85
86                // Full gradient formula:
87                // ∂L/∂x_i = (1/std) * [grad_normalized_i - (1/n)*sum_grad - (1/n)*normalized_i*sum_grad_normalized]
88                let grad_x: Vec<f32> = grad_normalized
89                    .iter()
90                    .zip(self.normalized.iter())
91                    .map(|(&grad_norm, &norm)| {
92                        (grad_norm - sum_grad / n - norm * sum_grad_normalized / n) / self.std
93                    })
94                    .collect();
95
96                self.x.accumulate_grad(Array1::from(grad_x));
97            }
98
99            // Continue backward through the graph
100            if let Some(op) = self.x.backward_op() {
101                op.backward();
102            }
103            if let Some(op) = self.gamma.backward_op() {
104                op.backward();
105            }
106            if let Some(op) = self.beta.backward_op() {
107                op.backward();
108            }
109        }
110    }
111}
112
113// =========================================================================
114// FALSIFY-LN: layernorm-kernel-v1.yaml contract (entrenar layer_norm)
115//
116// Five-Whys (PMAT-354, Phase 10):
117//   Why 1: entrenar had zero FALSIFY-LN-* tests despite a full LayerNorm impl
118//   Why 2: autograd backward tests verify gradients, not output invariants
119//   Why 3: no mapping from layernorm-kernel-v1.yaml to entrenar tests
120//   Why 4: entrenar predates the provable-contracts YAML convention
121//   Why 5: LayerNorm was "obviously correct" (y = (x-μ)/σ * γ + β)
122//
123// References:
124//   - provable-contracts/contracts/layernorm-kernel-v1.yaml
125//   - Ba et al. (2016) "Layer Normalization"
126// =========================================================================
127
128#[cfg(test)]
129mod normalization_correctness_tests {
130    use super::*;
131    use crate::autograd::Tensor;
132
133    /// Reference LayerNorm (f64 precision) for correctness verification
134    fn reference_layer_norm_f64(x: &[f32], gamma: &[f32], beta: &[f32], eps: f32) -> Vec<f32> {
135        let n = x.len() as f64;
136        let x_f64: Vec<f64> = x.iter().map(|&v| f64::from(v)).collect();
137        let mean: f64 = x_f64.iter().sum::<f64>() / n;
138        let variance: f64 = x_f64.iter().map(|&v| (v - mean) * (v - mean)).sum::<f64>() / n;
139        let std = (variance + f64::from(eps)).sqrt();
140        x_f64
141            .iter()
142            .enumerate()
143            .map(|(i, &v)| ((v - mean) / std * f64::from(gamma[i]) + f64::from(beta[i])) as f32)
144            .collect()
145    }
146
147    #[test]
148    fn test_normalization_correctness_matches_reference() {
149        let x_data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
150        let gamma_data = vec![1.0_f32; 5];
151        let beta_data = vec![0.0_f32; 5];
152        let eps = 1e-5;
153        let reference = reference_layer_norm_f64(&x_data, &gamma_data, &beta_data, eps);
154        let x = Tensor::from_vec(x_data, false);
155        let gamma = Tensor::from_vec(gamma_data, false);
156        let beta = Tensor::from_vec(beta_data, false);
157        let result = layer_norm(&x, &gamma, &beta, eps);
158        for (i, (&actual, &expected)) in result.data().iter().zip(reference.iter()).enumerate() {
159            let diff = (actual - expected).abs();
160            assert!(
161                diff < 1e-5,
162                "LayerNorm correctness[{i}]: actual={actual}, ref={expected}, diff={diff}"
163            );
164        }
165    }
166
167    #[test]
168    fn test_normalization_correctness_with_scaling() {
169        let x_data = vec![1.0_f32, 2.0, 3.0, 4.0];
170        let gamma_data = vec![2.0_f32; 4];
171        let beta_data = vec![1.0_f32; 4];
172        let eps = 1e-5;
173        let reference = reference_layer_norm_f64(&x_data, &gamma_data, &beta_data, eps);
174        let x = Tensor::from_vec(x_data, false);
175        let gamma = Tensor::from_vec(gamma_data, false);
176        let beta = Tensor::from_vec(beta_data, false);
177        let result = layer_norm(&x, &gamma, &beta, eps);
178        for (i, (&actual, &expected)) in result.data().iter().zip(reference.iter()).enumerate() {
179            let diff = (actual - expected).abs();
180            assert!(diff < 1e-5, "LayerNorm correctness scaled[{i}]: diff={diff}");
181        }
182    }
183}
184
185#[cfg(test)]
186mod ln_contract_tests {
187    use super::*;
188    use crate::autograd::Tensor;
189
190    fn make_unit_params(dim: usize) -> (Tensor, Tensor) {
191        let gamma = Tensor::from_vec(vec![1.0; dim], false);
192        let beta = Tensor::from_vec(vec![0.0; dim], false);
193        (gamma, beta)
194    }
195
196    /// FALSIFY-LN-001: Centering — mean of LN output ≈ 0 (with beta=0)
197    #[test]
198    fn falsify_ln_001_centering() {
199        let (gamma, beta) = make_unit_params(8);
200        let data = vec![1.0, -2.0, 3.0, 0.5, -1.5, 2.5, -0.5, 1.5];
201        let x = Tensor::from_vec(data, false);
202        let y = layer_norm(&x, &gamma, &beta, 1e-5);
203
204        let mean: f32 = y.data().sum() / y.len() as f32;
205        assert!(mean.abs() < 1e-5, "FALSIFIED LN-001: mean(LN(x)) = {mean}, expected ≈ 0");
206    }
207
208    /// FALSIFY-LN-002: Standardization — variance of LN output ≈ 1 (with gamma=1)
209    #[test]
210    fn falsify_ln_002_standardization() {
211        let (gamma, beta) = make_unit_params(8);
212        let data = vec![1.0, -2.0, 3.0, 0.5, -1.5, 2.5, -0.5, 1.5];
213        let x = Tensor::from_vec(data, false);
214        let y = layer_norm(&x, &gamma, &beta, 1e-5);
215        let y_data = y.data();
216        let n = y.len() as f32;
217
218        let mean: f32 = y_data.sum() / n;
219        let var: f32 = y_data.mapv(|v| (v - mean).powi(2)).sum() / n;
220        assert!((var - 1.0).abs() < 0.05, "FALSIFIED LN-002: var(LN(x)) = {var}, expected ≈ 1.0");
221    }
222
223    /// FALSIFY-LN-003: Denominator safety — output finite for all finite input
224    #[test]
225    fn falsify_ln_003_denominator_safety() {
226        let (gamma, beta) = make_unit_params(4);
227        let test_cases: Vec<(&str, Vec<f32>)> = vec![
228            ("normal", vec![1.0, 2.0, 3.0, 4.0]),
229            ("small", vec![1e-7, 1e-7, 1e-7, 1e-7]),
230            ("large", vec![1e6, 1e6, 1e6, 1e6]),
231            ("mixed_sign", vec![-3.0, 2.0, -1.0, 4.0]),
232            ("near_zero", vec![1e-20, 0.0, 1e-20, 0.0]),
233            ("all_zero", vec![0.0, 0.0, 0.0, 0.0]),
234        ];
235
236        for (name, data) in &test_cases {
237            let x = Tensor::from_vec(data.clone(), false);
238            let y = layer_norm(&x, &gamma, &beta, 1e-5);
239            for (i, &val) in y.data().iter().enumerate() {
240                assert!(
241                    val.is_finite(),
242                    "FALSIFIED LN-003: output[{i}] = {val} not finite for case '{name}'"
243                );
244            }
245        }
246    }
247
248    /// FALSIFY-LN-005: Idempotency — LN(LN(x)) ≈ LN(x)
249    #[test]
250    fn falsify_ln_005_idempotency() {
251        let (gamma, beta) = make_unit_params(6);
252        let x = Tensor::from_vec(vec![10.0, -5.0, 3.0, 7.0, -2.0, 0.5], false);
253        let y1 = layer_norm(&x, &gamma, &beta, 1e-5);
254        let y2 = layer_norm(&y1, &gamma, &beta, 1e-5);
255
256        for (i, (&a, &b)) in y1.data().iter().zip(y2.data().iter()).enumerate() {
257            let diff = (a - b).abs();
258            assert!(
259                diff < 1e-4,
260                "FALSIFIED LN-005: LN(LN(x))[{i}] = {b}, LN(x)[{i}] = {a}, diff = {diff}"
261            );
262        }
263    }
264
265    /// FALSIFY-LN-006: Shift invariance — LN(x + c) = LN(x)
266    #[test]
267    fn falsify_ln_006_shift_invariance() {
268        let (gamma, beta) = make_unit_params(5);
269        let data = vec![1.0, -2.0, 3.0, 0.5, -1.5];
270        let x = Tensor::from_vec(data.clone(), false);
271        let y_base = layer_norm(&x, &gamma, &beta, 1e-5);
272
273        for &c in &[10.0_f32, -100.0, 0.001, 1000.0] {
274            let shifted: Vec<f32> = data.iter().map(|&v| v + c).collect();
275            let x_shifted = Tensor::from_vec(shifted, false);
276            let y_shifted = layer_norm(&x_shifted, &gamma, &beta, 1e-5);
277
278            for (i, (&a, &b)) in y_base.data().iter().zip(y_shifted.data().iter()).enumerate() {
279                let tol = 1e-3 * a.abs().max(1.0);
280                assert!(
281                    (a - b).abs() < tol,
282                    "FALSIFIED LN-006: LN(x)[{i}]={a}, LN(x+{c})[{i}]={b}"
283                );
284            }
285        }
286    }
287
288    /// FALSIFY-LN-007: Constant input → output ≈ beta (0)
289    #[test]
290    fn falsify_ln_007_constant_input() {
291        let (gamma, beta) = make_unit_params(4);
292        for &c in &[0.0_f32, 1.0, -5.0, 1e6, 1e-6] {
293            let x = Tensor::from_vec(vec![c; 4], false);
294            let y = layer_norm(&x, &gamma, &beta, 1e-5);
295
296            for (i, &val) in y.data().iter().enumerate() {
297                assert!(val.is_finite(), "FALSIFIED LN-003 (via LN-007): NaN/Inf for constant {c}");
298                assert!(
299                    val.abs() < 1e-3,
300                    "FALSIFIED LN-007: LN([{c};4])[{i}] = {val}, expected ≈ 0"
301                );
302            }
303        }
304    }
305
306    mod ln_proptest_falsify {
307        use super::*;
308        use proptest::prelude::*;
309
310        // LN-001-prop: centering
311        proptest! {
312            #![proptest_config(ProptestConfig::with_cases(200))]
313            #[test]
314            fn falsify_ln_001_prop_centering(
315                dim in prop::sample::select(vec![4_usize, 8, 16, 32, 64]),
316                scale in 0.01_f32..100.0,
317            ) {
318                let (gamma, beta) = make_unit_params(dim);
319                let data: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.37 * scale).sin() * scale).collect();
320                let x = Tensor::from_vec(data, false);
321                let y = layer_norm(&x, &gamma, &beta, 1e-5);
322
323                let mean: f32 = y.data().sum() / dim as f32;
324                prop_assert!(
325                    mean.abs() < 1e-4,
326                    "FALSIFIED LN-001-prop: mean(LN(x)) = {} (d={}, scale={})",
327                    mean, dim, scale
328                );
329            }
330        }
331
332        // LN-002-prop: standardization
333        proptest! {
334            #![proptest_config(ProptestConfig::with_cases(200))]
335            #[test]
336            fn falsify_ln_002_prop_standardization(
337                dim in prop::sample::select(vec![8_usize, 16, 32, 64]),
338                scale in 0.1_f32..100.0,
339            ) {
340                let (gamma, beta) = make_unit_params(dim);
341                let data: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.23).sin() * scale).collect();
342                let x = Tensor::from_vec(data, false);
343                let y = layer_norm(&x, &gamma, &beta, 1e-5);
344                let y_data = y.data();
345                let n = dim as f32;
346
347                let mean: f32 = y_data.sum() / n;
348                let var: f32 = y_data.mapv(|v| (v - mean).powi(2)).sum() / n;
349                prop_assert!(
350                    (var - 1.0).abs() < 0.1,
351                    "FALSIFIED LN-002-prop: var(LN(x)) = {} (d={}, scale={})",
352                    var, dim, scale
353                );
354            }
355        }
356
357        // LN-006-prop: shift invariance
358        proptest! {
359            #![proptest_config(ProptestConfig::with_cases(100))]
360            #[test]
361            fn falsify_ln_006_prop_shift_invariance(
362                dim in prop::sample::select(vec![4_usize, 8, 16, 32]),
363                shift in prop::sample::select(vec![-100.0_f32, -1.0, 0.5, 10.0, 1000.0]),
364            ) {
365                let (gamma, beta) = make_unit_params(dim);
366                let data: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.37).sin() * 5.0).collect();
367                let x = Tensor::from_vec(data.clone(), false);
368                let y_base = layer_norm(&x, &gamma, &beta, 1e-5);
369
370                let shifted: Vec<f32> = data.iter().map(|&v| v + shift).collect();
371                let x_shifted = Tensor::from_vec(shifted, false);
372                let y_shifted = layer_norm(&x_shifted, &gamma, &beta, 1e-5);
373
374                for (i, (&a, &b)) in y_base.data().iter().zip(y_shifted.data().iter()).enumerate() {
375                    let tol = 1e-3 * a.abs().max(1.0);
376                    prop_assert!(
377                        (a - b).abs() < tol,
378                        "FALSIFIED LN-006-prop: LN(x)[{i}]={a}, LN(x+{shift})[{i}]={b} (d={dim})"
379                    );
380                }
381            }
382        }
383
384        // LN-007-prop: constant input
385        proptest! {
386            #![proptest_config(ProptestConfig::with_cases(100))]
387            #[test]
388            fn falsify_ln_007_prop_constant_input(
389                dim in prop::sample::select(vec![4_usize, 8, 16, 32]),
390                c in prop::sample::select(vec![-1e6_f32, -1.0, 0.0, 1.0, 1e6]),
391            ) {
392                let (gamma, beta) = make_unit_params(dim);
393                let x = Tensor::from_vec(vec![c; dim], false);
394                let y = layer_norm(&x, &gamma, &beta, 1e-5);
395
396                for (i, &val) in y.data().iter().enumerate() {
397                    prop_assert!(
398                        val.is_finite(),
399                        "FALSIFIED LN-003-prop: NaN/Inf at [{i}] for constant {c} (d={dim})"
400                    );
401                    prop_assert!(
402                        val.abs() < 1e-3,
403                        "FALSIFIED LN-007-prop: LN([{c};{dim}])[{i}] = {val} (expected ≈ 0)"
404                    );
405                }
406            }
407        }
408    }
409}