1use crate::device::{GpuBuffer, GpuDevice};
7use anyhow::{ensure, Result};
8
9#[repr(C)]
10#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
11struct GNStatsParams {
12 batch: u32,
13 channels: u32,
14 spatial: u32,
15 groups: u32,
16 channels_per_group: u32,
17 eps: f32,
18 _pad: [u32; 2],
19}
20
21const SHADER_GN_STATS: &str = "
23struct P {
24 batch: u32, channels: u32, spatial: u32, groups: u32,
25 cpg: u32, eps: f32, _p0: u32, _p1: u32,
26}
27@group(0) @binding(0) var<uniform> p: P;
28@group(0) @binding(1) var<storage, read> input: array<f32>;
29@group(0) @binding(2) var<storage, read_write> stats: array<f32>;
30@compute @workgroup_size(256)
31fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
32 let idx = gid.x;
33 let total = p.batch * p.groups;
34 if idx >= total { return; }
35
36 let g = idx % p.groups;
37 let n = idx / p.groups;
38
39 let count = p.cpg * p.spatial;
40 var sum: f32 = 0.0;
41 var sum_sq: f32 = 0.0;
42 for (var c: u32 = 0u; c < p.cpg; c++) {
43 let ch = g * p.cpg + c;
44 let base = n * (p.channels * p.spatial) + ch * p.spatial;
45 for (var s: u32 = 0u; s < p.spatial; s++) {
46 let v = input[base + s];
47 sum += v;
48 sum_sq += v * v;
49 }
50 }
51 let mean = sum / f32(count);
52 let variance = sum_sq / f32(count) - mean * mean;
53 stats[idx * 2u] = mean;
54 stats[idx * 2u + 1u] = 1.0 / sqrt(variance + p.eps);
55}
56";
57
58const SHADER_GN_NORM: &str = "
60struct P {
61 batch: u32, channels: u32, spatial: u32, groups: u32,
62 cpg: u32, eps: f32, _p0: u32, _p1: u32,
63}
64@group(0) @binding(0) var<uniform> p: P;
65@group(0) @binding(1) var<storage, read> input: array<f32>;
66@group(0) @binding(2) var<storage, read> stats: array<f32>;
67@group(0) @binding(3) var<storage, read> gamma: array<f32>;
68@group(0) @binding(4) var<storage, read> beta: array<f32>;
69@group(0) @binding(5) var<storage, read_write> out: array<f32>;
70@compute @workgroup_size(256)
71fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
72 let idx = gid.x + gid.y * 65535u * 256u;
73 let total = p.batch * p.channels * p.spatial;
74 if idx >= total { return; }
75
76 let s = idx % p.spatial;
77 let ch = (idx / p.spatial) % p.channels;
78 let n = idx / (p.spatial * p.channels);
79 let g = ch / p.cpg;
80
81 let stat_idx = n * p.groups + g;
82 let mean = stats[stat_idx * 2u];
83 let inv_std = stats[stat_idx * 2u + 1u];
84
85 out[idx] = (input[idx] - mean) * inv_std * gamma[ch] + beta[ch];
86}
87";
88
89impl GpuDevice {
90 pub fn group_norm(
93 &self,
94 input: &GpuBuffer,
95 gamma: &GpuBuffer,
96 beta: &GpuBuffer,
97 batch: u32, channels: u32, spatial: u32, groups: u32,
98 eps: f32,
99 ) -> Result<GpuBuffer> {
100 ensure!(input.len == (batch * channels * spatial) as usize);
101 ensure!(gamma.len == channels as usize);
102 ensure!(beta.len == channels as usize);
103 ensure!(channels % groups == 0, "channels must be divisible by groups");
104
105 let cpg = channels / groups;
106 let params = GNStatsParams { batch, channels, spatial, groups, channels_per_group: cpg, eps, _pad: [0; 2] };
107
108 let stats = self.alloc((batch * groups * 2) as usize);
110 self.dispatch_shader(
111 SHADER_GN_STATS, Some("gn_stats"),
112 ¶ms, &[input], &stats,
113 super::dispatch_1d(batch * groups),
114 );
115
116 let total = batch * channels * spatial;
118 let out = self.alloc(total as usize);
119
120 let params_buf = self.upload_uniform(¶ms);
122 let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
123 label: Some("gn_norm"),
124 source: wgpu::ShaderSource::Wgsl(SHADER_GN_NORM.into()),
125 });
126 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
127 label: Some("gn_norm"),
128 layout: None,
129 module: &shader,
130 entry_point: Some("main"),
131 compilation_options: Default::default(),
132 cache: None,
133 });
134 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
135 label: None,
136 layout: &pipeline.get_bind_group_layout(0),
137 entries: &[
138 wgpu::BindGroupEntry { binding: 0, resource: params_buf.as_entire_binding() },
139 wgpu::BindGroupEntry { binding: 1, resource: input.buffer.as_entire_binding() },
140 wgpu::BindGroupEntry { binding: 2, resource: stats.buffer.as_entire_binding() },
141 wgpu::BindGroupEntry { binding: 3, resource: gamma.buffer.as_entire_binding() },
142 wgpu::BindGroupEntry { binding: 4, resource: beta.buffer.as_entire_binding() },
143 wgpu::BindGroupEntry { binding: 5, resource: out.buffer.as_entire_binding() },
144 ],
145 });
146 let (wx, wy, wz) = super::dispatch_1d(total);
147 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
148 {
149 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
150 label: Some("gn_norm"),
151 timestamp_writes: None,
152 });
153 pass.set_pipeline(&pipeline);
154 pass.set_bind_group(0, &bind_group, &[]);
155 pass.dispatch_workgroups(wx, wy, wz);
156 }
157 self.queue.submit(Some(encoder.finish()));
158
159 Ok(out)
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use crate::ops::assert_approx;
167 fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
168
169 fn cpu_group_norm(
171 input: &[f32], gamma: &[f32], beta: &[f32],
172 batch: usize, channels: usize, spatial: usize, groups: usize, eps: f32,
173 ) -> Vec<f32> {
174 let cpg = channels / groups;
175 let mut out = vec![0.0f32; input.len()];
176 for n in 0..batch {
177 for g in 0..groups {
178 let mut sum = 0.0f32;
179 let mut sum_sq = 0.0f32;
180 let count = (cpg * spatial) as f32;
181 for c in 0..cpg {
182 let ch = g * cpg + c;
183 for s in 0..spatial {
184 let v = input[n * channels * spatial + ch * spatial + s];
185 sum += v;
186 sum_sq += v * v;
187 }
188 }
189 let mean = sum / count;
190 let var = sum_sq / count - mean * mean;
191 let inv_std = 1.0 / (var + eps).sqrt();
192 for c in 0..cpg {
193 let ch = g * cpg + c;
194 for s in 0..spatial {
195 let idx = n * channels * spatial + ch * spatial + s;
196 out[idx] = (input[idx] - mean) * inv_std * gamma[ch] + beta[ch];
197 }
198 }
199 }
200 }
201 out
202 }
203
204 #[test]
205 fn test_group_norm_per_channel() {
206 let input = dev().upload(&[1.0, 3.0, 2.0, 4.0]);
208 let gamma = dev().upload(&[1.0, 1.0]);
209 let beta = dev().upload(&[0.0, 0.0]);
210 let result = dev().read(&dev().group_norm(&input, &gamma, &beta, 1, 2, 2, 2, 1e-5).unwrap()).unwrap();
211 assert_approx(&result, &[-1.0, 1.0, -1.0, 1.0], 1e-3);
212 }
213
214 #[test]
215 fn test_group_norm_single_group() {
216 let data = vec![1.0, 2.0, 3.0, 4.0];
219 let gamma = vec![1.0; 4];
220 let beta = vec![0.0; 4];
221 let expected = cpu_group_norm(&data, &gamma, &beta, 1, 4, 1, 1, 1e-5);
222 let result = dev().read(&dev().group_norm(
223 &dev().upload(&data), &dev().upload(&gamma), &dev().upload(&beta),
224 1, 4, 1, 1, 1e-5
225 ).unwrap()).unwrap();
226 assert_approx(&result, &expected, 1e-3);
227 }
228
229 #[test]
230 fn test_group_norm_with_affine() {
231 let input = dev().upload(&[1.0, 3.0, 2.0, 4.0]);
232 let gamma = dev().upload(&[2.0, 0.5]);
233 let beta = dev().upload(&[1.0, -1.0]);
234 let result = dev().read(&dev().group_norm(&input, &gamma, &beta, 1, 2, 2, 2, 1e-5).unwrap()).unwrap();
235 assert_approx(&result, &[-1.0, 3.0, -1.5, -0.5], 1e-3);
236 }
237
238 #[test]
239 fn test_group_norm_batched_vs_cpu() {
240 let data: Vec<f32> = (0..24).map(|i| (i as f32) * 0.1 - 0.5).collect();
242 let gamma = vec![1.0, 2.0, 0.5, 1.5];
243 let beta = vec![0.0, 1.0, -1.0, 0.5];
244 let expected = cpu_group_norm(&data, &gamma, &beta, 2, 4, 3, 2, 1e-5);
245 let result = dev().read(&dev().group_norm(
246 &dev().upload(&data), &dev().upload(&gamma), &dev().upload(&beta),
247 2, 4, 3, 2, 1e-5
248 ).unwrap()).unwrap();
249 assert_approx(&result, &expected, 1e-3);
250 }
251
252 #[test]
253 fn test_group_norm_constant_input() {
254 let data = vec![5.0; 8]; let gamma = vec![1.0, 1.0];
257 let beta = vec![0.0, 0.0];
258 let result = dev().read(&dev().group_norm(
259 &dev().upload(&data), &dev().upload(&gamma), &dev().upload(&beta),
260 1, 2, 4, 2, 1e-5
261 ).unwrap()).unwrap();
262 assert_approx(&result, &[0.0; 8], 1e-3);
263 }
264
265 #[test]
266 fn test_group_norm_channels_not_divisible() {
267 let data = vec![1.0; 5]; let gamma = vec![1.0; 5];
269 let beta = vec![0.0; 5];
270 assert!(dev().group_norm(&dev().upload(&data), &dev().upload(&gamma), &dev().upload(&beta), 1, 5, 1, 2, 1e-5).is_err());
271 }
272
273 #[test]
274 fn test_group_norm_input_size_mismatch() {
275 let data = vec![1.0; 10]; let gamma = vec![1.0; 4];
277 let beta = vec![0.0; 4];
278 assert!(dev().group_norm(&dev().upload(&data), &dev().upload(&gamma), &dev().upload(&beta), 1, 4, 4, 2, 1e-5).is_err());
279 }
280}