1use crate::error::TruenoError;
23
24pub fn rms_norm(
38 input: &[f32],
39 gamma: &[f32],
40 eps: f32,
41 output: &mut [f32],
42) -> Result<(), TruenoError> {
43 let n = input.len();
44 if n == 0 || n != gamma.len() || n != output.len() {
45 return Err(TruenoError::InvalidInput(format!(
46 "rms_norm size mismatch: input[{}], gamma[{}], output[{}]",
47 n,
48 gamma.len(),
49 output.len()
50 )));
51 }
52
53 contract_pre_rmsnorm!(input);
55
56 #[cfg(target_arch = "x86_64")]
57 {
58 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
59 unsafe {
61 rms_norm_avx2(input, gamma, eps, output);
62 }
63 contract_post_rmsnorm!(output);
64 return Ok(());
65 }
66 }
67
68 rms_norm_scalar(input, gamma, eps, output);
69 contract_post_rmsnorm!(output);
70 Ok(())
71}
72
73fn rms_norm_scalar(input: &[f32], gamma: &[f32], eps: f32, output: &mut [f32]) {
75 let n = input.len();
76
77 let mut sum_sq = 0.0_f32;
79 for &x in input {
80 sum_sq += x * x;
81 }
82
83 let inv_rms = 1.0 / (sum_sq / n as f32 + eps).sqrt();
85
86 for i in 0..n {
88 output[i] = input[i] * inv_rms * gamma[i];
89 }
90}
91
92#[cfg(target_arch = "x86_64")]
101#[target_feature(enable = "avx2,fma")]
102unsafe fn rms_norm_avx2(input: &[f32], gamma: &[f32], eps: f32, output: &mut [f32]) {
103 use std::arch::x86_64::*;
104
105 let n = input.len();
106 let chunks = n / 16; let remainder_16 = chunks * 16;
108
109 unsafe {
110 let mut acc0 = _mm256_setzero_ps();
112 let mut acc1 = _mm256_setzero_ps();
113
114 for i in 0..chunks {
115 let v0 = _mm256_loadu_ps(input.as_ptr().add(i * 16));
116 let v1 = _mm256_loadu_ps(input.as_ptr().add(i * 16 + 8));
117 acc0 = _mm256_fmadd_ps(v0, v0, acc0);
118 acc1 = _mm256_fmadd_ps(v1, v1, acc1);
119 }
120
121 let mut sum_sq;
123 if remainder_16 + 8 <= n {
124 let v = _mm256_loadu_ps(input.as_ptr().add(remainder_16));
125 acc0 = _mm256_fmadd_ps(v, v, acc0);
126 let combined = _mm256_add_ps(acc0, acc1);
127
128 let hi = _mm256_extractf128_ps(combined, 1);
130 let lo = _mm256_castps256_ps128(combined);
131 let sum128 = _mm_add_ps(lo, hi);
132 let shuf = _mm_movehdup_ps(sum128);
133 let sums = _mm_add_ps(sum128, shuf);
134 let shuf2 = _mm_movehl_ps(sums, sums);
135 let sums2 = _mm_add_ss(sums, shuf2);
136 sum_sq = _mm_cvtss_f32(sums2);
137
138 for i in (remainder_16 + 8)..n {
140 sum_sq += input[i] * input[i];
141 }
142 } else {
143 let combined = _mm256_add_ps(acc0, acc1);
144 let hi = _mm256_extractf128_ps(combined, 1);
145 let lo = _mm256_castps256_ps128(combined);
146 let sum128 = _mm_add_ps(lo, hi);
147 let shuf = _mm_movehdup_ps(sum128);
148 let sums = _mm_add_ps(sum128, shuf);
149 let shuf2 = _mm_movehl_ps(sums, sums);
150 let sums2 = _mm_add_ss(sums, shuf2);
151 sum_sq = _mm_cvtss_f32(sums2);
152
153 for i in remainder_16..n {
154 sum_sq += input[i] * input[i];
155 }
156 }
157
158 let inv_rms = 1.0 / (sum_sq / n as f32 + eps).sqrt();
160
161 let inv_rms_vec = _mm256_set1_ps(inv_rms);
163 let chunks_out = n / 8;
164 let remainder_out = chunks_out * 8;
165
166 for i in 0..chunks_out {
167 let x = _mm256_loadu_ps(input.as_ptr().add(i * 8));
168 let g = _mm256_loadu_ps(gamma.as_ptr().add(i * 8));
169 let normed = _mm256_mul_ps(x, inv_rms_vec);
171 let scaled = _mm256_mul_ps(normed, g);
172 _mm256_storeu_ps(output.as_mut_ptr().add(i * 8), scaled);
173 }
174
175 for i in remainder_out..n {
177 output[i] = input[i] * inv_rms * gamma[i];
178 }
179 }
180}
181
182pub fn layer_norm(
196 input: &[f32],
197 gamma: &[f32],
198 beta: &[f32],
199 eps: f32,
200 output: &mut [f32],
201) -> Result<(), TruenoError> {
202 let n = input.len();
203 if n == 0 || n != gamma.len() || n != beta.len() || n != output.len() {
204 return Err(TruenoError::InvalidInput(format!(
205 "layer_norm size mismatch: input[{}], gamma[{}], beta[{}], output[{}]",
206 n,
207 gamma.len(),
208 beta.len(),
209 output.len()
210 )));
211 }
212
213 #[cfg(target_arch = "x86_64")]
214 {
215 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
216 unsafe {
218 layer_norm_avx2(input, gamma, beta, eps, output);
219 }
220 return Ok(());
221 }
222 }
223
224 layer_norm_scalar(input, gamma, beta, eps, output);
225 Ok(())
226}
227
228fn layer_norm_scalar(input: &[f32], gamma: &[f32], beta: &[f32], eps: f32, output: &mut [f32]) {
230 let n = input.len();
231
232 let mut sum = 0.0_f32;
234 for &x in input {
235 sum += x;
236 }
237 let mean = sum / n as f32;
238
239 let mut var_sum = 0.0_f32;
241 for &x in input {
242 let d = x - mean;
243 var_sum += d * d;
244 }
245 let inv_std = 1.0 / (var_sum / n as f32 + eps).sqrt();
246
247 for i in 0..n {
249 output[i] = gamma[i] * (input[i] - mean) * inv_std + beta[i];
250 }
251}
252
253#[cfg(target_arch = "x86_64")]
262#[target_feature(enable = "avx2,fma")]
263unsafe fn layer_norm_avx2(
264 input: &[f32],
265 gamma: &[f32],
266 beta: &[f32],
267 eps: f32,
268 output: &mut [f32],
269) {
270 use std::arch::x86_64::*;
271
272 let n = input.len();
273 let chunks = n / 8;
274 let remainder = chunks * 8;
275
276 unsafe {
277 let mut sum_vec = _mm256_setzero_ps();
279 for i in 0..chunks {
280 let v = _mm256_loadu_ps(input.as_ptr().add(i * 8));
281 sum_vec = _mm256_add_ps(sum_vec, v);
282 }
283
284 let hi = _mm256_extractf128_ps(sum_vec, 1);
286 let lo = _mm256_castps256_ps128(sum_vec);
287 let sum128 = _mm_add_ps(lo, hi);
288 let shuf = _mm_movehdup_ps(sum128);
289 let sums = _mm_add_ps(sum128, shuf);
290 let shuf2 = _mm_movehl_ps(sums, sums);
291 let sums2 = _mm_add_ss(sums, shuf2);
292 let mut sum = _mm_cvtss_f32(sums2);
293
294 for i in remainder..n {
295 sum += input[i];
296 }
297 let mean = sum / n as f32;
298
299 let mean_vec = _mm256_set1_ps(mean);
301 let mut var_vec0 = _mm256_setzero_ps();
302 let mut var_vec1 = _mm256_setzero_ps();
303 let chunks2 = n / 16;
304 let remainder2 = chunks2 * 16;
305
306 for i in 0..chunks2 {
307 let v0 = _mm256_loadu_ps(input.as_ptr().add(i * 16));
308 let v1 = _mm256_loadu_ps(input.as_ptr().add(i * 16 + 8));
309 let d0 = _mm256_sub_ps(v0, mean_vec);
310 let d1 = _mm256_sub_ps(v1, mean_vec);
311 var_vec0 = _mm256_fmadd_ps(d0, d0, var_vec0);
312 var_vec1 = _mm256_fmadd_ps(d1, d1, var_vec1);
313 }
314
315 let mut var_sum;
317 if remainder2 + 8 <= n {
318 let v = _mm256_loadu_ps(input.as_ptr().add(remainder2));
319 let d = _mm256_sub_ps(v, mean_vec);
320 var_vec0 = _mm256_fmadd_ps(d, d, var_vec0);
321
322 let combined = _mm256_add_ps(var_vec0, var_vec1);
323 let hi2 = _mm256_extractf128_ps(combined, 1);
324 let lo2 = _mm256_castps256_ps128(combined);
325 let s128 = _mm_add_ps(lo2, hi2);
326 let sh = _mm_movehdup_ps(s128);
327 let ss = _mm_add_ps(s128, sh);
328 let sh2 = _mm_movehl_ps(ss, ss);
329 let ss2 = _mm_add_ss(ss, sh2);
330 var_sum = _mm_cvtss_f32(ss2);
331
332 for i in (remainder2 + 8)..n {
333 let d = input[i] - mean;
334 var_sum += d * d;
335 }
336 } else {
337 let combined = _mm256_add_ps(var_vec0, var_vec1);
338 let hi2 = _mm256_extractf128_ps(combined, 1);
339 let lo2 = _mm256_castps256_ps128(combined);
340 let s128 = _mm_add_ps(lo2, hi2);
341 let sh = _mm_movehdup_ps(s128);
342 let ss = _mm_add_ps(s128, sh);
343 let sh2 = _mm_movehl_ps(ss, ss);
344 let ss2 = _mm_add_ss(ss, sh2);
345 var_sum = _mm_cvtss_f32(ss2);
346
347 for i in remainder2..n {
348 let d = input[i] - mean;
349 var_sum += d * d;
350 }
351 }
352
353 let inv_std = 1.0 / (var_sum / n as f32 + eps).sqrt();
354
355 let inv_std_vec = _mm256_set1_ps(inv_std);
357 for i in 0..chunks {
358 let x = _mm256_loadu_ps(input.as_ptr().add(i * 8));
359 let g = _mm256_loadu_ps(gamma.as_ptr().add(i * 8));
360 let b = _mm256_loadu_ps(beta.as_ptr().add(i * 8));
361 let centered = _mm256_sub_ps(x, mean_vec);
362 let normed = _mm256_mul_ps(centered, inv_std_vec);
363 let result = _mm256_fmadd_ps(g, normed, b);
365 _mm256_storeu_ps(output.as_mut_ptr().add(i * 8), result);
366 }
367
368 for i in remainder..n {
370 output[i] = gamma[i] * (input[i] - mean) * inv_std + beta[i];
371 }
372 }
373}
374
375#[must_use]
385pub fn rms_norm_alloc(input: &[f32], gamma: &[f32], eps: f32) -> Vec<f32> {
386 let n = input.len();
387 let mut output = vec![0.0f32; n];
388 rms_norm(input, gamma, eps, &mut output).expect("rms_norm_alloc: length mismatch");
389 output
390}
391
392#[must_use]
398pub fn layer_norm_alloc(input: &[f32], gamma: &[f32], beta: &[f32], eps: f32) -> Vec<f32> {
399 let n = input.len();
400 let mut output = vec![0.0f32; n];
401 layer_norm(input, gamma, beta, eps, &mut output).expect("layer_norm_alloc: length mismatch");
402 output
403}
404
405#[cfg(test)]
410mod tests {
411 use super::*;
412
413 #[test]
417 fn test_rmsnorm_finiteness() {
418 for n in [4, 8, 16, 32, 64, 128, 4096] {
419 let input: Vec<f32> =
420 (0..n).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect();
421 let gamma = vec![1.0f32; n];
422 let mut output = vec![0.0f32; n];
423 rms_norm(&input, &gamma, 1e-5, &mut output).unwrap();
424 for (i, &o) in output.iter().enumerate() {
425 assert!(o.is_finite(), "RMSNorm output[{i}] not finite for n={n}");
426 }
427 }
428 }
429
430 #[test]
432 fn test_rmsnorm_scale_invariance() {
433 let input: Vec<f32> = (0..64).map(|i| (i as f32) * 0.1 + 0.1).collect();
434 let gamma = vec![1.0f32; 64];
435 let mut out1 = vec![0.0f32; 64];
436 let mut out2 = vec![0.0f32; 64];
437
438 rms_norm(&input, &gamma, 1e-8, &mut out1).unwrap();
439
440 let scaled: Vec<f32> = input.iter().map(|&x| x * 3.7).collect();
441 rms_norm(&scaled, &gamma, 1e-8, &mut out2).unwrap();
442
443 for i in 0..64 {
444 assert!(
445 (out1[i] - out2[i]).abs() < 1e-4,
446 "Scale invariance failed at {i}: {} vs {}",
447 out1[i],
448 out2[i]
449 );
450 }
451 }
452
453 #[test]
455 fn test_rmsnorm_avx2_scalar_parity() {
456 for n in [4, 7, 8, 16, 31, 64, 128, 4096] {
457 let input: Vec<f32> =
458 (0..n).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect();
459 let gamma: Vec<f32> = (0..n).map(|i| 0.5 + (i % 5) as f32 * 0.2).collect();
460 let mut scalar_out = vec![0.0f32; n];
461 let mut dispatch_out = vec![0.0f32; n];
462
463 rms_norm_scalar(&input, &gamma, 1e-5, &mut scalar_out);
464 rms_norm(&input, &gamma, 1e-5, &mut dispatch_out).unwrap();
465
466 for i in 0..n {
467 let diff = (scalar_out[i] - dispatch_out[i]).abs();
468 assert!(
469 diff < 1e-4,
470 "RMSNorm parity failed at [{i}] n={n}: scalar={} dispatch={} diff={}",
471 scalar_out[i],
472 dispatch_out[i],
473 diff
474 );
475 }
476 }
477 }
478
479 #[test]
481 fn test_rmsnorm_zero_input() {
482 let input = vec![0.0f32; 16];
483 let gamma = vec![1.0f32; 16];
484 let mut output = vec![0.0f32; 16];
485 rms_norm(&input, &gamma, 1e-5, &mut output).unwrap();
486 for (i, &o) in output.iter().enumerate() {
487 assert!(o.is_finite(), "Zero input produced non-finite at {i}");
488 assert!(o.abs() < 1e-2, "Zero input should produce ~0 at {i}, got {o}");
489 }
490 }
491
492 #[test]
494 fn test_rmsnorm_unit_gamma_normalized_rms() {
495 let input: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1 + 0.1).collect();
496 let gamma = vec![1.0f32; 128];
497 let mut output = vec![0.0f32; 128];
498 rms_norm(&input, &gamma, 1e-8, &mut output).unwrap();
499
500 let sum_sq: f32 = output.iter().map(|x| x * x).sum();
501 let rms_out = (sum_sq / output.len() as f32).sqrt();
502 assert!((rms_out - 1.0).abs() < 1e-3, "RMS of output = {rms_out}, expected ~1.0");
503 }
504
505 #[test]
506 fn test_rmsnorm_error_on_mismatch() {
507 let input = vec![1.0f32; 4];
508 let gamma = vec![1.0f32; 3];
509 let mut output = vec![0.0f32; 4];
510 assert!(rms_norm(&input, &gamma, 1e-5, &mut output).is_err());
511 }
512
513 #[test]
514 fn test_rmsnorm_error_on_empty() {
515 let input: Vec<f32> = vec![];
516 let gamma: Vec<f32> = vec![];
517 let mut output: Vec<f32> = vec![];
518 assert!(rms_norm(&input, &gamma, 1e-5, &mut output).is_err());
519 }
520
521 #[test]
525 fn test_layernorm_finiteness() {
526 for n in [4, 8, 16, 32, 64, 128, 4096] {
527 let input: Vec<f32> =
528 (0..n).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect();
529 let gamma = vec![1.0f32; n];
530 let beta = vec![0.0f32; n];
531 let mut output = vec![0.0f32; n];
532 layer_norm(&input, &gamma, &beta, 1e-5, &mut output).unwrap();
533 for (i, &o) in output.iter().enumerate() {
534 assert!(o.is_finite(), "LayerNorm output[{i}] not finite for n={n}");
535 }
536 }
537 }
538
539 #[test]
541 fn test_layernorm_zero_mean() {
542 for n in [16, 64, 128, 4096] {
543 let input: Vec<f32> =
544 (0..n).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect();
545 let gamma = vec![1.0f32; n];
546 let beta = vec![0.0f32; n];
547 let mut output = vec![0.0f32; n];
548 layer_norm(&input, &gamma, &beta, 1e-5, &mut output).unwrap();
549
550 let mean: f32 = output.iter().sum::<f32>() / n as f32;
551 assert!(mean.abs() < 1e-4, "LayerNorm output mean = {mean}, expected ~0 for n={n}");
552 }
553 }
554
555 #[test]
557 fn test_layernorm_unit_variance() {
558 for n in [16, 64, 128, 4096] {
559 let input: Vec<f32> =
560 (0..n).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect();
561 let gamma = vec![1.0f32; n];
562 let beta = vec![0.0f32; n];
563 let mut output = vec![0.0f32; n];
564 layer_norm(&input, &gamma, &beta, 1e-5, &mut output).unwrap();
565
566 let mean: f32 = output.iter().sum::<f32>() / n as f32;
567 let var: f32 = output.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>() / n as f32;
568 assert!(
569 (var - 1.0).abs() < 1e-2,
570 "LayerNorm output var = {var}, expected ~1.0 for n={n}"
571 );
572 }
573 }
574
575 #[test]
577 fn test_layernorm_shift_invariance() {
578 let input: Vec<f32> = (0..64).map(|i| (i as f32) * 0.1).collect();
579 let gamma = vec![1.0f32; 64];
580 let beta = vec![0.0f32; 64];
581 let mut out1 = vec![0.0f32; 64];
582 let mut out2 = vec![0.0f32; 64];
583
584 layer_norm(&input, &gamma, &beta, 1e-5, &mut out1).unwrap();
585
586 let shifted: Vec<f32> = input.iter().map(|&x| x + 42.0).collect();
587 layer_norm(&shifted, &gamma, &beta, 1e-5, &mut out2).unwrap();
588
589 for i in 0..64 {
590 assert!(
591 (out1[i] - out2[i]).abs() < 1e-3,
592 "Shift invariance failed at {i}: {} vs {}",
593 out1[i],
594 out2[i]
595 );
596 }
597 }
598
599 #[test]
601 fn test_layernorm_avx2_scalar_parity() {
602 for n in [4, 7, 8, 16, 31, 64, 128, 4096] {
603 let input: Vec<f32> =
604 (0..n).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect();
605 let gamma: Vec<f32> = (0..n).map(|i| 0.5 + (i % 5) as f32 * 0.2).collect();
606 let beta: Vec<f32> = (0..n).map(|i| (i % 3) as f32 * 0.1 - 0.1).collect();
607 let mut scalar_out = vec![0.0f32; n];
608 let mut dispatch_out = vec![0.0f32; n];
609
610 layer_norm_scalar(&input, &gamma, &beta, 1e-5, &mut scalar_out);
611 layer_norm(&input, &gamma, &beta, 1e-5, &mut dispatch_out).unwrap();
612
613 for i in 0..n {
614 let diff = (scalar_out[i] - dispatch_out[i]).abs();
615 assert!(
616 diff < 1e-4,
617 "LayerNorm parity failed at [{i}] n={n}: scalar={} dispatch={} diff={}",
618 scalar_out[i],
619 dispatch_out[i],
620 diff
621 );
622 }
623 }
624 }
625
626 #[test]
628 fn test_layernorm_constant_input() {
629 let input = vec![5.0f32; 32];
630 let gamma = vec![1.0f32; 32];
631 let beta: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
632 let mut output = vec![0.0f32; 32];
633 layer_norm(&input, &gamma, &beta, 1e-5, &mut output).unwrap();
634 for (i, (&o, &b)) in output.iter().zip(beta.iter()).enumerate() {
635 assert!((o - b).abs() < 1e-3, "Constant input: output[{i}]={o}, expected ~beta={b}");
636 }
637 }
638
639 #[test]
640 fn test_layernorm_error_on_mismatch() {
641 let input = vec![1.0f32; 4];
642 let gamma = vec![1.0f32; 3];
643 let beta = vec![0.0f32; 4];
644 let mut output = vec![0.0f32; 4];
645 assert!(layer_norm(&input, &gamma, &beta, 1e-5, &mut output).is_err());
646 }
647
648 #[test]
649 fn test_layernorm_error_on_empty() {
650 let input: Vec<f32> = vec![];
651 let gamma: Vec<f32> = vec![];
652 let beta: Vec<f32> = vec![];
653 let mut output: Vec<f32> = vec![];
654 assert!(layer_norm(&input, &gamma, &beta, 1e-5, &mut output).is_err());
655 }
656}