Skip to main content

entrenar/transformer/
norm.rs

1//! RMS Normalization module
2//!
3//! This module provides RMS normalization layers for transformer models.
4
5use crate::autograd::scale;
6use crate::Tensor;
7use std::collections::HashMap;
8
9/// RMS Normalization layer
10pub struct RMSNorm {
11    /// Weight (scale) parameter
12    pub weight: Tensor,
13    /// Epsilon for numerical stability
14    eps: f32,
15}
16
17impl RMSNorm {
18    /// Create new RMS normalization layer
19    pub fn new(hidden_size: usize, eps: f32) -> Self {
20        Self { weight: Tensor::ones(hidden_size, true), eps }
21    }
22
23    /// Create from parameters
24    ///
25    /// # Contract (PMAT-332 norm)
26    /// Validates weight.len() == hidden_size.
27    /// Returns None if key is missing or length is wrong.
28    pub fn from_params(
29        params: &HashMap<String, Tensor>,
30        prefix: &str,
31        eps: f32,
32        hidden_size: usize,
33    ) -> Option<Self> {
34        let weight = params.get(&format!("{prefix}.weight"))?.clone();
35        if weight.len() != hidden_size {
36            eprintln!(
37                "[PMAT-332] {prefix}.weight: length mismatch — got {}, expected {hidden_size}",
38                weight.len()
39            );
40            return None;
41        }
42        Some(Self { weight, eps })
43    }
44
45    /// Forward pass
46    ///
47    /// RMSNorm(x) = x / sqrt(mean(x^2) + eps) * weight
48    pub fn forward(&self, x: &Tensor) -> Tensor {
49        contract_pre_rmsnorm!(x.data());
50        let n = x.len() as f32;
51
52        // Compute RMS
53        let sq_sum: f32 = x.data().iter().map(|v| v * v).sum();
54        let rms = (sq_sum / n + self.eps).sqrt();
55
56        // Normalize and scale
57        let normalized = scale(x, 1.0 / rms);
58        let result = crate::autograd::mul(&normalized, &self.weight);
59        contract_post_rmsnorm!(result.data().as_slice().unwrap_or(&[]));
60        result
61    }
62
63    /// Forward pass for batched input
64    ///
65    /// # Arguments
66    /// * `x` - Input tensor (seq_len * hidden_size, flattened)
67    /// * `seq_len` - Sequence length
68    /// * `hidden_size` - Hidden dimension
69    pub fn forward_batched(&self, x: &Tensor, seq_len: usize, hidden_size: usize) -> Tensor {
70        let mut output = vec![0.0; seq_len * hidden_size];
71        let mut rms_values = Vec::with_capacity(seq_len);
72
73        // KAIZEN-020: Hoist data borrows outside the sequence loop.
74        // Previously x.data().as_slice() was called per-position (seq_len times)
75        // and self.weight.data()[i] was called per-element (seq_len * hidden_size times).
76        let x_data = x.data();
77        let x_slice = x_data.as_slice().expect("norm input must be contiguous");
78        let w_data = self.weight.data();
79        let w_slice = w_data.as_slice().expect("norm weight must be contiguous");
80
81        for s in 0..seq_len {
82            let start = s * hidden_size;
83            let end = start + hidden_size;
84            let slice = &x_slice[start..end];
85
86            // Compute RMS for this position
87            let sq_sum: f32 = slice.iter().map(|v| v * v).sum();
88            let rms = (sq_sum / hidden_size as f32 + self.eps).sqrt();
89            rms_values.push(rms);
90
91            // Normalize and scale
92            for (i, &val) in slice.iter().enumerate() {
93                output[start + i] = (val / rms) * w_slice[i];
94            }
95        }
96
97        let requires_grad = x.requires_grad() || self.weight.requires_grad();
98        let mut result = Tensor::from_vec(output, requires_grad);
99
100        if requires_grad {
101            use crate::autograd::BackwardOp;
102            use ndarray::Array1;
103            use std::cell::RefCell;
104            use std::rc::Rc;
105
106            struct RMSNormBatchedBackward {
107                x: Tensor,
108                weight: Tensor,
109                rms_values: Vec<f32>,
110                seq_len: usize,
111                hidden_size: usize,
112                result_grad: Rc<RefCell<Option<Array1<f32>>>>,
113            }
114
115            impl BackwardOp for RMSNormBatchedBackward {
116                fn backward(&self) {
117                    if let Some(grad_output) = self.result_grad.borrow().as_ref() {
118                        let h = self.hidden_size;
119                        let x_data = self.x.data();
120                        let x_sl = x_data.as_slice().expect("x contiguous");
121                        let w_data = self.weight.data();
122                        let w_sl = w_data.as_slice().expect("weight contiguous");
123                        let go = grad_output.as_slice().expect("grad contiguous");
124
125                        if self.x.requires_grad() {
126                            // dx[s,j] = (go[s,j]*w[j] - x[s,j]*c_s) / rms_s
127                            // c_s = sum_i(go[s,i]*w[i]*x[s,i]) / (n * rms_s^2)
128                            let mut grad_x = vec![0.0_f32; self.seq_len * h];
129                            let n = h as f32;
130
131                            for s in 0..self.seq_len {
132                                let off = s * h;
133                                let rms = self.rms_values[s];
134
135                                let mut dot = 0.0_f32;
136                                for i in 0..h {
137                                    dot += go[off + i] * w_sl[i] * x_sl[off + i];
138                                }
139                                let c = dot / (n * rms * rms);
140
141                                for j in 0..h {
142                                    grad_x[off + j] =
143                                        (go[off + j] * w_sl[j] - x_sl[off + j] * c) / rms;
144                                }
145                            }
146
147                            self.x.accumulate_grad(Array1::from(grad_x));
148                        }
149
150                        if self.weight.requires_grad() {
151                            // dw[i] = sum_s(go[s,i] * x[s,i] / rms_s)
152                            let mut grad_w = vec![0.0_f32; h];
153
154                            for s in 0..self.seq_len {
155                                let off = s * h;
156                                let rms = self.rms_values[s];
157                                for i in 0..h {
158                                    grad_w[i] += go[off + i] * x_sl[off + i] / rms;
159                                }
160                            }
161
162                            self.weight.accumulate_grad(Array1::from(grad_w));
163                        }
164
165                        // Continue backward propagation through inputs
166                        if let Some(op) = self.x.backward_op() {
167                            op.backward();
168                        }
169                        if let Some(op) = self.weight.backward_op() {
170                            op.backward();
171                        }
172                    }
173                }
174            }
175
176            let backward_op = Rc::new(RMSNormBatchedBackward {
177                x: x.clone(),
178                weight: self.weight.clone(),
179                rms_values,
180                seq_len,
181                hidden_size,
182                result_grad: result.grad_cell(),
183            });
184            result.set_backward_op(backward_op);
185        }
186
187        contract_post_rmsnorm!(result.data().as_slice().unwrap_or(&[]));
188        result
189    }
190}
191
192/// Layer Normalization with bias (used by BERT/RoBERTa/CodeBERT encoders).
193///
194/// Unlike RMSNorm (used by decoders), LayerNorm:
195/// 1. Subtracts the mean (re-centering)
196/// 2. Divides by standard deviation (not RMS)
197/// 3. Has both weight (gamma) AND bias (beta) parameters
198///
199/// LayerNorm(x) = (x - mean(x)) / sqrt(var(x) + eps) * weight + bias
200///
201/// # Contract (ENC-005)
202/// - Output has zero mean and unit variance (before affine transform)
203/// - weight.len() == bias.len() == hidden_size
204pub struct LayerNorm {
205    /// Scale parameter (gamma)
206    pub weight: Tensor,
207    /// Shift parameter (beta)
208    pub bias: Tensor,
209    /// Epsilon for numerical stability
210    eps: f32,
211    /// Hidden size
212    hidden_size: usize,
213}
214
215impl LayerNorm {
216    /// Create new LayerNorm (weight=1, bias=0)
217    pub fn new(hidden_size: usize, eps: f32) -> Self {
218        Self {
219            weight: Tensor::ones(hidden_size, true),
220            bias: Tensor::from_vec(vec![0.0; hidden_size], true),
221            eps,
222            hidden_size,
223        }
224    }
225
226    /// Create from pre-trained parameters
227    pub fn from_params(
228        params: &HashMap<String, Tensor>,
229        prefix: &str,
230        eps: f32,
231        hidden_size: usize,
232    ) -> Option<Self> {
233        let weight = params.get(&format!("{prefix}.weight"))?.clone();
234        let bias = params.get(&format!("{prefix}.bias"))?.clone();
235        if weight.len() != hidden_size || bias.len() != hidden_size {
236            eprintln!(
237                "[ENC-005] {prefix}: shape mismatch — weight={}, bias={}, expected {hidden_size}",
238                weight.len(),
239                bias.len()
240            );
241            return None;
242        }
243        Some(Self { weight, bias, eps, hidden_size })
244    }
245
246    /// Forward pass for batched input (seq_len positions, each of hidden_size)
247    pub fn forward_batched(&self, x: &Tensor, seq_len: usize, hidden_size: usize) -> Tensor {
248        let mut output = vec![0.0_f32; seq_len * hidden_size];
249        let x_data = x.data();
250        let x_slice = x_data.as_slice().expect("input contiguous");
251        let w_data = self.weight.data();
252        let w_slice = w_data.as_slice().expect("weight contiguous");
253        let b_data = self.bias.data();
254        let b_slice = b_data.as_slice().expect("bias contiguous");
255
256        for s in 0..seq_len {
257            let start = s * hidden_size;
258            let end = start + hidden_size;
259            let row = &x_slice[start..end];
260
261            // Mean
262            let mean: f32 = row.iter().sum::<f32>() / hidden_size as f32;
263
264            // Variance
265            let var: f32 =
266                row.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / hidden_size as f32;
267            let inv_std = 1.0 / (var + self.eps).sqrt();
268
269            // Normalize, scale, shift
270            for (i, &val) in row.iter().enumerate() {
271                output[start + i] = (val - mean) * inv_std * w_slice[i] + b_slice[i];
272            }
273        }
274
275        let result = Tensor::from_vec(output, x.requires_grad() || self.weight.requires_grad());
276        contract_post_layernorm!(result.data().as_slice().unwrap_or(&[]));
277        result
278    }
279
280    /// Get hidden size
281    pub fn hidden_size(&self) -> usize {
282        self.hidden_size
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    #[test]
291    fn test_rms_norm_forward() {
292        let norm = RMSNorm::new(4, 1e-6);
293        let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], true);
294        let output = norm.forward(&x);
295        assert_eq!(output.len(), 4);
296        // Output should be normalized and scaled
297        let data = output.data();
298        assert!(data.iter().all(|&v| v.is_finite()));
299    }
300
301    #[test]
302    fn test_rms_norm_batched() {
303        let norm = RMSNorm::new(4, 1e-6);
304        let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], true);
305        let output = norm.forward_batched(&x, 2, 4);
306        assert_eq!(output.len(), 8);
307    }
308
309    #[test]
310    fn test_rms_norm_normalization_property() {
311        let norm = RMSNorm::new(4, 1e-6);
312        let x = Tensor::from_vec(vec![2.0, 2.0, 2.0, 2.0], true);
313        let output = norm.forward(&x);
314        // After RMS normalization, if weights are 1, output should be x / rms(x)
315        // rms(x) = sqrt(mean(x^2)) = sqrt(4) = 2
316        // so output = [2/2, 2/2, 2/2, 2/2] = [1, 1, 1, 1]
317        let data = output.data();
318        for &val in data {
319            assert!((val - 1.0).abs() < 1e-5, "Expected ~1.0, got {val}");
320        }
321    }
322
323    #[test]
324    fn test_rms_norm_with_zeros() {
325        let norm = RMSNorm::new(4, 1e-6);
326        let x = Tensor::from_vec(vec![0.0, 0.0, 0.0, 0.0], true);
327        let output = norm.forward(&x);
328        // With zeros input and eps, output should be finite (zeros)
329        let data = output.data();
330        assert!(data.iter().all(|&v| v.is_finite()));
331    }
332
333    #[test]
334    fn test_rms_norm_weight_requires_grad() {
335        let norm = RMSNorm::new(4, 1e-6);
336        assert!(norm.weight.requires_grad());
337    }
338
339    #[test]
340    fn test_rms_norm_from_params() {
341        let mut params = HashMap::new();
342        params.insert("test.weight".to_string(), Tensor::from_vec(vec![1.0, 1.0, 1.0, 1.0], true));
343        let norm = RMSNorm::from_params(&params, "test", 1e-6, 4);
344        assert!(norm.is_some());
345        let norm = norm.expect("operation should succeed");
346        assert_eq!(norm.weight.len(), 4);
347    }
348
349    #[test]
350    fn test_rms_norm_from_params_missing() {
351        let params: HashMap<String, Tensor> = HashMap::new();
352        let norm = RMSNorm::from_params(&params, "missing", 1e-6, 4);
353        assert!(norm.is_none());
354    }
355
356    #[test]
357    fn test_rms_norm_backward_gradient_exists() {
358        let norm = RMSNorm::new(8, 1e-6);
359        let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], true);
360        let mut output = norm.forward(&x);
361
362        let grad_out = ndarray::Array1::ones(8);
363        crate::autograd::backward(&mut output, Some(grad_out));
364
365        assert!(norm.weight.grad().is_some());
366        let grad = norm.weight.grad().expect("gradient should be available");
367        assert!(grad.iter().all(|&v| v.is_finite()));
368    }
369
370    /// ALB-038 fix: forward_batched must propagate gradients (was creating tensors with no backward op)
371    #[test]
372    fn test_rms_norm_batched_backward_gradient_exists() {
373        let norm = RMSNorm::new(4, 1e-6);
374        let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], true);
375        let mut output = norm.forward_batched(&x, 2, 4);
376
377        let grad_out = ndarray::Array1::ones(8);
378        crate::autograd::backward(&mut output, Some(grad_out));
379
380        // Weight must receive gradient (was broken before ALB-038 fix)
381        assert!(norm.weight.grad().is_some(), "ALB-038: norm weight must have gradient");
382        let wgrad = norm.weight.grad().expect("gradient available");
383        assert!(wgrad.iter().all(|&v| v.is_finite()), "Weight gradients must be finite");
384        assert!(wgrad.iter().any(|&v| v.abs() > 1e-10), "Weight gradients must be non-zero");
385
386        // Input must receive gradient (enables gradient flow through model)
387        assert!(x.grad().is_some(), "ALB-038: input x must have gradient");
388        let xgrad = x.grad().expect("gradient available");
389        assert!(xgrad.iter().all(|&v| v.is_finite()), "Input gradients must be finite");
390        assert!(xgrad.iter().any(|&v| v.abs() > 1e-10), "Input gradients must be non-zero");
391    }
392
393    /// ALB-038 fix: batched backward produces correct weight gradients
394    ///
395    /// Note: forward() uses scale(x, 1/rms) which treats rms as constant w.r.t. x,
396    /// giving an approximate input gradient. forward_batched() computes the exact
397    /// RMSNorm gradient including d(rms)/d(x). Weight gradients match exactly since
398    /// dL/dw_i = go_i * x_i / rms regardless.
399    #[test]
400    fn test_rms_norm_batched_backward_weight_grad_matches() {
401        let hidden = 4;
402        let data = vec![1.0_f32, -2.0, 3.0, -0.5];
403
404        // Non-batched path (uses autograd ops)
405        let norm1 = RMSNorm::new(hidden, 1e-6);
406        let x1 = Tensor::from_vec(data.clone(), true);
407        let mut out1 = norm1.forward(&x1);
408        crate::autograd::backward(&mut out1, Some(ndarray::Array1::ones(hidden)));
409        let wgrad1 = norm1.weight.grad().expect("gradient available");
410
411        // Batched path (new backward op)
412        let norm2 = RMSNorm::new(hidden, 1e-6);
413        let x2 = Tensor::from_vec(data, true);
414        let mut out2 = norm2.forward_batched(&x2, 1, hidden);
415        crate::autograd::backward(&mut out2, Some(ndarray::Array1::ones(hidden)));
416        let wgrad2 = norm2.weight.grad().expect("gradient available");
417
418        // Weight gradients should match exactly (dw = go * x / rms)
419        for i in 0..hidden {
420            assert!(
421                (wgrad1[i] - wgrad2[i]).abs() < 1e-5,
422                "Weight grad mismatch at [{i}]: unbatched={}, batched={}",
423                wgrad1[i],
424                wgrad2[i]
425            );
426        }
427    }
428
429    // =========================================================================
430    // ENC-005: LayerNorm tests
431    // =========================================================================
432
433    #[test]
434    fn enc_005_layernorm_output_shape() {
435        let ln = LayerNorm::new(8, 1e-5);
436        let x = Tensor::from_vec(vec![1.0; 3 * 8], true);
437        let output = ln.forward_batched(&x, 3, 8);
438        assert_eq!(output.len(), 3 * 8);
439    }
440
441    #[test]
442    fn enc_005_layernorm_zero_mean_unit_var() {
443        let ln = LayerNorm::new(8, 1e-12);
444        let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], true);
445        let output = ln.forward_batched(&x, 1, 8);
446        let data = output.data();
447        let slice = data.as_slice().expect("contiguous");
448
449        // With weight=1, bias=0: output should have ~zero mean, ~unit variance
450        let mean: f32 = slice.iter().sum::<f32>() / 8.0;
451        assert!(mean.abs() < 1e-5, "LayerNorm output mean={mean}, expected ~0");
452
453        let var: f32 = slice.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / 8.0;
454        assert!((var - 1.0).abs() < 0.01, "LayerNorm output var={var}, expected ~1");
455    }
456
457    #[test]
458    fn enc_005_layernorm_with_bias() {
459        let mut ln = LayerNorm::new(4, 1e-12);
460        // Set bias to shift output by 5.0
461        ln.bias = Tensor::from_vec(vec![5.0; 4], true);
462        let x = Tensor::from_vec(vec![1.0, 1.0, 1.0, 1.0], true);
463        let output = ln.forward_batched(&x, 1, 4);
464        // Constant input → normalized to 0, then shifted by bias → all 5.0
465        let data = output.data();
466        for &v in data {
467            assert!((v - 5.0).abs() < 1e-3, "Expected ~5.0 with bias, got {v}");
468        }
469    }
470
471    #[test]
472    fn enc_005_layernorm_from_params() {
473        let mut params = HashMap::new();
474        params.insert("ln.weight".to_string(), Tensor::from_vec(vec![1.0; 32], true));
475        params.insert("ln.bias".to_string(), Tensor::from_vec(vec![0.0; 32], true));
476        let ln = LayerNorm::from_params(&params, "ln", 1e-5, 32);
477        assert!(ln.is_some());
478        assert_eq!(ln.expect("should succeed").hidden_size(), 32);
479    }
480
481    #[test]
482    fn enc_005_layernorm_from_params_rejects_mismatch() {
483        let mut params = HashMap::new();
484        params.insert("ln.weight".to_string(), Tensor::from_vec(vec![1.0; 32], true));
485        params.insert("ln.bias".to_string(), Tensor::from_vec(vec![0.0; 16], true)); // wrong size
486        let ln = LayerNorm::from_params(&params, "ln", 1e-5, 32);
487        assert!(ln.is_none());
488    }
489
490    #[test]
491    fn enc_005_layernorm_finite_output() {
492        let ln = LayerNorm::new(4, 1e-5);
493        let x = Tensor::from_vec(vec![1e6, -1e6, 0.0, 1.0], true);
494        let output = ln.forward_batched(&x, 1, 4);
495        assert!(output.data().iter().all(|v| v.is_finite()));
496    }
497
498    // =========================================================================
499    // FALSIFY-N: §2.1.5-6 Layer Norms — Five-Whys Gap Analysis (Refs PMAT-332)
500    //
501    // Contract: tensor-layout-v1.yaml §tensors.input_layernorm/post_attention_layernorm/final_norm
502    //   apr_shape: "[hidden]"
503    //   transpose: "false"
504    //   kernel: "element-wise multiply"
505    //
506    // Five-Whys:
507    //   Why 1: from_params accepts ANY tensor length without validation
508    //   Why 2: RMSNorm stores raw Tensor with no length check
509    //   Why 3: Wrong-length norm produces wrong-scale hidden states
510    //   Why 4: Mismatched norm length panics at element-wise multiply
511    //   Why 5: No constructor-time length check exists
512    //
513    // Popper (1959): "These tests attempt to falsify the claim that
514    // entrenar's norm handling prevents shape-related runtime errors."
515    // =========================================================================
516
517    /// FALSIFY-N1e: from_params rejects wrong-length norm weight (PMAT-332 norm fix)
518    ///
519    /// RMSNorm.from_params now validates weight.len() == hidden_size.
520    /// A wrong-length weight is rejected at construction time.
521    #[test]
522    fn falsify_n1e_from_params_rejects_wrong_length_norm() {
523        let mut params = HashMap::new();
524        // WRONG: weight has 7 elements, should be hidden_size=4
525        params.insert("test.weight".to_string(), Tensor::from_vec(vec![1.0; 7], true));
526        let norm = RMSNorm::from_params(&params, "test", 1e-6, 4);
527        // FIXED: from_params now rejects wrong-length weight
528        assert!(
529            norm.is_none(),
530            "FALSIFY-N1e: PMAT-332 fix — from_params MUST reject wrong-length norm weight"
531        );
532    }
533
534    /// FALSIFY-N2e: RMSNorm forward produces finite output for valid input
535    #[test]
536    fn falsify_n2e_norm_output_finite() {
537        let norm = RMSNorm::new(8, 1e-6);
538        let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], true);
539        let output = norm.forward(&x);
540        assert!(
541            output.data().iter().all(|v| v.is_finite()),
542            "FALSIFY-N2e: RMSNorm output must be finite for valid input"
543        );
544    }
545
546    /// FALSIFY-N3e: RMSNorm weight length matches hidden_size when constructed via new()
547    #[test]
548    fn falsify_n3e_new_constructor_correct_length() {
549        let hidden_sizes = [64, 128, 256, 896, 4096];
550        for &hidden in &hidden_sizes {
551            let norm = RMSNorm::new(hidden, 1e-6);
552            assert_eq!(
553                norm.weight.len(),
554                hidden,
555                "FALSIFY-N3e: RMSNorm::new({hidden}) weight must have {hidden} elements"
556            );
557        }
558    }
559
560    /// FALSIFY-N4e: Batched forward preserves sequence*hidden dimension
561    #[test]
562    fn falsify_n4e_batched_forward_preserves_dims() {
563        let hidden = 8;
564        let seq_len = 4;
565        let norm = RMSNorm::new(hidden, 1e-6);
566        let x = Tensor::from_vec(vec![0.5; seq_len * hidden], true);
567        let output = norm.forward_batched(&x, seq_len, hidden);
568        assert_eq!(
569            output.len(),
570            seq_len * hidden,
571            "FALSIFY-N4e: Batched norm must preserve seq_len * hidden dimension"
572        );
573        assert!(
574            output.data().iter().all(|v| v.is_finite()),
575            "FALSIFY-N4e: Batched norm output must be finite"
576        );
577    }
578
579    /// FALSIFY-N5e: RMSNorm handles extreme but finite input values
580    ///
581    /// Very large values should still produce finite output due to normalization.
582    #[test]
583    fn falsify_n5e_extreme_input_still_finite() {
584        let norm = RMSNorm::new(4, 1e-6);
585        let x = Tensor::from_vec(vec![1e30, -1e30, 1e30, -1e30], true);
586        let output = norm.forward(&x);
587        assert!(
588            output.data().iter().all(|v| v.is_finite()),
589            "FALSIFY-N5e: RMSNorm must handle extreme values without Inf/NaN"
590        );
591    }
592
593    // =========================================================================
594    // FALSIFY-RN: rmsnorm-kernel-v1.yaml contract (entrenar RMSNorm)
595    //
596    // Five-Whys (PMAT-354):
597    //   Why 1: entrenar had 10+ RMSNorm tests but zero FALSIFY-RN-* tagged tests
598    //   Why 2: existing tests verify API behavior, not mathematical invariants
599    //   Why 3: no mapping from rmsnorm-kernel-v1.yaml to entrenar test names
600    //   Why 4: entrenar predates the provable-contracts YAML convention
601    //   Why 5: norm was "obviously correct" (divide by RMS, multiply by weight)
602    //
603    // References:
604    //   - provable-contracts/contracts/rmsnorm-kernel-v1.yaml
605    //   - Zhang & Sennrich (2019) "Root Mean Square Layer Normalization"
606    // =========================================================================
607
608    /// FALSIFY-RN-001: Finiteness — output must be finite for all finite input
609    ///
610    /// Contract: |RMSNorm(x)_i| < ∞ for all i when ε > 0
611    #[test]
612    fn falsify_rn_001_finiteness() {
613        let norm = RMSNorm::new(4, 1e-6);
614
615        let test_cases: Vec<(&str, Vec<f32>)> = vec![
616            ("normal", vec![1.0, 2.0, 3.0, 4.0]),
617            ("small", vec![1e-7, 1e-7, 1e-7, 1e-7]),
618            ("large", vec![1e6, 1e6, 1e6, 1e6]),
619            ("mixed_sign", vec![-3.0, 2.0, -1.0, 4.0]),
620            ("near_zero", vec![1e-20, 0.0, 1e-20, 0.0]),
621        ];
622
623        for (name, data) in &test_cases {
624            let x = Tensor::from_vec(data.clone(), true);
625            let y = norm.forward(&x);
626
627            for (i, &val) in y.data().iter().enumerate() {
628                assert!(
629                    val.is_finite(),
630                    "FALSIFIED RN-001: output[{i}] = {val} not finite for case '{name}'"
631                );
632            }
633        }
634    }
635
636    /// FALSIFY-RN-002: Scale invariance — RMSNorm(α·x) = sign(α)·RMSNorm(x)
637    #[test]
638    fn falsify_rn_002_scale_invariance() {
639        let norm = RMSNorm::new(4, 1e-6);
640        let x = Tensor::from_vec(vec![1.0, -2.0, 3.0, -0.5], true);
641        let y_base = norm.forward(&x);
642
643        for &alpha in &[2.0_f32, 0.5, -1.0, 10.0] {
644            let x_scaled = Tensor::from_vec(x.data().iter().map(|&v| v * alpha).collect(), true);
645            let y_scaled = norm.forward(&x_scaled);
646
647            let sign = alpha.signum();
648            for (i, (&ys, &yb)) in y_scaled.data().iter().zip(y_base.data().iter()).enumerate() {
649                let expected = sign * yb;
650                let diff = (ys - expected).abs();
651                assert!(
652                    diff < 1e-3,
653                    "FALSIFIED RN-002: RMSNorm({alpha}·x)[{i}] = {ys}, expected {expected}"
654                );
655            }
656        }
657    }
658
659    /// FALSIFY-RN-004: Zero vector — RMSNorm(0) = 0 (not NaN)
660    #[test]
661    fn falsify_rn_004_zero_vector() {
662        let norm = RMSNorm::new(4, 1e-6);
663        let x = Tensor::from_vec(vec![0.0, 0.0, 0.0, 0.0], true);
664        let y = norm.forward(&x);
665
666        for (i, &val) in y.data().iter().enumerate() {
667            assert!(val.is_finite(), "FALSIFIED RN-004: RMSNorm(0)[{i}] = {val} (expected finite)");
668        }
669    }
670
671    /// FALSIFY-RN-005: Unit γ normalized RMS ≈ 1
672    #[test]
673    fn falsify_rn_005_unit_gamma_normalized_rms() {
674        let norm = RMSNorm::new(8, 1e-6);
675        let x = Tensor::from_vec(vec![1.0, -2.0, 3.0, -0.5, 4.0, -1.0, 2.5, -3.0], true);
676        let y = norm.forward(&x);
677        let y_data = y.data();
678
679        let rms_out: f32 =
680            (y_data.iter().map(|&v| v * v).sum::<f32>() / y_data.len() as f32).sqrt();
681
682        assert!(
683            (rms_out - 1.0).abs() < 0.01,
684            "FALSIFIED RN-005: RMS(RMSNorm(x)) = {rms_out}, expected ≈ 1.0"
685        );
686    }
687
688    // =========================================================================
689    // PROPTEST FALSIFY: RMSNorm property-based falsification
690    //
691    // Five-Whys (PMAT-354, Phase 10):
692    //   Why 1: RN-001..005 used fixed dimensions (d=4 or d=8)
693    //   Why 2: Scale invariance (RN-002) could break at edge float ranges
694    //   Why 3: proptest explores dimension/value combos humans miss
695    //   Why 4: RMSNorm eps-dominated regime untested at scale
696    //   Why 5: YAML rmsnorm-kernel-v1 calls for proptest on all claims
697    // =========================================================================
698
699    mod rn_proptest_falsify {
700        use super::*;
701        use proptest::prelude::*;
702
703        // RN-001-prop: finiteness for random vectors
704        proptest! {
705            #![proptest_config(ProptestConfig::with_cases(200))]
706            #[test]
707            fn falsify_rn_001_prop_finiteness(
708                dim in prop::sample::select(vec![4_usize, 8, 16, 32, 64]),
709                scale in 0.001_f32..1000.0,
710            ) {
711                let norm = RMSNorm::new(dim, 1e-6);
712                let data: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.13 * scale).sin()).collect();
713                let x = Tensor::from_vec(data, true);
714                let y = norm.forward(&x);
715                for (i, &val) in y.data().iter().enumerate() {
716                    prop_assert!(
717                        val.is_finite(),
718                        "FALSIFIED RN-001-prop: output[{}]={} not finite (d={}, scale={})",
719                        i, val, dim, scale
720                    );
721                }
722            }
723        }
724
725        // RN-002-prop: scale invariance for random vectors
726        proptest! {
727            #![proptest_config(ProptestConfig::with_cases(100))]
728            #[test]
729            fn falsify_rn_002_prop_scale_invariance(
730                dim in prop::sample::select(vec![4_usize, 8, 16, 32]),
731                alpha in prop::sample::select(vec![-10.0_f32, -1.0, 0.5, 2.0, 100.0]),
732            ) {
733                let norm = RMSNorm::new(dim, 1e-6);
734                let data: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.37).sin() * 5.0).collect();
735                let x = Tensor::from_vec(data.clone(), true);
736                let y_base = norm.forward(&x);
737
738                let x_scaled = Tensor::from_vec(
739                    data.iter().map(|&v| v * alpha).collect(),
740                    true,
741                );
742                let y_scaled = norm.forward(&x_scaled);
743
744                let sign = alpha.signum();
745                for (i, (&ys, &yb)) in y_scaled.data().iter().zip(y_base.data().iter()).enumerate() {
746                    let expected = sign * yb;
747                    prop_assert!(
748                        (ys - expected).abs() < 1e-3,
749                        "FALSIFIED RN-002-prop: [{i}] got {ys}, expected {expected} (alpha={alpha}, d={dim})"
750                    );
751                }
752            }
753        }
754
755        // RN-005-prop: unit gamma normalized RMS for random vectors
756        proptest! {
757            #![proptest_config(ProptestConfig::with_cases(100))]
758            #[test]
759            fn falsify_rn_005_prop_unit_gamma_rms(
760                dim in prop::sample::select(vec![8_usize, 16, 32, 64]),
761            ) {
762                let norm = RMSNorm::new(dim, 1e-6);
763                // Use values large enough that eps doesn't dominate
764                let data: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.23).sin() * 10.0).collect();
765                let x = Tensor::from_vec(data, true);
766                let y = norm.forward(&x);
767                let y_data = y.data();
768
769                let rms_out: f32 = (y_data.iter().map(|&v| v * v).sum::<f32>() / y_data.len() as f32).sqrt();
770                prop_assert!(
771                    (rms_out - 1.0).abs() < 0.05,
772                    "FALSIFIED RN-005-prop: RMS(output)={} != 1.0 (d={})",
773                    rms_out, dim
774                );
775            }
776        }
777    }
778}