#[allow(clippy::too_many_arguments)]
pub fn batchnorm_scalar(
input: &[f32],
n: usize,
c: usize,
gamma: &[f32],
beta: &[f32],
eps: f32,
running_mean: &mut [f32],
running_var: &mut [f32],
output: &mut [f32],
momentum: f32,
training: bool,
) {
assert_eq!(input.len(), n * c, "input length must be n * c");
assert_eq!(output.len(), n * c, "output length must be n * c");
assert_eq!(gamma.len(), c, "gamma length must be c");
assert_eq!(beta.len(), c, "beta length must be c");
assert_eq!(running_mean.len(), c, "running_mean length must be c");
assert_eq!(running_var.len(), c, "running_var length must be c");
assert!(n > 0 && c > 0, "batchnorm requires n > 0 and c > 0");
if training {
for ch in 0..c {
let mut sum = 0.0_f32;
for sample in 0..n {
sum += input[sample * c + ch];
}
let batch_mean = sum / n as f32;
let mut var_sum = 0.0_f32;
for sample in 0..n {
let diff = input[sample * c + ch] - batch_mean;
var_sum += diff * diff;
}
let batch_var = var_sum / n as f32;
let inv_std = 1.0 / (batch_var + eps).sqrt();
for sample in 0..n {
let idx = sample * c + ch;
output[idx] = gamma[ch] * (input[idx] - batch_mean) * inv_std + beta[ch];
}
running_mean[ch] = (1.0 - momentum) * running_mean[ch] + momentum * batch_mean;
running_var[ch] = (1.0 - momentum) * running_var[ch] + momentum * batch_var;
}
} else {
for ch in 0..c {
let inv_std = 1.0 / (running_var[ch] + eps).sqrt();
for sample in 0..n {
let idx = sample * c + ch;
output[idx] = gamma[ch] * (input[idx] - running_mean[ch]) * inv_std + beta[ch];
}
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[allow(clippy::too_many_arguments)]
pub unsafe fn batchnorm_avx2(
input: &[f32],
n: usize,
c: usize,
gamma: &[f32],
beta: &[f32],
eps: f32,
running_mean: &mut [f32],
running_var: &mut [f32],
output: &mut [f32],
momentum: f32,
training: bool,
) {
batchnorm_scalar(
input,
n,
c,
gamma,
beta,
eps,
running_mean,
running_var,
output,
momentum,
training,
);
}
include!("batchnorm_ptx.rs");
#[cfg(test)]
mod tests {
include!("batchnorm_tests.rs");
include!("batchnorm_tests2.rs");
}