Skip to main content

provable_contracts/kernels/
batchnorm.rs

1//! Batch normalization kernel.
2//!
3//! Matches `batchnorm-kernel-v1.yaml`.
4//!
5//! Training mode: computes per-channel mean/variance from the batch, normalizes,
6//! applies affine transform, and updates running statistics via EMA.
7//!
8//! Inference mode: uses running mean/variance directly for normalization.
9//!
10//! Input layout: N*C flattened, where N = batch size, C = channels.
11//! Element (n, c) is at index `n * c_count + c`.
12
13// ────────────────────────────────────────────────────────────────────────────
14// Scalar implementation
15// ────────────────────────────────────────────────────────────────────────────
16
17/// Scalar reference implementation of BatchNorm.
18///
19/// # Arguments
20///
21/// * `input`        - Flattened `[N, C]` tensor (row-major).
22/// * `n`            - Batch size (N).
23/// * `c`            - Number of channels (C).
24/// * `gamma`        - Per-channel scale, length `C`.
25/// * `beta`         - Per-channel bias, length `C`.
26/// * `eps`          - Small constant for numerical stability.
27/// * `running_mean` - Running mean per channel, length `C`. Updated in training.
28/// * `running_var`  - Running variance per channel, length `C`. Updated in training.
29/// * `output`       - Output buffer, same shape as `input`.
30/// * `momentum`     - EMA momentum for running stats update.
31/// * `training`     - If true, compute batch stats and update running stats.
32///   If false, use running stats for normalization.
33///
34/// # Panics
35///
36/// Panics if buffer sizes are inconsistent with `n * c`.
37#[allow(clippy::too_many_arguments)]
38pub fn batchnorm_scalar(
39    input: &[f32],
40    n: usize,
41    c: usize,
42    gamma: &[f32],
43    beta: &[f32],
44    eps: f32,
45    running_mean: &mut [f32],
46    running_var: &mut [f32],
47    output: &mut [f32],
48    momentum: f32,
49    training: bool,
50) {
51    assert_eq!(input.len(), n * c, "input length must be n * c");
52    assert_eq!(output.len(), n * c, "output length must be n * c");
53    assert_eq!(gamma.len(), c, "gamma length must be c");
54    assert_eq!(beta.len(), c, "beta length must be c");
55    assert_eq!(running_mean.len(), c, "running_mean length must be c");
56    assert_eq!(running_var.len(), c, "running_var length must be c");
57    assert!(n > 0 && c > 0, "batchnorm requires n > 0 and c > 0");
58
59    if training {
60        // Compute per-channel batch statistics and normalize
61        for ch in 0..c {
62            // Compute batch mean for this channel
63            let mut sum = 0.0_f32;
64            for sample in 0..n {
65                sum += input[sample * c + ch];
66            }
67            let batch_mean = sum / n as f32;
68
69            // Compute batch variance for this channel
70            let mut var_sum = 0.0_f32;
71            for sample in 0..n {
72                let diff = input[sample * c + ch] - batch_mean;
73                var_sum += diff * diff;
74            }
75            let batch_var = var_sum / n as f32;
76
77            // Normalize, scale, and shift
78            let inv_std = 1.0 / (batch_var + eps).sqrt();
79            for sample in 0..n {
80                let idx = sample * c + ch;
81                output[idx] = gamma[ch] * (input[idx] - batch_mean) * inv_std + beta[ch];
82            }
83
84            // Update running statistics (EMA)
85            running_mean[ch] = (1.0 - momentum) * running_mean[ch] + momentum * batch_mean;
86            running_var[ch] = (1.0 - momentum) * running_var[ch] + momentum * batch_var;
87        }
88    } else {
89        // Inference: use running stats
90        for ch in 0..c {
91            let inv_std = 1.0 / (running_var[ch] + eps).sqrt();
92            for sample in 0..n {
93                let idx = sample * c + ch;
94                output[idx] = gamma[ch] * (input[idx] - running_mean[ch]) * inv_std + beta[ch];
95            }
96        }
97    }
98}
99
100// ────────────────────────────────────────────────────────────────────────────
101// AVX2 implementation
102// ────────────────────────────────────────────────────────────────────────────
103
104/// AVX2 BatchNorm -- delegates to scalar.
105///
106/// Batch dimension reduction is irregular for SIMD (strided access across
107/// samples for each channel), so this delegates to the scalar implementation.
108///
109/// # Safety
110///
111/// Requires AVX2 support. Caller must verify with `is_x86_feature_detected!("avx2")`.
112///
113/// # Panics
114///
115/// Panics if buffer sizes are inconsistent with `n * c`.
116#[cfg(target_arch = "x86_64")]
117#[target_feature(enable = "avx2")]
118#[allow(clippy::too_many_arguments)]
119pub unsafe fn batchnorm_avx2(
120    input: &[f32],
121    n: usize,
122    c: usize,
123    gamma: &[f32],
124    beta: &[f32],
125    eps: f32,
126    running_mean: &mut [f32],
127    running_var: &mut [f32],
128    output: &mut [f32],
129    momentum: f32,
130    training: bool,
131) {
132    batchnorm_scalar(
133        input,
134        n,
135        c,
136        gamma,
137        beta,
138        eps,
139        running_mean,
140        running_var,
141        output,
142        momentum,
143        training,
144    );
145}
146
147// ────────────────────────────────────────────────────────────────────────────
148// PTX implementation
149// ────────────────────────────────────────────────────────────────────────────
150
151include!("batchnorm_ptx.rs");
152
153// ────────────────────────────────────────────────────────────────────────────
154// Tests
155// ────────────────────────────────────────────────────────────────────────────
156
157#[cfg(test)]
158mod tests {
159    // Scalar + property-based tests
160    include!("batchnorm_tests.rs");
161    // AVX2 parity + PTX structural tests
162    include!("batchnorm_tests2.rs");
163}