#[non_exhaustive]#[repr(u16)]pub enum NormalizationKind {
RMSNorm = 0,
LayerNorm = 1,
GroupNorm = 2,
BatchNorm = 3,
InstanceNorm = 4,
}Expand description
Normalization op discriminant — category G from the comprehensive plan.
Stored as u16 in crate::KernelSku::op when
category == OpCategory::Normalization. The variants differ in
which axes are reduced for the per-row statistics and how the
affine parameters (gamma / beta) are indexed.
Today wired: {RMSNorm, LayerNorm, BatchNorm, GroupNorm, InstanceNorm} × {f32, f16, bf16, f64} — FW + BW. RMSNorm /
LayerNorm support multi-axis normalization via a bitmask
(PyTorch’s normalized_shape — must be a suffix of the input
shape). InstanceNorm is implemented as a thin wrapper around
GroupNorm with num_groups == c_extent (shares kernel symbols).
BatchNorm is training-mode-only for the trailblazer — it
computes per-channel stats from the batch and saves them for BW.
Inference mode (use of running statistics, reducing to a per-
channel affine multiply) is reserved for a follow-up. WeightNorm
(a parameterization rather than a plain op) and LocalResponseNorm
(rarely used today) are explicitly deferred.
Variants (Non-exhaustive)§
This enum is marked as non-exhaustive
RMSNorm = 0
y = x / sqrt(mean(x², over norm_axes) + eps) * gamma.
Llama / Mistral / Gemma block-pre-norm. Trailblazer SKU.
LayerNorm = 1
y = (x - mean) / sqrt(var + eps) * gamma + beta. PyTorch’s
torch.nn.LayerNorm with biased / “population” variance.
GroupNorm = 2
Per-group-of-channels statistics. y[n, c, ...] = (x[n, c, ...] - mean[n, g]) / sqrt(var[n, g] + eps) * gamma[c] + beta[c],
g = c / (C / num_groups). PyTorch torch.nn.GroupNorm.
BatchNorm = 3
Per-channel statistics across batch + spatial. Training-mode
only — saves (saved_mean, saved_rstd) of shape [C]. Inference
mode (running stats) deferred. PyTorch torch.nn.BatchNormNd.
InstanceNorm = 4
Per-(sample, channel) statistics across spatial only. PyTorch
torch.nn.InstanceNormNd. Equivalent to GroupNorm with
num_groups == num_channels; same kernel symbols.
Trait Implementations§
Source§impl Clone for NormalizationKind
impl Clone for NormalizationKind
Source§fn clone(&self) -> NormalizationKind
fn clone(&self) -> NormalizationKind
1.0.0 (const: unstable) · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read moreimpl Copy for NormalizationKind
Source§impl Debug for NormalizationKind
impl Debug for NormalizationKind
impl Eq for NormalizationKind
Source§impl Hash for NormalizationKind
impl Hash for NormalizationKind
Source§impl PartialEq for NormalizationKind
impl PartialEq for NormalizationKind
Source§fn eq(&self, other: &NormalizationKind) -> bool
fn eq(&self, other: &NormalizationKind) -> bool
self and other values to be equal, and is used by ==.