Skip to main content

any_gpu/ops/
norm.rs

1// Unlicense — cochranblock.org
2// Contributors: GotEmCoach, KOVA, Claude Opus 4.6
3//
4// Group normalization (two-pass: compute stats, then normalize+affine).
5
6use crate::device::{GpuBuffer, GpuDevice};
7use anyhow::{ensure, Result};
8
9#[repr(C)]
10#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
11struct GNStatsParams {
12    batch: u32,
13    channels: u32,
14    spatial: u32,
15    groups: u32,
16    channels_per_group: u32,
17    eps: f32,
18    _pad: [u32; 2],
19}
20
21// Pass 1: one thread per (batch, group). Computes mean and variance.
22const SHADER_GN_STATS: &str = "
23struct P {
24    batch: u32, channels: u32, spatial: u32, groups: u32,
25    cpg: u32, eps: f32, _p0: u32, _p1: u32,
26}
27@group(0) @binding(0) var<uniform> p: P;
28@group(0) @binding(1) var<storage, read> input: array<f32>;
29@group(0) @binding(2) var<storage, read_write> stats: array<f32>;
30@compute @workgroup_size(256)
31fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
32    let idx = gid.x;
33    let total = p.batch * p.groups;
34    if idx >= total { return; }
35
36    let g = idx % p.groups;
37    let n = idx / p.groups;
38
39    let count = p.cpg * p.spatial;
40    var sum: f32 = 0.0;
41    var sum_sq: f32 = 0.0;
42    for (var c: u32 = 0u; c < p.cpg; c++) {
43        let ch = g * p.cpg + c;
44        let base = n * (p.channels * p.spatial) + ch * p.spatial;
45        for (var s: u32 = 0u; s < p.spatial; s++) {
46            let v = input[base + s];
47            sum += v;
48            sum_sq += v * v;
49        }
50    }
51    let mean = sum / f32(count);
52    let variance = sum_sq / f32(count) - mean * mean;
53    stats[idx * 2u] = mean;
54    stats[idx * 2u + 1u] = 1.0 / sqrt(variance + p.eps);
55}
56";
57
58// Pass 2: one thread per element. Normalize and apply affine.
59const SHADER_GN_NORM: &str = "
60struct P {
61    batch: u32, channels: u32, spatial: u32, groups: u32,
62    cpg: u32, eps: f32, _p0: u32, _p1: u32,
63}
64@group(0) @binding(0) var<uniform> p: P;
65@group(0) @binding(1) var<storage, read> input: array<f32>;
66@group(0) @binding(2) var<storage, read> stats: array<f32>;
67@group(0) @binding(3) var<storage, read> gamma: array<f32>;
68@group(0) @binding(4) var<storage, read> beta: array<f32>;
69@group(0) @binding(5) var<storage, read_write> out: array<f32>;
70@compute @workgroup_size(256)
71fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
72    let idx = gid.x + gid.y * 65535u * 256u;
73    let total = p.batch * p.channels * p.spatial;
74    if idx >= total { return; }
75
76    let s = idx % p.spatial;
77    let ch = (idx / p.spatial) % p.channels;
78    let n = idx / (p.spatial * p.channels);
79    let g = ch / p.cpg;
80
81    let stat_idx = n * p.groups + g;
82    let mean = stats[stat_idx * 2u];
83    let inv_std = stats[stat_idx * 2u + 1u];
84
85    out[idx] = (input[idx] - mean) * inv_std * gamma[ch] + beta[ch];
86}
87";
88
89impl GpuDevice {
90    /// Group normalization: input[N,C,*spatial] with C/groups groups.
91    /// gamma[C] and beta[C] are learnable affine params.
92    pub fn group_norm(
93        &self,
94        input: &GpuBuffer,
95        gamma: &GpuBuffer,
96        beta: &GpuBuffer,
97        batch: u32, channels: u32, spatial: u32, groups: u32,
98        eps: f32,
99    ) -> Result<GpuBuffer> {
100        ensure!(input.len == (batch * channels * spatial) as usize);
101        ensure!(gamma.len == channels as usize);
102        ensure!(beta.len == channels as usize);
103        ensure!(channels % groups == 0, "channels must be divisible by groups");
104
105        let cpg = channels / groups;
106        let params = GNStatsParams { batch, channels, spatial, groups, channels_per_group: cpg, eps, _pad: [0; 2] };
107
108        // Pass 1: compute per-group mean and inv_std
109        let stats = self.alloc((batch * groups * 2) as usize);
110        self.dispatch_shader(
111            SHADER_GN_STATS, Some("gn_stats"),
112            &params, &[input], &stats,
113            super::dispatch_1d(batch * groups),
114        );
115
116        // Pass 2: normalize + affine
117        let total = batch * channels * spatial;
118        let out = self.alloc(total as usize);
119
120        // For pass 2 we need 5 storage bindings + params. Use raw dispatch.
121        let params_buf = self.upload_uniform(&params);
122        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
123            label: Some("gn_norm"),
124            source: wgpu::ShaderSource::Wgsl(SHADER_GN_NORM.into()),
125        });
126        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
127            label: Some("gn_norm"),
128            layout: None,
129            module: &shader,
130            entry_point: Some("main"),
131            compilation_options: Default::default(),
132            cache: None,
133        });
134        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
135            label: None,
136            layout: &pipeline.get_bind_group_layout(0),
137            entries: &[
138                wgpu::BindGroupEntry { binding: 0, resource: params_buf.as_entire_binding() },
139                wgpu::BindGroupEntry { binding: 1, resource: input.buffer.as_entire_binding() },
140                wgpu::BindGroupEntry { binding: 2, resource: stats.buffer.as_entire_binding() },
141                wgpu::BindGroupEntry { binding: 3, resource: gamma.buffer.as_entire_binding() },
142                wgpu::BindGroupEntry { binding: 4, resource: beta.buffer.as_entire_binding() },
143                wgpu::BindGroupEntry { binding: 5, resource: out.buffer.as_entire_binding() },
144            ],
145        });
146        let (wx, wy, wz) = super::dispatch_1d(total);
147        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
148        {
149            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
150                label: Some("gn_norm"),
151                timestamp_writes: None,
152            });
153            pass.set_pipeline(&pipeline);
154            pass.set_bind_group(0, &bind_group, &[]);
155            pass.dispatch_workgroups(wx, wy, wz);
156        }
157        self.queue.submit(Some(encoder.finish()));
158
159        Ok(out)
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use crate::ops::assert_approx;
167    fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
168
169    // CPU reference group_norm
170    fn cpu_group_norm(
171        input: &[f32], gamma: &[f32], beta: &[f32],
172        batch: usize, channels: usize, spatial: usize, groups: usize, eps: f32,
173    ) -> Vec<f32> {
174        let cpg = channels / groups;
175        let mut out = vec![0.0f32; input.len()];
176        for n in 0..batch {
177            for g in 0..groups {
178                let mut sum = 0.0f32;
179                let mut sum_sq = 0.0f32;
180                let count = (cpg * spatial) as f32;
181                for c in 0..cpg {
182                    let ch = g * cpg + c;
183                    for s in 0..spatial {
184                        let v = input[n * channels * spatial + ch * spatial + s];
185                        sum += v;
186                        sum_sq += v * v;
187                    }
188                }
189                let mean = sum / count;
190                let var = sum_sq / count - mean * mean;
191                let inv_std = 1.0 / (var + eps).sqrt();
192                for c in 0..cpg {
193                    let ch = g * cpg + c;
194                    for s in 0..spatial {
195                        let idx = n * channels * spatial + ch * spatial + s;
196                        out[idx] = (input[idx] - mean) * inv_std * gamma[ch] + beta[ch];
197                    }
198                }
199            }
200        }
201        out
202    }
203
204    #[test]
205    fn test_group_norm_per_channel() {
206        // groups=channels: each channel is its own group
207        let input = dev().upload(&[1.0, 3.0, 2.0, 4.0]);
208        let gamma = dev().upload(&[1.0, 1.0]);
209        let beta = dev().upload(&[0.0, 0.0]);
210        let result = dev().read(&dev().group_norm(&input, &gamma, &beta, 1, 2, 2, 2, 1e-5).unwrap()).unwrap();
211        assert_approx(&result, &[-1.0, 1.0, -1.0, 1.0], 1e-3);
212    }
213
214    #[test]
215    fn test_group_norm_single_group() {
216        // groups=1: all channels normalized together
217        // 1 batch, 4 channels, 1 spatial -> normalize all 4 values as one group
218        let data = vec![1.0, 2.0, 3.0, 4.0];
219        let gamma = vec![1.0; 4];
220        let beta = vec![0.0; 4];
221        let expected = cpu_group_norm(&data, &gamma, &beta, 1, 4, 1, 1, 1e-5);
222        let result = dev().read(&dev().group_norm(
223            &dev().upload(&data), &dev().upload(&gamma), &dev().upload(&beta),
224            1, 4, 1, 1, 1e-5
225        ).unwrap()).unwrap();
226        assert_approx(&result, &expected, 1e-3);
227    }
228
229    #[test]
230    fn test_group_norm_with_affine() {
231        let input = dev().upload(&[1.0, 3.0, 2.0, 4.0]);
232        let gamma = dev().upload(&[2.0, 0.5]);
233        let beta = dev().upload(&[1.0, -1.0]);
234        let result = dev().read(&dev().group_norm(&input, &gamma, &beta, 1, 2, 2, 2, 1e-5).unwrap()).unwrap();
235        assert_approx(&result, &[-1.0, 3.0, -1.5, -0.5], 1e-3);
236    }
237
238    #[test]
239    fn test_group_norm_batched_vs_cpu() {
240        // batch=2, channels=4, spatial=3, groups=2
241        let data: Vec<f32> = (0..24).map(|i| (i as f32) * 0.1 - 0.5).collect();
242        let gamma = vec![1.0, 2.0, 0.5, 1.5];
243        let beta = vec![0.0, 1.0, -1.0, 0.5];
244        let expected = cpu_group_norm(&data, &gamma, &beta, 2, 4, 3, 2, 1e-5);
245        let result = dev().read(&dev().group_norm(
246            &dev().upload(&data), &dev().upload(&gamma), &dev().upload(&beta),
247            2, 4, 3, 2, 1e-5
248        ).unwrap()).unwrap();
249        assert_approx(&result, &expected, 1e-3);
250    }
251
252    #[test]
253    fn test_group_norm_constant_input() {
254        // All same values -> normalized to 0 (var=0, eps prevents div by zero)
255        let data = vec![5.0; 8]; // 1 batch, 2 channels, 4 spatial, 2 groups
256        let gamma = vec![1.0, 1.0];
257        let beta = vec![0.0, 0.0];
258        let result = dev().read(&dev().group_norm(
259            &dev().upload(&data), &dev().upload(&gamma), &dev().upload(&beta),
260            1, 2, 4, 2, 1e-5
261        ).unwrap()).unwrap();
262        assert_approx(&result, &[0.0; 8], 1e-3);
263    }
264
265    #[test]
266    fn test_group_norm_channels_not_divisible() {
267        let data = vec![1.0; 5]; // 5 channels, groups=2 -> not divisible
268        let gamma = vec![1.0; 5];
269        let beta = vec![0.0; 5];
270        assert!(dev().group_norm(&dev().upload(&data), &dev().upload(&gamma), &dev().upload(&beta), 1, 5, 1, 2, 1e-5).is_err());
271    }
272
273    #[test]
274    fn test_group_norm_input_size_mismatch() {
275        let data = vec![1.0; 10]; // 10 elements but batch*channels*spatial = 1*4*4 = 16
276        let gamma = vec![1.0; 4];
277        let beta = vec![0.0; 4];
278        assert!(dev().group_norm(&dev().upload(&data), &dev().upload(&gamma), &dev().upload(&beta), 1, 4, 4, 2, 1e-5).is_err());
279    }
280}