Skip to main content

batchnorm_scalar

Function batchnorm_scalar 

Source
pub fn batchnorm_scalar(
    input: &[f32],
    n: usize,
    c: usize,
    gamma: &[f32],
    beta: &[f32],
    eps: f32,
    running_mean: &mut [f32],
    running_var: &mut [f32],
    output: &mut [f32],
    momentum: f32,
    training: bool,
)
Expand description

Scalar reference implementation of BatchNorm.

§Arguments

  • input - Flattened [N, C] tensor (row-major).
  • n - Batch size (N).
  • c - Number of channels (C).
  • gamma - Per-channel scale, length C.
  • beta - Per-channel bias, length C.
  • eps - Small constant for numerical stability.
  • running_mean - Running mean per channel, length C. Updated in training.
  • running_var - Running variance per channel, length C. Updated in training.
  • output - Output buffer, same shape as input.
  • momentum - EMA momentum for running stats update.
  • training - If true, compute batch stats and update running stats. If false, use running stats for normalization.

§Panics

Panics if buffer sizes are inconsistent with n * c.