provable_contracts/kernels/batchnorm.rs
1//! Batch normalization kernel.
2//!
3//! Matches `batchnorm-kernel-v1.yaml`.
4//!
5//! Training mode: computes per-channel mean/variance from the batch, normalizes,
6//! applies affine transform, and updates running statistics via EMA.
7//!
8//! Inference mode: uses running mean/variance directly for normalization.
9//!
10//! Input layout: N*C flattened, where N = batch size, C = channels.
11//! Element (n, c) is at index `n * c_count + c`.
12
13// ────────────────────────────────────────────────────────────────────────────
14// Scalar implementation
15// ────────────────────────────────────────────────────────────────────────────
16
17/// Scalar reference implementation of BatchNorm.
18///
19/// # Arguments
20///
21/// * `input` - Flattened `[N, C]` tensor (row-major).
22/// * `n` - Batch size (N).
23/// * `c` - Number of channels (C).
24/// * `gamma` - Per-channel scale, length `C`.
25/// * `beta` - Per-channel bias, length `C`.
26/// * `eps` - Small constant for numerical stability.
27/// * `running_mean` - Running mean per channel, length `C`. Updated in training.
28/// * `running_var` - Running variance per channel, length `C`. Updated in training.
29/// * `output` - Output buffer, same shape as `input`.
30/// * `momentum` - EMA momentum for running stats update.
31/// * `training` - If true, compute batch stats and update running stats.
32/// If false, use running stats for normalization.
33///
34/// # Panics
35///
36/// Panics if buffer sizes are inconsistent with `n * c`.
37#[allow(clippy::too_many_arguments)]
38pub fn batchnorm_scalar(
39 input: &[f32],
40 n: usize,
41 c: usize,
42 gamma: &[f32],
43 beta: &[f32],
44 eps: f32,
45 running_mean: &mut [f32],
46 running_var: &mut [f32],
47 output: &mut [f32],
48 momentum: f32,
49 training: bool,
50) {
51 assert_eq!(input.len(), n * c, "input length must be n * c");
52 assert_eq!(output.len(), n * c, "output length must be n * c");
53 assert_eq!(gamma.len(), c, "gamma length must be c");
54 assert_eq!(beta.len(), c, "beta length must be c");
55 assert_eq!(running_mean.len(), c, "running_mean length must be c");
56 assert_eq!(running_var.len(), c, "running_var length must be c");
57 assert!(n > 0 && c > 0, "batchnorm requires n > 0 and c > 0");
58
59 if training {
60 // Compute per-channel batch statistics and normalize
61 for ch in 0..c {
62 // Compute batch mean for this channel
63 let mut sum = 0.0_f32;
64 for sample in 0..n {
65 sum += input[sample * c + ch];
66 }
67 let batch_mean = sum / n as f32;
68
69 // Compute batch variance for this channel
70 let mut var_sum = 0.0_f32;
71 for sample in 0..n {
72 let diff = input[sample * c + ch] - batch_mean;
73 var_sum += diff * diff;
74 }
75 let batch_var = var_sum / n as f32;
76
77 // Normalize, scale, and shift
78 let inv_std = 1.0 / (batch_var + eps).sqrt();
79 for sample in 0..n {
80 let idx = sample * c + ch;
81 output[idx] = gamma[ch] * (input[idx] - batch_mean) * inv_std + beta[ch];
82 }
83
84 // Update running statistics (EMA)
85 running_mean[ch] = (1.0 - momentum) * running_mean[ch] + momentum * batch_mean;
86 running_var[ch] = (1.0 - momentum) * running_var[ch] + momentum * batch_var;
87 }
88 } else {
89 // Inference: use running stats
90 for ch in 0..c {
91 let inv_std = 1.0 / (running_var[ch] + eps).sqrt();
92 for sample in 0..n {
93 let idx = sample * c + ch;
94 output[idx] = gamma[ch] * (input[idx] - running_mean[ch]) * inv_std + beta[ch];
95 }
96 }
97 }
98}
99
100// ────────────────────────────────────────────────────────────────────────────
101// AVX2 implementation
102// ────────────────────────────────────────────────────────────────────────────
103
104/// AVX2 BatchNorm -- delegates to scalar.
105///
106/// Batch dimension reduction is irregular for SIMD (strided access across
107/// samples for each channel), so this delegates to the scalar implementation.
108///
109/// # Safety
110///
111/// Requires AVX2 support. Caller must verify with `is_x86_feature_detected!("avx2")`.
112///
113/// # Panics
114///
115/// Panics if buffer sizes are inconsistent with `n * c`.
116#[cfg(target_arch = "x86_64")]
117#[target_feature(enable = "avx2")]
118#[allow(clippy::too_many_arguments)]
119pub unsafe fn batchnorm_avx2(
120 input: &[f32],
121 n: usize,
122 c: usize,
123 gamma: &[f32],
124 beta: &[f32],
125 eps: f32,
126 running_mean: &mut [f32],
127 running_var: &mut [f32],
128 output: &mut [f32],
129 momentum: f32,
130 training: bool,
131) {
132 batchnorm_scalar(
133 input,
134 n,
135 c,
136 gamma,
137 beta,
138 eps,
139 running_mean,
140 running_var,
141 output,
142 momentum,
143 training,
144 );
145}
146
147// ────────────────────────────────────────────────────────────────────────────
148// PTX implementation
149// ────────────────────────────────────────────────────────────────────────────
150
151include!("batchnorm_ptx.rs");
152
153// ────────────────────────────────────────────────────────────────────────────
154// Tests
155// ────────────────────────────────────────────────────────────────────────────
156
157#[cfg(test)]
158mod tests {
159 // Scalar + property-based tests
160 include!("batchnorm_tests.rs");
161 // AVX2 parity + PTX structural tests
162 include!("batchnorm_tests2.rs");
163}