use crate::dtype::Element;
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn rms_norm_kernel<T: Element>(
input: *const T,
weight: *const T,
out: *mut T,
batch_size: usize,
hidden_size: usize,
eps: f32,
) {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::simd::norm;
use crate::dtype::DType;
match T::DTYPE {
DType::F32 => {
norm::rms_norm_f32(
input as *const f32,
weight as *const f32,
out as *mut f32,
batch_size,
hidden_size,
eps,
);
return;
}
DType::F64 => {
norm::rms_norm_f64(
input as *const f64,
weight as *const f64,
out as *mut f64,
batch_size,
hidden_size,
eps as f64,
);
return;
}
#[cfg(feature = "f16")]
DType::F16 => {
norm::rms_norm_f16(
input as *const half::f16,
weight as *const half::f16,
out as *mut half::f16,
batch_size,
hidden_size,
eps,
);
return;
}
#[cfg(feature = "f16")]
DType::BF16 => {
norm::rms_norm_bf16(
input as *const half::bf16,
weight as *const half::bf16,
out as *mut half::bf16,
batch_size,
hidden_size,
eps,
);
return;
}
_ => {} }
}
rms_norm_kernel_scalar(input, weight, out, batch_size, hidden_size, eps);
}
#[inline]
unsafe fn rms_norm_kernel_scalar<T: Element>(
input: *const T,
weight: *const T,
out: *mut T,
batch_size: usize,
hidden_size: usize,
eps: f32,
) {
let weight_slice = std::slice::from_raw_parts(weight, hidden_size);
let eps = eps as f64;
for batch in 0..batch_size {
let row_start = batch * hidden_size;
let mut sum_sq = 0.0f64;
for i in 0..hidden_size {
let x = (*input.add(row_start + i)).to_f64();
sum_sq += x * x;
}
let rms = (sum_sq / hidden_size as f64 + eps).sqrt();
let inv_rms = 1.0 / rms;
for (i, &w) in weight_slice.iter().enumerate() {
let x = (*input.add(row_start + i)).to_f64();
let result = x * inv_rms * w.to_f64();
*out.add(row_start + i) = T::from_f64(result);
}
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn layer_norm_kernel<T: Element>(
input: *const T,
weight: *const T,
bias: *const T,
out: *mut T,
batch_size: usize,
hidden_size: usize,
eps: f32,
) {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::simd::norm;
use crate::dtype::DType;
match T::DTYPE {
DType::F32 => {
norm::layer_norm_f32(
input as *const f32,
weight as *const f32,
bias as *const f32,
out as *mut f32,
batch_size,
hidden_size,
eps,
);
return;
}
DType::F64 => {
norm::layer_norm_f64(
input as *const f64,
weight as *const f64,
bias as *const f64,
out as *mut f64,
batch_size,
hidden_size,
eps as f64,
);
return;
}
#[cfg(feature = "f16")]
DType::F16 => {
norm::layer_norm_f16(
input as *const half::f16,
weight as *const half::f16,
bias as *const half::f16,
out as *mut half::f16,
batch_size,
hidden_size,
eps,
);
return;
}
#[cfg(feature = "f16")]
DType::BF16 => {
norm::layer_norm_bf16(
input as *const half::bf16,
weight as *const half::bf16,
bias as *const half::bf16,
out as *mut half::bf16,
batch_size,
hidden_size,
eps,
);
return;
}
_ => {} }
}
layer_norm_kernel_scalar(input, weight, bias, out, batch_size, hidden_size, eps);
}
#[inline]
unsafe fn layer_norm_kernel_scalar<T: Element>(
input: *const T,
weight: *const T,
bias: *const T,
out: *mut T,
batch_size: usize,
hidden_size: usize,
eps: f32,
) {
let weight_slice = std::slice::from_raw_parts(weight, hidden_size);
let bias_slice = std::slice::from_raw_parts(bias, hidden_size);
let eps = eps as f64;
for batch in 0..batch_size {
let row_start = batch * hidden_size;
let mut sum = 0.0f64;
for i in 0..hidden_size {
sum += (*input.add(row_start + i)).to_f64();
}
let mean = sum / hidden_size as f64;
let mut var_sum = 0.0f64;
for i in 0..hidden_size {
let x = (*input.add(row_start + i)).to_f64();
let diff = x - mean;
var_sum += diff * diff;
}
let variance = var_sum / hidden_size as f64;
let inv_std = 1.0 / (variance + eps).sqrt();
for i in 0..hidden_size {
let x = (*input.add(row_start + i)).to_f64();
let w = weight_slice[i].to_f64();
let b = bias_slice[i].to_f64();
let result = (x - mean) * inv_std * w + b;
*out.add(row_start + i) = T::from_f64(result);
}
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn group_norm_kernel<T: Element>(
input: *const T,
weight: *const T,
bias: *const T,
out: *mut T,
batch: usize,
channels: usize,
spatial: usize,
num_groups: usize,
channels_per_group: usize,
eps: f32,
) {
let eps = eps as f64;
let group_size = channels_per_group * spatial;
for b in 0..batch {
for g in 0..num_groups {
let c_start = g * channels_per_group;
let mut sum = 0.0f64;
for c in 0..channels_per_group {
let ch = c_start + c;
let offset = (b * channels + ch) * spatial;
for s in 0..spatial {
sum += (*input.add(offset + s)).to_f64();
}
}
let mean = sum / group_size as f64;
let mut var_sum = 0.0f64;
for c in 0..channels_per_group {
let ch = c_start + c;
let offset = (b * channels + ch) * spatial;
for s in 0..spatial {
let diff = (*input.add(offset + s)).to_f64() - mean;
var_sum += diff * diff;
}
}
let inv_std = 1.0 / (var_sum / group_size as f64 + eps).sqrt();
for c in 0..channels_per_group {
let ch = c_start + c;
let w = (*weight.add(ch)).to_f64();
let bi = (*bias.add(ch)).to_f64();
let offset = (b * channels + ch) * spatial;
for s in 0..spatial {
let x = (*input.add(offset + s)).to_f64();
let result = (x - mean) * inv_std * w + bi;
*out.add(offset + s) = T::from_f64(result);
}
}
}
}
}