#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
pub fn layernorm_scalar(input: &[f32], gamma: &[f32], beta: &[f32], eps: f32, output: &mut [f32]) {
let n = input.len();
assert_eq!(n, gamma.len(), "input/gamma length mismatch");
assert_eq!(n, beta.len(), "input/beta length mismatch");
assert_eq!(n, output.len(), "input/output length mismatch");
assert!(n > 0, "layernorm requires non-empty input");
let mut sum = 0.0_f32;
for &x in input {
sum += x;
}
let mean = sum / n as f32;
let mut var_sum = 0.0_f32;
for &x in input {
let diff = x - mean;
var_sum += diff * diff;
}
let variance = var_sum / n as f32;
let inv_std = 1.0 / (variance + eps).sqrt();
for i in 0..n {
output[i] = gamma[i] * (input[i] - mean) * inv_std + beta[i];
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn layernorm_avx2(
input: &[f32],
gamma: &[f32],
beta: &[f32],
eps: f32,
output: &mut [f32],
) {
let n = input.len();
assert_eq!(n, gamma.len(), "input/gamma length mismatch");
assert_eq!(n, beta.len(), "input/beta length mismatch");
assert_eq!(n, output.len(), "input/output length mismatch");
assert!(n > 0, "layernorm requires non-empty input");
let chunks = n / 8;
let remainder_start = chunks * 8;
let mut sum = 0.0_f32;
for &x in input.iter() {
sum += x;
}
let mean = sum / n as f32;
let mut var_sum = 0.0_f32;
for &x in input.iter() {
let diff = x - mean;
var_sum += diff * diff;
}
let variance = var_sum / n as f32;
let inv_std = 1.0 / (variance + eps).sqrt();
unsafe {
let mean_vec = _mm256_set1_ps(mean);
let inv_std_vec = _mm256_set1_ps(inv_std);
for i in 0..chunks {
let x = _mm256_loadu_ps(input.as_ptr().add(i * 8));
let g = _mm256_loadu_ps(gamma.as_ptr().add(i * 8));
let b = _mm256_loadu_ps(beta.as_ptr().add(i * 8));
let centered = _mm256_sub_ps(x, mean_vec);
let normed = _mm256_mul_ps(centered, inv_std_vec);
let scaled = _mm256_mul_ps(normed, g);
let result = _mm256_add_ps(scaled, b);
_mm256_storeu_ps(output.as_mut_ptr().add(i * 8), result);
}
for i in remainder_start..n {
output[i] = gamma[i] * (input[i] - mean) * inv_std + beta[i];
}
}
}
include!("layernorm_ptx.rs");
#[cfg(test)]
mod tests {
use super::super::ulp::assert_ulp_eq;
use super::*;
use proptest::prelude::*;
#[test]
fn test_layernorm_constant_input() {
let input = [5.0_f32, 5.0, 5.0, 5.0];
let gamma = [1.0_f32, 1.0, 1.0, 1.0];
let beta = [0.1_f32, 0.2, 0.3, 0.4];
let mut output = [0.0_f32; 4];
layernorm_scalar(&input, &gamma, &beta, 1e-5, &mut output);
for (i, (&o, &b)) in output.iter().zip(beta.iter()).enumerate() {
assert!((o - b).abs() < 1e-4, "output[{i}] = {o}, expected ~{b}");
}
}
#[test]
fn test_layernorm_simple() {
let input = [1.0_f32, 2.0, 3.0, 4.0];
let gamma = [1.0_f32, 1.0, 1.0, 1.0];
let beta = [0.0_f32, 0.0, 0.0, 0.0];
let mut output = [0.0_f32; 4];
layernorm_scalar(&input, &gamma, &beta, 1e-8, &mut output);
let mean = 2.5_f32;
let std = 1.25_f32.sqrt();
for (i, &x) in input.iter().enumerate() {
let expected = (x - mean) / std;
assert!(
(output[i] - expected).abs() < 1e-4,
"output[{i}] = {}, expected {expected}",
output[i]
);
}
}
#[test]
fn test_layernorm_with_affine() {
let input = [1.0_f32, 3.0];
let gamma = [2.0_f32, 0.5];
let beta = [10.0_f32, -10.0];
let mut output = [0.0_f32; 2];
layernorm_scalar(&input, &gamma, &beta, 1e-8, &mut output);
let mean = 2.0_f32;
let var = 1.0_f32;
let inv_std = 1.0 / (var + 1e-8_f32).sqrt();
let expected0 = 2.0 * (1.0 - mean) * inv_std + 10.0;
let expected1 = 0.5 * (3.0 - mean) * inv_std + (-10.0);
assert!((output[0] - expected0).abs() < 1e-5);
assert!((output[1] - expected1).abs() < 1e-5);
}
#[test]
#[should_panic(expected = "input/gamma length mismatch")]
fn test_layernorm_gamma_mismatch() {
let input = [1.0_f32, 2.0];
let gamma = [1.0_f32];
let beta = [0.0_f32, 0.0];
let mut output = [0.0_f32; 2];
layernorm_scalar(&input, &gamma, &beta, 1e-5, &mut output);
}
#[test]
#[should_panic(expected = "input/beta length mismatch")]
fn test_layernorm_beta_mismatch() {
let input = [1.0_f32, 2.0];
let gamma = [1.0_f32, 1.0];
let beta = [0.0_f32];
let mut output = [0.0_f32; 2];
layernorm_scalar(&input, &gamma, &beta, 1e-5, &mut output);
}
#[test]
#[should_panic(expected = "input/output length mismatch")]
fn test_layernorm_output_mismatch() {
let input = [1.0_f32, 2.0];
let gamma = [1.0_f32, 1.0];
let beta = [0.0_f32, 0.0];
let mut output = [0.0_f32; 3];
layernorm_scalar(&input, &gamma, &beta, 1e-5, &mut output);
}
#[test]
#[should_panic(expected = "layernorm requires non-empty input")]
fn test_layernorm_empty_input() {
let input: [f32; 0] = [];
let gamma: [f32; 0] = [];
let beta: [f32; 0] = [];
let mut output: [f32; 0] = [];
layernorm_scalar(&input, &gamma, &beta, 1e-5, &mut output);
}
proptest! {
#[test]
fn prop_layernorm_zero_mean(
v in proptest::collection::vec(-10.0_f32..10.0, 2..64)
) {
let gamma = vec![1.0_f32; v.len()];
let beta = vec![0.0_f32; v.len()];
let mut output = vec![0.0_f32; v.len()];
layernorm_scalar(&v, &gamma, &beta, 1e-5, &mut output);
let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
prop_assert!(
mean.abs() < 1e-4,
"output mean = {mean}, expected ~0.0"
);
}
#[test]
fn prop_layernorm_unit_variance(
v in proptest::collection::vec(-10.0_f32..10.0, 4..64)
) {
let min = v.iter().copied().fold(f32::INFINITY, f32::min);
let max = v.iter().copied().fold(f32::NEG_INFINITY, f32::max);
if (max - min).abs() < 1e-6 {
return Ok(());
}
let gamma = vec![1.0_f32; v.len()];
let beta = vec![0.0_f32; v.len()];
let mut output = vec![0.0_f32; v.len()];
layernorm_scalar(&v, &gamma, &beta, 1e-5, &mut output);
let n = output.len() as f32;
let mean: f32 = output.iter().sum::<f32>() / n;
let var: f32 = output.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>() / n;
prop_assert!(
(var - 1.0).abs() < 1e-3,
"output variance = {var}, expected ~1.0"
);
}
#[test]
fn prop_layernorm_shift_invariance(
v in proptest::collection::vec(-10.0_f32..10.0, 2..32),
c in -50.0_f32..50.0
) {
let gamma = vec![1.0_f32; v.len()];
let beta = vec![0.0_f32; v.len()];
let mut out1 = vec![0.0_f32; v.len()];
layernorm_scalar(&v, &gamma, &beta, 1e-5, &mut out1);
let shifted: Vec<f32> = v.iter().map(|&x| x + c).collect();
let mut out2 = vec![0.0_f32; v.len()];
layernorm_scalar(&shifted, &gamma, &beta, 1e-5, &mut out2);
for i in 0..v.len() {
prop_assert!(
(out1[i] - out2[i]).abs() < 1e-3,
"shift invariance violated at {i}: {} vs {}",
out1[i], out2[i]
);
}
}
#[test]
fn prop_layernorm_finite_output(
v in proptest::collection::vec(-10.0_f32..10.0, 1..64)
) {
let gamma = vec![1.0_f32; v.len()];
let beta = vec![0.0_f32; v.len()];
let mut output = vec![0.0_f32; v.len()];
layernorm_scalar(&v, &gamma, &beta, 1e-5, &mut output);
for (i, &o) in output.iter().enumerate() {
prop_assert!(o.is_finite(), "output[{i}] = {o} is not finite");
}
}
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_layernorm_avx2_basic() {
if !is_x86_feature_detected!("avx2") {
return;
}
let input: Vec<f32> = (1..=16).map(|x| x as f32).collect();
let gamma = vec![1.0_f32; 16];
let beta = vec![0.0_f32; 16];
let mut scalar_out = vec![0.0_f32; 16];
let mut avx2_out = vec![0.0_f32; 16];
layernorm_scalar(&input, &gamma, &beta, 1e-5, &mut scalar_out);
unsafe { layernorm_avx2(&input, &gamma, &beta, 1e-5, &mut avx2_out) };
assert_ulp_eq(&scalar_out, &avx2_out, 4);
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_layernorm_avx2_non_multiple_of_8() {
if !is_x86_feature_detected!("avx2") {
return;
}
let input = [1.0_f32, 2.0, 3.0, 4.0, 5.0];
let gamma = [1.0_f32; 5];
let beta = [0.0_f32; 5];
let mut scalar_out = [0.0_f32; 5];
let mut avx2_out = [0.0_f32; 5];
layernorm_scalar(&input, &gamma, &beta, 1e-5, &mut scalar_out);
unsafe { layernorm_avx2(&input, &gamma, &beta, 1e-5, &mut avx2_out) };
assert_ulp_eq(&scalar_out, &avx2_out, 4);
}
#[cfg(target_arch = "x86_64")]
proptest! {
#[test]
fn prop_layernorm_avx2_parity(
v in proptest::collection::vec(-10.0_f32..10.0, 1..64)
) {
if !is_x86_feature_detected!("avx2") {
return Ok(());
}
let gamma = vec![1.0_f32; v.len()];
let beta = vec![0.0_f32; v.len()];
let mut scalar_out = vec![0.0_f32; v.len()];
let mut avx2_out = vec![0.0_f32; v.len()];
layernorm_scalar(&v, &gamma, &beta, 1e-5, &mut scalar_out);
unsafe { layernorm_avx2(&v, &gamma, &beta, 1e-5, &mut avx2_out) };
assert_ulp_eq(&scalar_out, &avx2_out, 4);
}
}
#[test]
fn test_layernorm_ptx_version() {
let ptx = layernorm_ptx();
assert!(ptx.contains(".version 8.5"), "missing PTX version");
}
#[test]
fn test_layernorm_ptx_target() {
let ptx = layernorm_ptx();
assert!(ptx.contains(".target sm_90"), "missing PTX target");
}
#[test]
fn test_layernorm_ptx_entry() {
let ptx = layernorm_ptx();
assert!(
ptx.contains(".entry layernorm_kernel"),
"missing entry point"
);
}
#[test]
fn test_layernorm_ptx_ret() {
let ptx = layernorm_ptx();
assert!(ptx.contains("ret;"), "missing ret instruction");
}
#[test]
fn test_layernorm_ptx_shared_memory() {
let ptx = layernorm_ptx();
assert!(ptx.contains(".shared"), "missing shared memory declaration");
}
#[test]
fn test_layernorm_ptx_warp_shuffle() {
let ptx = layernorm_ptx();
assert!(
ptx.contains("shfl.sync"),
"missing warp shuffle instructions"
);
}
#[test]
fn test_layernorm_ptx_bar_sync() {
let ptx = layernorm_ptx();
assert!(
ptx.contains("bar.sync"),
"missing bar.sync for block synchronization"
);
}
#[test]
fn test_layernorm_ptx_balanced_braces() {
let ptx = layernorm_ptx();
let open = ptx.matches('{').count();
let close = ptx.matches('}').count();
assert_eq!(
open, close,
"unbalanced braces: {open} open vs {close} close"
);
}
}