1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
//! Normalization op family — Phase 5 Category G.
//!
//! Length-preserving per-row normalization. Output shape equals input
//! shape across all variants. Today wired:
//!
//! - **RMSNorm** (FW + BW) — `y = x / sqrt(mean(x², over norm_axes) + eps) * gamma`.
//! Llama / Mistral / Gemma block-pre-norm. Multi-axis via `norm_axes_mask`
//! bitmask (PyTorch `normalized_shape` convention: must be a suffix of
//! the input shape).
//!
//! - **LayerNorm** (FW + BW) — `y = (x - mean) / sqrt(var + eps) * gamma + beta`.
//! Same multi-axis spec as RMSNorm.
//!
//! - **BatchNorm** (FW + BW) — per-channel normalization across
//! `(N, *spatial*)`. Training mode only (inference mode using running
//! statistics is deferred).
//!
//! - **GroupNorm** (FW + BW) — splits channel axis into `num_groups`
//! groups, normalizes per `(sample, group, *spatial*)`.
//!
//! - **InstanceNorm** (FW + BW) — thin wrapper over GroupNorm with
//! `num_groups == num_channels` (same kernel symbols).
//!
//! ## Deferred
//!
//! - `WeightNorm` — a parameterization, not a plain op.
//! - `LocalResponseNorm` — rarely used.
//! - BatchNorm inference mode (running statistics → per-channel affine
//! multiply).
//!
//! ## Design notes
//!
//! - **No atomic adds.** Affine-grad accumulators (`dgamma`, `dbeta`)
//! and group-stats reductions use one-block-per-feature kernels with
//! warp shuffles + smem — fully deterministic, no half / bf16
//! atomicAdd arch quirks.
//!
//! - **f16 / bf16 accumulate in f32** (mandatory — variance in half
//! precision is catastrophic). f64 uses double throughout. For
//! BatchNorm BW workspace partials we keep f32 even at f64 (acceptable
//! precision loss on the partial-sum workspace for the trailblazer).
//!
//! - **Per-output-cell two-pass per-row.** Same naive O(extent²) total
//! work per row as the softmax kernel for RMSNorm / LayerNorm; the
//! BN/GN three-stage scheme amortizes the per-group reduction.
pub use ;
pub use ;
pub use ;
pub use ;
pub use ;
pub use ;
pub use ;
pub use ;
pub use ;
pub use ;