Skip to main content

trueno/blis/
norms.rs

1//! SIMD-accelerated normalization kernels.
2//!
3//! AVX2 implementations of RMSNorm and LayerNorm with scalar fallback.
4//! Uses FMA for sum-of-squares accumulation and fused normalize+scale.
5//!
6//! # Algorithm
7//!
8//! RMSNorm: output_i = x_i / sqrt(mean(x^2) + eps) * gamma_i
9//! LayerNorm: output_i = gamma_i * (x_i - mean) / sqrt(var + eps) + beta_i
10//!
11//! Both use two-accumulator SIMD reduction for sum/sum-of-squares to hide
12//! FMA latency, then vectorized normalize+scale pass.
13//!
14//! Contract: provable-contracts/contracts/rmsnorm-kernel-v1.yaml
15//! Contract: provable-contracts/contracts/layernorm-kernel-v1.yaml
16//!
17//! # References
18//!
19//! - Zhang & Sennrich (2019). Root Mean Square Layer Normalization.
20//! - Ba, Kiros & Hinton (2016). Layer Normalization.
21
22use crate::error::TruenoError;
23
24// ============================================================================
25// RMSNorm
26// ============================================================================
27
28/// RMSNorm: output_i = x_i / sqrt(mean(x^2) + eps) * gamma_i
29///
30/// Uses AVX2 SIMD with FMA when available, scalar fallback otherwise.
31///
32/// Contract: rmsnorm-kernel-v1, equation "rmsnorm"
33///
34/// # Errors
35///
36/// Returns `Err` if input/gamma/output lengths don't match or are empty.
37pub fn rms_norm(
38    input: &[f32],
39    gamma: &[f32],
40    eps: f32,
41    output: &mut [f32],
42) -> Result<(), TruenoError> {
43    let n = input.len();
44    if n == 0 || n != gamma.len() || n != output.len() {
45        return Err(TruenoError::InvalidInput(format!(
46            "rms_norm size mismatch: input[{}], gamma[{}], output[{}]",
47            n,
48            gamma.len(),
49            output.len()
50        )));
51    }
52
53    // Contract: rmsnorm-kernel-v1.yaml precondition (pv codegen)
54    contract_pre_rmsnorm!(input);
55
56    #[cfg(target_arch = "x86_64")]
57    {
58        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
59            // SAFETY: AVX2+FMA verified by feature detection above.
60            unsafe {
61                rms_norm_avx2(input, gamma, eps, output);
62            }
63            contract_post_rmsnorm!(output);
64            return Ok(());
65        }
66    }
67
68    rms_norm_scalar(input, gamma, eps, output);
69    contract_post_rmsnorm!(output);
70    Ok(())
71}
72
73/// Scalar RMSNorm implementation.
74fn rms_norm_scalar(input: &[f32], gamma: &[f32], eps: f32, output: &mut [f32]) {
75    let n = input.len();
76
77    // Phase 1: sum of squares
78    let mut sum_sq = 0.0_f32;
79    for &x in input {
80        sum_sq += x * x;
81    }
82
83    // Phase 2: inverse RMS
84    let inv_rms = 1.0 / (sum_sq / n as f32 + eps).sqrt();
85
86    // Phase 3: normalize and scale
87    for i in 0..n {
88        output[i] = input[i] * inv_rms * gamma[i];
89    }
90}
91
92/// AVX2+FMA RMSNorm implementation.
93///
94/// Two-accumulator reduction hides FMA latency (5 cycles on Zen3/4, 4 on Intel).
95/// Single vectorized pass for normalize+scale.
96///
97/// # Safety
98///
99/// Requires AVX2 and FMA support.
100#[cfg(target_arch = "x86_64")]
101#[target_feature(enable = "avx2,fma")]
102unsafe fn rms_norm_avx2(input: &[f32], gamma: &[f32], eps: f32, output: &mut [f32]) {
103    use std::arch::x86_64::*;
104
105    let n = input.len();
106    let chunks = n / 16; // Process 16 elements (2×8) per iteration
107    let remainder_16 = chunks * 16;
108
109    unsafe {
110        // Phase 1: sum of squares with two accumulators
111        let mut acc0 = _mm256_setzero_ps();
112        let mut acc1 = _mm256_setzero_ps();
113
114        for i in 0..chunks {
115            let v0 = _mm256_loadu_ps(input.as_ptr().add(i * 16));
116            let v1 = _mm256_loadu_ps(input.as_ptr().add(i * 16 + 8));
117            acc0 = _mm256_fmadd_ps(v0, v0, acc0);
118            acc1 = _mm256_fmadd_ps(v1, v1, acc1);
119        }
120
121        // Handle 8-element remainder chunk
122        let mut sum_sq;
123        if remainder_16 + 8 <= n {
124            let v = _mm256_loadu_ps(input.as_ptr().add(remainder_16));
125            acc0 = _mm256_fmadd_ps(v, v, acc0);
126            let combined = _mm256_add_ps(acc0, acc1);
127
128            // Horizontal sum: 128-bit halves, then pairs, then scalar
129            let hi = _mm256_extractf128_ps(combined, 1);
130            let lo = _mm256_castps256_ps128(combined);
131            let sum128 = _mm_add_ps(lo, hi);
132            let shuf = _mm_movehdup_ps(sum128);
133            let sums = _mm_add_ps(sum128, shuf);
134            let shuf2 = _mm_movehl_ps(sums, sums);
135            let sums2 = _mm_add_ss(sums, shuf2);
136            sum_sq = _mm_cvtss_f32(sums2);
137
138            // Scalar tail after remainder_16 + 8
139            for i in (remainder_16 + 8)..n {
140                sum_sq += input[i] * input[i];
141            }
142        } else {
143            let combined = _mm256_add_ps(acc0, acc1);
144            let hi = _mm256_extractf128_ps(combined, 1);
145            let lo = _mm256_castps256_ps128(combined);
146            let sum128 = _mm_add_ps(lo, hi);
147            let shuf = _mm_movehdup_ps(sum128);
148            let sums = _mm_add_ps(sum128, shuf);
149            let shuf2 = _mm_movehl_ps(sums, sums);
150            let sums2 = _mm_add_ss(sums, shuf2);
151            sum_sq = _mm_cvtss_f32(sums2);
152
153            for i in remainder_16..n {
154                sum_sq += input[i] * input[i];
155            }
156        }
157
158        // Phase 2: inverse RMS (scalar — single instruction, not worth SIMD)
159        let inv_rms = 1.0 / (sum_sq / n as f32 + eps).sqrt();
160
161        // Phase 3: normalize and scale using AVX2
162        let inv_rms_vec = _mm256_set1_ps(inv_rms);
163        let chunks_out = n / 8;
164        let remainder_out = chunks_out * 8;
165
166        for i in 0..chunks_out {
167            let x = _mm256_loadu_ps(input.as_ptr().add(i * 8));
168            let g = _mm256_loadu_ps(gamma.as_ptr().add(i * 8));
169            // output = x * inv_rms * gamma = (x * inv_rms) * gamma
170            let normed = _mm256_mul_ps(x, inv_rms_vec);
171            let scaled = _mm256_mul_ps(normed, g);
172            _mm256_storeu_ps(output.as_mut_ptr().add(i * 8), scaled);
173        }
174
175        // Scalar tail
176        for i in remainder_out..n {
177            output[i] = input[i] * inv_rms * gamma[i];
178        }
179    }
180}
181
182// ============================================================================
183// LayerNorm
184// ============================================================================
185
186/// LayerNorm: output_i = gamma_i * (x_i - mean) / sqrt(var + eps) + beta_i
187///
188/// Uses AVX2 SIMD when available, scalar fallback otherwise.
189///
190/// Contract: layernorm-kernel-v1, equation "layernorm"
191///
192/// # Errors
193///
194/// Returns `Err` if input/gamma/beta/output lengths don't match or are empty.
195pub fn layer_norm(
196    input: &[f32],
197    gamma: &[f32],
198    beta: &[f32],
199    eps: f32,
200    output: &mut [f32],
201) -> Result<(), TruenoError> {
202    let n = input.len();
203    if n == 0 || n != gamma.len() || n != beta.len() || n != output.len() {
204        return Err(TruenoError::InvalidInput(format!(
205            "layer_norm size mismatch: input[{}], gamma[{}], beta[{}], output[{}]",
206            n,
207            gamma.len(),
208            beta.len(),
209            output.len()
210        )));
211    }
212
213    #[cfg(target_arch = "x86_64")]
214    {
215        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
216            // SAFETY: AVX2+FMA verified by feature detection above.
217            unsafe {
218                layer_norm_avx2(input, gamma, beta, eps, output);
219            }
220            return Ok(());
221        }
222    }
223
224    layer_norm_scalar(input, gamma, beta, eps, output);
225    Ok(())
226}
227
228/// Scalar LayerNorm implementation.
229fn layer_norm_scalar(input: &[f32], gamma: &[f32], beta: &[f32], eps: f32, output: &mut [f32]) {
230    let n = input.len();
231
232    // Phase 1: mean
233    let mut sum = 0.0_f32;
234    for &x in input {
235        sum += x;
236    }
237    let mean = sum / n as f32;
238
239    // Phase 2: variance
240    let mut var_sum = 0.0_f32;
241    for &x in input {
242        let d = x - mean;
243        var_sum += d * d;
244    }
245    let inv_std = 1.0 / (var_sum / n as f32 + eps).sqrt();
246
247    // Phase 3: normalize + affine
248    for i in 0..n {
249        output[i] = gamma[i] * (input[i] - mean) * inv_std + beta[i];
250    }
251}
252
253/// AVX2+FMA LayerNorm implementation.
254///
255/// Two-pass: (1) mean via SIMD sum, (2) variance via SIMD FMA,
256/// then vectorized normalize + affine transform.
257///
258/// # Safety
259///
260/// Requires AVX2 and FMA support.
261#[cfg(target_arch = "x86_64")]
262#[target_feature(enable = "avx2,fma")]
263unsafe fn layer_norm_avx2(
264    input: &[f32],
265    gamma: &[f32],
266    beta: &[f32],
267    eps: f32,
268    output: &mut [f32],
269) {
270    use std::arch::x86_64::*;
271
272    let n = input.len();
273    let chunks = n / 8;
274    let remainder = chunks * 8;
275
276    unsafe {
277        // Phase 1: compute mean with AVX2
278        let mut sum_vec = _mm256_setzero_ps();
279        for i in 0..chunks {
280            let v = _mm256_loadu_ps(input.as_ptr().add(i * 8));
281            sum_vec = _mm256_add_ps(sum_vec, v);
282        }
283
284        // Horizontal sum
285        let hi = _mm256_extractf128_ps(sum_vec, 1);
286        let lo = _mm256_castps256_ps128(sum_vec);
287        let sum128 = _mm_add_ps(lo, hi);
288        let shuf = _mm_movehdup_ps(sum128);
289        let sums = _mm_add_ps(sum128, shuf);
290        let shuf2 = _mm_movehl_ps(sums, sums);
291        let sums2 = _mm_add_ss(sums, shuf2);
292        let mut sum = _mm_cvtss_f32(sums2);
293
294        for i in remainder..n {
295            sum += input[i];
296        }
297        let mean = sum / n as f32;
298
299        // Phase 2: compute variance with AVX2 FMA
300        let mean_vec = _mm256_set1_ps(mean);
301        let mut var_vec0 = _mm256_setzero_ps();
302        let mut var_vec1 = _mm256_setzero_ps();
303        let chunks2 = n / 16;
304        let remainder2 = chunks2 * 16;
305
306        for i in 0..chunks2 {
307            let v0 = _mm256_loadu_ps(input.as_ptr().add(i * 16));
308            let v1 = _mm256_loadu_ps(input.as_ptr().add(i * 16 + 8));
309            let d0 = _mm256_sub_ps(v0, mean_vec);
310            let d1 = _mm256_sub_ps(v1, mean_vec);
311            var_vec0 = _mm256_fmadd_ps(d0, d0, var_vec0);
312            var_vec1 = _mm256_fmadd_ps(d1, d1, var_vec1);
313        }
314
315        // Handle 8-element remainder
316        let mut var_sum;
317        if remainder2 + 8 <= n {
318            let v = _mm256_loadu_ps(input.as_ptr().add(remainder2));
319            let d = _mm256_sub_ps(v, mean_vec);
320            var_vec0 = _mm256_fmadd_ps(d, d, var_vec0);
321
322            let combined = _mm256_add_ps(var_vec0, var_vec1);
323            let hi2 = _mm256_extractf128_ps(combined, 1);
324            let lo2 = _mm256_castps256_ps128(combined);
325            let s128 = _mm_add_ps(lo2, hi2);
326            let sh = _mm_movehdup_ps(s128);
327            let ss = _mm_add_ps(s128, sh);
328            let sh2 = _mm_movehl_ps(ss, ss);
329            let ss2 = _mm_add_ss(ss, sh2);
330            var_sum = _mm_cvtss_f32(ss2);
331
332            for i in (remainder2 + 8)..n {
333                let d = input[i] - mean;
334                var_sum += d * d;
335            }
336        } else {
337            let combined = _mm256_add_ps(var_vec0, var_vec1);
338            let hi2 = _mm256_extractf128_ps(combined, 1);
339            let lo2 = _mm256_castps256_ps128(combined);
340            let s128 = _mm_add_ps(lo2, hi2);
341            let sh = _mm_movehdup_ps(s128);
342            let ss = _mm_add_ps(s128, sh);
343            let sh2 = _mm_movehl_ps(ss, ss);
344            let ss2 = _mm_add_ss(ss, sh2);
345            var_sum = _mm_cvtss_f32(ss2);
346
347            for i in remainder2..n {
348                let d = input[i] - mean;
349                var_sum += d * d;
350            }
351        }
352
353        let inv_std = 1.0 / (var_sum / n as f32 + eps).sqrt();
354
355        // Phase 3: normalize + affine with AVX2
356        let inv_std_vec = _mm256_set1_ps(inv_std);
357        for i in 0..chunks {
358            let x = _mm256_loadu_ps(input.as_ptr().add(i * 8));
359            let g = _mm256_loadu_ps(gamma.as_ptr().add(i * 8));
360            let b = _mm256_loadu_ps(beta.as_ptr().add(i * 8));
361            let centered = _mm256_sub_ps(x, mean_vec);
362            let normed = _mm256_mul_ps(centered, inv_std_vec);
363            // output = gamma * normed + beta
364            let result = _mm256_fmadd_ps(g, normed, b);
365            _mm256_storeu_ps(output.as_mut_ptr().add(i * 8), result);
366        }
367
368        // Scalar tail
369        for i in remainder..n {
370            output[i] = gamma[i] * (input[i] - mean) * inv_std + beta[i];
371        }
372    }
373}
374
375// ============================================================================
376// Allocating variants
377// ============================================================================
378
379/// RMSNorm with output allocation. Avoids zero-fill overhead.
380///
381/// # Panics
382///
383/// Panics if input and gamma have different lengths.
384#[must_use]
385pub fn rms_norm_alloc(input: &[f32], gamma: &[f32], eps: f32) -> Vec<f32> {
386    let n = input.len();
387    let mut output = vec![0.0f32; n];
388    rms_norm(input, gamma, eps, &mut output).expect("rms_norm_alloc: length mismatch");
389    output
390}
391
392/// LayerNorm with output allocation. Avoids zero-fill overhead.
393///
394/// # Panics
395///
396/// Panics if input, gamma, and beta have different lengths.
397#[must_use]
398pub fn layer_norm_alloc(input: &[f32], gamma: &[f32], beta: &[f32], eps: f32) -> Vec<f32> {
399    let n = input.len();
400    let mut output = vec![0.0f32; n];
401    layer_norm(input, gamma, beta, eps, &mut output).expect("layer_norm_alloc: length mismatch");
402    output
403}
404
405// ============================================================================
406// Tests
407// ============================================================================
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412
413    // ── RMSNorm tests ─────────────────────────────────────────────────────
414
415    /// FALSIFY-RN-001: Finiteness
416    #[test]
417    fn test_rmsnorm_finiteness() {
418        for n in [4, 8, 16, 32, 64, 128, 4096] {
419            let input: Vec<f32> =
420                (0..n).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect();
421            let gamma = vec![1.0f32; n];
422            let mut output = vec![0.0f32; n];
423            rms_norm(&input, &gamma, 1e-5, &mut output).unwrap();
424            for (i, &o) in output.iter().enumerate() {
425                assert!(o.is_finite(), "RMSNorm output[{i}] not finite for n={n}");
426            }
427        }
428    }
429
430    /// FALSIFY-RN-002: Scale invariance
431    #[test]
432    fn test_rmsnorm_scale_invariance() {
433        let input: Vec<f32> = (0..64).map(|i| (i as f32) * 0.1 + 0.1).collect();
434        let gamma = vec![1.0f32; 64];
435        let mut out1 = vec![0.0f32; 64];
436        let mut out2 = vec![0.0f32; 64];
437
438        rms_norm(&input, &gamma, 1e-8, &mut out1).unwrap();
439
440        let scaled: Vec<f32> = input.iter().map(|&x| x * 3.7).collect();
441        rms_norm(&scaled, &gamma, 1e-8, &mut out2).unwrap();
442
443        for i in 0..64 {
444            assert!(
445                (out1[i] - out2[i]).abs() < 1e-4,
446                "Scale invariance failed at {i}: {} vs {}",
447                out1[i],
448                out2[i]
449            );
450        }
451    }
452
453    /// FALSIFY-RN-003: AVX2 vs scalar parity
454    #[test]
455    fn test_rmsnorm_avx2_scalar_parity() {
456        for n in [4, 7, 8, 16, 31, 64, 128, 4096] {
457            let input: Vec<f32> =
458                (0..n).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect();
459            let gamma: Vec<f32> = (0..n).map(|i| 0.5 + (i % 5) as f32 * 0.2).collect();
460            let mut scalar_out = vec![0.0f32; n];
461            let mut dispatch_out = vec![0.0f32; n];
462
463            rms_norm_scalar(&input, &gamma, 1e-5, &mut scalar_out);
464            rms_norm(&input, &gamma, 1e-5, &mut dispatch_out).unwrap();
465
466            for i in 0..n {
467                let diff = (scalar_out[i] - dispatch_out[i]).abs();
468                assert!(
469                    diff < 1e-4,
470                    "RMSNorm parity failed at [{i}] n={n}: scalar={} dispatch={} diff={}",
471                    scalar_out[i],
472                    dispatch_out[i],
473                    diff
474                );
475            }
476        }
477    }
478
479    /// FALSIFY-RN-004: Zero vector
480    #[test]
481    fn test_rmsnorm_zero_input() {
482        let input = vec![0.0f32; 16];
483        let gamma = vec![1.0f32; 16];
484        let mut output = vec![0.0f32; 16];
485        rms_norm(&input, &gamma, 1e-5, &mut output).unwrap();
486        for (i, &o) in output.iter().enumerate() {
487            assert!(o.is_finite(), "Zero input produced non-finite at {i}");
488            assert!(o.abs() < 1e-2, "Zero input should produce ~0 at {i}, got {o}");
489        }
490    }
491
492    /// FALSIFY-RN-005: Unit gamma normalized RMS ≈ 1
493    #[test]
494    fn test_rmsnorm_unit_gamma_normalized_rms() {
495        let input: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1 + 0.1).collect();
496        let gamma = vec![1.0f32; 128];
497        let mut output = vec![0.0f32; 128];
498        rms_norm(&input, &gamma, 1e-8, &mut output).unwrap();
499
500        let sum_sq: f32 = output.iter().map(|x| x * x).sum();
501        let rms_out = (sum_sq / output.len() as f32).sqrt();
502        assert!((rms_out - 1.0).abs() < 1e-3, "RMS of output = {rms_out}, expected ~1.0");
503    }
504
505    #[test]
506    fn test_rmsnorm_error_on_mismatch() {
507        let input = vec![1.0f32; 4];
508        let gamma = vec![1.0f32; 3];
509        let mut output = vec![0.0f32; 4];
510        assert!(rms_norm(&input, &gamma, 1e-5, &mut output).is_err());
511    }
512
513    #[test]
514    fn test_rmsnorm_error_on_empty() {
515        let input: Vec<f32> = vec![];
516        let gamma: Vec<f32> = vec![];
517        let mut output: Vec<f32> = vec![];
518        assert!(rms_norm(&input, &gamma, 1e-5, &mut output).is_err());
519    }
520
521    // ── LayerNorm tests ───────────────────────────────────────────────────
522
523    /// FALSIFY-LN-001: Finiteness
524    #[test]
525    fn test_layernorm_finiteness() {
526        for n in [4, 8, 16, 32, 64, 128, 4096] {
527            let input: Vec<f32> =
528                (0..n).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect();
529            let gamma = vec![1.0f32; n];
530            let beta = vec![0.0f32; n];
531            let mut output = vec![0.0f32; n];
532            layer_norm(&input, &gamma, &beta, 1e-5, &mut output).unwrap();
533            for (i, &o) in output.iter().enumerate() {
534                assert!(o.is_finite(), "LayerNorm output[{i}] not finite for n={n}");
535            }
536        }
537    }
538
539    /// FALSIFY-LN-002: Zero mean (with gamma=1, beta=0)
540    #[test]
541    fn test_layernorm_zero_mean() {
542        for n in [16, 64, 128, 4096] {
543            let input: Vec<f32> =
544                (0..n).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect();
545            let gamma = vec![1.0f32; n];
546            let beta = vec![0.0f32; n];
547            let mut output = vec![0.0f32; n];
548            layer_norm(&input, &gamma, &beta, 1e-5, &mut output).unwrap();
549
550            let mean: f32 = output.iter().sum::<f32>() / n as f32;
551            assert!(mean.abs() < 1e-4, "LayerNorm output mean = {mean}, expected ~0 for n={n}");
552        }
553    }
554
555    /// FALSIFY-LN-003: Unit variance (with gamma=1, beta=0)
556    #[test]
557    fn test_layernorm_unit_variance() {
558        for n in [16, 64, 128, 4096] {
559            let input: Vec<f32> =
560                (0..n).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect();
561            let gamma = vec![1.0f32; n];
562            let beta = vec![0.0f32; n];
563            let mut output = vec![0.0f32; n];
564            layer_norm(&input, &gamma, &beta, 1e-5, &mut output).unwrap();
565
566            let mean: f32 = output.iter().sum::<f32>() / n as f32;
567            let var: f32 = output.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>() / n as f32;
568            assert!(
569                (var - 1.0).abs() < 1e-2,
570                "LayerNorm output var = {var}, expected ~1.0 for n={n}"
571            );
572        }
573    }
574
575    /// FALSIFY-LN-004: Shift invariance
576    #[test]
577    fn test_layernorm_shift_invariance() {
578        let input: Vec<f32> = (0..64).map(|i| (i as f32) * 0.1).collect();
579        let gamma = vec![1.0f32; 64];
580        let beta = vec![0.0f32; 64];
581        let mut out1 = vec![0.0f32; 64];
582        let mut out2 = vec![0.0f32; 64];
583
584        layer_norm(&input, &gamma, &beta, 1e-5, &mut out1).unwrap();
585
586        let shifted: Vec<f32> = input.iter().map(|&x| x + 42.0).collect();
587        layer_norm(&shifted, &gamma, &beta, 1e-5, &mut out2).unwrap();
588
589        for i in 0..64 {
590            assert!(
591                (out1[i] - out2[i]).abs() < 1e-3,
592                "Shift invariance failed at {i}: {} vs {}",
593                out1[i],
594                out2[i]
595            );
596        }
597    }
598
599    /// FALSIFY-LN-005: AVX2 vs scalar parity
600    #[test]
601    fn test_layernorm_avx2_scalar_parity() {
602        for n in [4, 7, 8, 16, 31, 64, 128, 4096] {
603            let input: Vec<f32> =
604                (0..n).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect();
605            let gamma: Vec<f32> = (0..n).map(|i| 0.5 + (i % 5) as f32 * 0.2).collect();
606            let beta: Vec<f32> = (0..n).map(|i| (i % 3) as f32 * 0.1 - 0.1).collect();
607            let mut scalar_out = vec![0.0f32; n];
608            let mut dispatch_out = vec![0.0f32; n];
609
610            layer_norm_scalar(&input, &gamma, &beta, 1e-5, &mut scalar_out);
611            layer_norm(&input, &gamma, &beta, 1e-5, &mut dispatch_out).unwrap();
612
613            for i in 0..n {
614                let diff = (scalar_out[i] - dispatch_out[i]).abs();
615                assert!(
616                    diff < 1e-4,
617                    "LayerNorm parity failed at [{i}] n={n}: scalar={} dispatch={} diff={}",
618                    scalar_out[i],
619                    dispatch_out[i],
620                    diff
621                );
622            }
623        }
624    }
625
626    /// FALSIFY-LN-006: Constant input → output = beta
627    #[test]
628    fn test_layernorm_constant_input() {
629        let input = vec![5.0f32; 32];
630        let gamma = vec![1.0f32; 32];
631        let beta: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
632        let mut output = vec![0.0f32; 32];
633        layer_norm(&input, &gamma, &beta, 1e-5, &mut output).unwrap();
634        for (i, (&o, &b)) in output.iter().zip(beta.iter()).enumerate() {
635            assert!((o - b).abs() < 1e-3, "Constant input: output[{i}]={o}, expected ~beta={b}");
636        }
637    }
638
639    #[test]
640    fn test_layernorm_error_on_mismatch() {
641        let input = vec![1.0f32; 4];
642        let gamma = vec![1.0f32; 3];
643        let beta = vec![0.0f32; 4];
644        let mut output = vec![0.0f32; 4];
645        assert!(layer_norm(&input, &gamma, &beta, 1e-5, &mut output).is_err());
646    }
647
648    #[test]
649    fn test_layernorm_error_on_empty() {
650        let input: Vec<f32> = vec![];
651        let gamma: Vec<f32> = vec![];
652        let beta: Vec<f32> = vec![];
653        let mut output: Vec<f32> = vec![];
654        assert!(layer_norm(&input, &gamma, &beta, 1e-5, &mut output).is_err());
655    }
656}