use crate::dtype::Element;
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn fused_add_rms_norm_kernel<T: Element>(
input: *const T,
residual: *const T,
weight: *const T,
out: *mut T,
pre_norm: *mut T,
batch_size: usize,
hidden_size: usize,
eps: f32,
) {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::simd::norm;
use crate::dtype::DType;
match T::DTYPE {
DType::F32 => {
norm::fused_add_rms_norm_f32(
input as *const f32,
residual as *const f32,
weight as *const f32,
out as *mut f32,
pre_norm as *mut f32,
batch_size,
hidden_size,
eps,
);
return;
}
DType::F64 => {
norm::fused_add_rms_norm_f64(
input as *const f64,
residual as *const f64,
weight as *const f64,
out as *mut f64,
pre_norm as *mut f64,
batch_size,
hidden_size,
eps as f64,
);
return;
}
#[cfg(feature = "f16")]
DType::F16 => {
norm::fused_add_rms_norm_f16(
input as *const half::f16,
residual as *const half::f16,
weight as *const half::f16,
out as *mut half::f16,
pre_norm as *mut half::f16,
batch_size,
hidden_size,
eps,
);
return;
}
#[cfg(feature = "f16")]
DType::BF16 => {
norm::fused_add_rms_norm_bf16(
input as *const half::bf16,
residual as *const half::bf16,
weight as *const half::bf16,
out as *mut half::bf16,
pre_norm as *mut half::bf16,
batch_size,
hidden_size,
eps,
);
return;
}
_ => {}
}
}
fused_add_rms_norm_scalar(
input,
residual,
weight,
out,
pre_norm,
batch_size,
hidden_size,
eps,
);
}
#[inline]
unsafe fn fused_add_rms_norm_scalar<T: Element>(
input: *const T,
residual: *const T,
weight: *const T,
out: *mut T,
pre_norm_out: *mut T,
batch_size: usize,
hidden_size: usize,
eps: f32,
) {
let eps = eps as f64;
let weight_slice = std::slice::from_raw_parts(weight, hidden_size);
for batch in 0..batch_size {
let row = batch * hidden_size;
let mut sum_sq = 0.0f64;
for i in 0..hidden_size {
let pn = (*input.add(row + i)).to_f64() + (*residual.add(row + i)).to_f64();
*pre_norm_out.add(row + i) = T::from_f64(pn);
sum_sq += pn * pn;
}
let inv_rms = 1.0 / (sum_sq / hidden_size as f64 + eps).sqrt();
for (i, &w) in weight_slice.iter().enumerate() {
let pn = (*pre_norm_out.add(row + i)).to_f64();
*out.add(row + i) = T::from_f64(pn * inv_rms * w.to_f64());
}
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn fused_add_rms_norm_bwd_kernel<T: Element>(
grad: *const T,
pre_norm: *const T,
weight: *const T,
d_input_residual: *mut T,
d_weight: *mut T,
batch_size: usize,
hidden_size: usize,
eps: f32,
) {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::simd::norm;
use crate::dtype::DType;
match T::DTYPE {
DType::F32 => {
norm::fused_add_rms_norm_bwd_f32(
grad as *const f32,
pre_norm as *const f32,
weight as *const f32,
d_input_residual as *mut f32,
d_weight as *mut f32,
batch_size,
hidden_size,
eps,
);
return;
}
DType::F64 => {
norm::fused_add_rms_norm_bwd_f64(
grad as *const f64,
pre_norm as *const f64,
weight as *const f64,
d_input_residual as *mut f64,
d_weight as *mut f64,
batch_size,
hidden_size,
eps as f64,
);
return;
}
#[cfg(feature = "f16")]
DType::F16 => {
norm::fused_add_rms_norm_bwd_f16(
grad as *const half::f16,
pre_norm as *const half::f16,
weight as *const half::f16,
d_input_residual as *mut half::f16,
d_weight as *mut half::f16,
batch_size,
hidden_size,
eps,
);
return;
}
#[cfg(feature = "f16")]
DType::BF16 => {
norm::fused_add_rms_norm_bwd_bf16(
grad as *const half::bf16,
pre_norm as *const half::bf16,
weight as *const half::bf16,
d_input_residual as *mut half::bf16,
d_weight as *mut half::bf16,
batch_size,
hidden_size,
eps,
);
return;
}
_ => {}
}
}
fused_add_rms_norm_bwd_scalar(
grad,
pre_norm,
weight,
d_input_residual,
d_weight,
batch_size,
hidden_size,
eps,
);
}
#[inline]
unsafe fn fused_add_rms_norm_bwd_scalar<T: Element>(
grad: *const T,
pre_norm: *const T,
weight: *const T,
d_input_residual: *mut T,
d_weight: *mut T,
batch_size: usize,
hidden_size: usize,
eps: f32,
) {
let eps = eps as f64;
let weight_slice = std::slice::from_raw_parts(weight, hidden_size);
for batch in 0..batch_size {
let row = batch * hidden_size;
let mut sum_sq = 0.0f64;
for i in 0..hidden_size {
let pn = (*pre_norm.add(row + i)).to_f64();
sum_sq += pn * pn;
}
let mean_sq = sum_sq / hidden_size as f64;
let inv_rms = 1.0 / (mean_sq + eps).sqrt();
let mut dot = 0.0f64;
for i in 0..hidden_size {
let g = (*grad.add(row + i)).to_f64();
let w = weight_slice[i].to_f64();
let pn = (*pre_norm.add(row + i)).to_f64();
dot += g * w * pn;
}
let coeff = dot * inv_rms / (hidden_size as f64 * (mean_sq + eps));
for i in 0..hidden_size {
let g = (*grad.add(row + i)).to_f64();
let w = weight_slice[i].to_f64();
let pn = (*pre_norm.add(row + i)).to_f64();
let d_ir = (g * w - pn * coeff) * inv_rms;
*d_input_residual.add(row + i) = T::from_f64(d_ir);
let dw_old = (*d_weight.add(i)).to_f64();
*d_weight.add(i) = T::from_f64(dw_old + g * pn * inv_rms);
}
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn fused_add_layer_norm_kernel<T: Element>(
input: *const T,
residual: *const T,
weight: *const T,
bias: *const T,
out: *mut T,
pre_norm: *mut T,
batch_size: usize,
hidden_size: usize,
eps: f32,
) {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::simd::norm;
use crate::dtype::DType;
match T::DTYPE {
DType::F32 => {
norm::fused_add_layer_norm_f32(
input as *const f32,
residual as *const f32,
weight as *const f32,
bias as *const f32,
out as *mut f32,
pre_norm as *mut f32,
batch_size,
hidden_size,
eps,
);
return;
}
DType::F64 => {
norm::fused_add_layer_norm_f64(
input as *const f64,
residual as *const f64,
weight as *const f64,
bias as *const f64,
out as *mut f64,
pre_norm as *mut f64,
batch_size,
hidden_size,
eps as f64,
);
return;
}
#[cfg(feature = "f16")]
DType::F16 => {
norm::fused_add_layer_norm_f16(
input as *const half::f16,
residual as *const half::f16,
weight as *const half::f16,
bias as *const half::f16,
out as *mut half::f16,
pre_norm as *mut half::f16,
batch_size,
hidden_size,
eps,
);
return;
}
#[cfg(feature = "f16")]
DType::BF16 => {
norm::fused_add_layer_norm_bf16(
input as *const half::bf16,
residual as *const half::bf16,
weight as *const half::bf16,
bias as *const half::bf16,
out as *mut half::bf16,
pre_norm as *mut half::bf16,
batch_size,
hidden_size,
eps,
);
return;
}
_ => {}
}
}
fused_add_layer_norm_scalar(
input,
residual,
weight,
bias,
out,
pre_norm,
batch_size,
hidden_size,
eps,
);
}
#[inline]
unsafe fn fused_add_layer_norm_scalar<T: Element>(
input: *const T,
residual: *const T,
weight: *const T,
bias: *const T,
out: *mut T,
pre_norm_out: *mut T,
batch_size: usize,
hidden_size: usize,
eps: f32,
) {
let eps = eps as f64;
let weight_slice = std::slice::from_raw_parts(weight, hidden_size);
let bias_slice = std::slice::from_raw_parts(bias, hidden_size);
for batch in 0..batch_size {
let row = batch * hidden_size;
let mut sum = 0.0f64;
for i in 0..hidden_size {
let pn = (*input.add(row + i)).to_f64() + (*residual.add(row + i)).to_f64();
*pre_norm_out.add(row + i) = T::from_f64(pn);
sum += pn;
}
let mean = sum / hidden_size as f64;
let mut var_sum = 0.0f64;
for i in 0..hidden_size {
let pn = (*pre_norm_out.add(row + i)).to_f64();
let diff = pn - mean;
var_sum += diff * diff;
}
let inv_std = 1.0 / (var_sum / hidden_size as f64 + eps).sqrt();
for i in 0..hidden_size {
let pn = (*pre_norm_out.add(row + i)).to_f64();
let w = weight_slice[i].to_f64();
let b = bias_slice[i].to_f64();
*out.add(row + i) = T::from_f64((pn - mean) * inv_std * w + b);
}
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn fused_add_layer_norm_bwd_kernel<T: Element>(
grad: *const T,
pre_norm: *const T,
weight: *const T,
_bias: *const T,
d_input_residual: *mut T,
d_weight: *mut T,
d_bias: *mut T,
batch_size: usize,
hidden_size: usize,
eps: f32,
) {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::simd::norm;
use crate::dtype::DType;
match T::DTYPE {
DType::F32 => {
norm::fused_add_layer_norm_bwd_f32(
grad as *const f32,
pre_norm as *const f32,
weight as *const f32,
d_input_residual as *mut f32,
d_weight as *mut f32,
d_bias as *mut f32,
batch_size,
hidden_size,
eps,
);
return;
}
DType::F64 => {
norm::fused_add_layer_norm_bwd_f64(
grad as *const f64,
pre_norm as *const f64,
weight as *const f64,
d_input_residual as *mut f64,
d_weight as *mut f64,
d_bias as *mut f64,
batch_size,
hidden_size,
eps as f64,
);
return;
}
#[cfg(feature = "f16")]
DType::F16 => {
norm::fused_add_layer_norm_bwd_f16(
grad as *const half::f16,
pre_norm as *const half::f16,
weight as *const half::f16,
d_input_residual as *mut half::f16,
d_weight as *mut half::f16,
d_bias as *mut half::f16,
batch_size,
hidden_size,
eps,
);
return;
}
#[cfg(feature = "f16")]
DType::BF16 => {
norm::fused_add_layer_norm_bwd_bf16(
grad as *const half::bf16,
pre_norm as *const half::bf16,
weight as *const half::bf16,
d_input_residual as *mut half::bf16,
d_weight as *mut half::bf16,
d_bias as *mut half::bf16,
batch_size,
hidden_size,
eps,
);
return;
}
_ => {}
}
}
fused_add_layer_norm_bwd_scalar(
grad,
pre_norm,
weight,
d_input_residual,
d_weight,
d_bias,
batch_size,
hidden_size,
eps,
);
}
#[inline]
unsafe fn fused_add_layer_norm_bwd_scalar<T: Element>(
grad: *const T,
pre_norm: *const T,
weight: *const T,
d_input_residual: *mut T,
d_weight: *mut T,
d_bias: *mut T,
batch_size: usize,
hidden_size: usize,
eps: f32,
) {
let eps = eps as f64;
let weight_slice = std::slice::from_raw_parts(weight, hidden_size);
for batch in 0..batch_size {
let row = batch * hidden_size;
let mut sum = 0.0f64;
for i in 0..hidden_size {
sum += (*pre_norm.add(row + i)).to_f64();
}
let mean = sum / hidden_size as f64;
let mut var_sum = 0.0f64;
for i in 0..hidden_size {
let diff = (*pre_norm.add(row + i)).to_f64() - mean;
var_sum += diff * diff;
}
let inv_std = 1.0 / (var_sum / hidden_size as f64 + eps).sqrt();
let mut mean_gs = 0.0f64;
let mut mean_gs_n = 0.0f64;
for i in 0..hidden_size {
let g = (*grad.add(row + i)).to_f64();
let w = weight_slice[i].to_f64();
let pn = (*pre_norm.add(row + i)).to_f64();
let normalized = (pn - mean) * inv_std;
let gs = g * w;
mean_gs += gs;
mean_gs_n += gs * normalized;
}
mean_gs /= hidden_size as f64;
mean_gs_n /= hidden_size as f64;
for i in 0..hidden_size {
let g = (*grad.add(row + i)).to_f64();
let w = weight_slice[i].to_f64();
let pn = (*pre_norm.add(row + i)).to_f64();
let normalized = (pn - mean) * inv_std;
let gs = g * w;
let d_ir = inv_std * (gs - mean_gs - normalized * mean_gs_n);
*d_input_residual.add(row + i) = T::from_f64(d_ir);
let dw_old = (*d_weight.add(i)).to_f64();
*d_weight.add(i) = T::from_f64(dw_old + g * normalized);
let db_old = (*d_bias.add(i)).to_f64();
*d_bias.add(i) = T::from_f64(db_old + g);
}
}
}