1
2use crate::array::Array;
3use anyhow::{Result, anyhow};
4use std::borrow::Cow;
5use wgpu::util::DeviceExt;
6use crate::backend::webgpu::get_gpu_device;
7
8pub fn batch_norm_1d_training_webgpu(
9 _input: &Array,
10 _running_mean: &mut Array,
11 _running_var: &mut Array,
12 _weight: &Array,
13 _bias: &Array,
14 _momentum: f32,
15 _eps: f32,
16) -> Result<Array> {
17 Err(anyhow!("WebGPU BatchNorm Training not yet implemented"))
20}
21
22pub fn batch_norm_1d_inference_webgpu(
23 input: &Array,
24 running_mean: &Array,
25 running_var: &Array,
26 weight: &Array,
27 bias: &Array,
28 eps: f32,
29) -> Result<Array> {
30 let _batch_size = input.shape[0];
37 let channels = input.shape[1];
38 let spatial = if input.shape.len() > 2 { input.shape[2] } else { 1 };
39
40 let output_shape = input.shape.clone();
42 let num_elements = input.len();
43
44 let shader_src = format!(r#"
45 struct Uniforms {{
46 spatial_size: u32,
47 channels: u32,
48 eps: f32,
49 }};
50
51 @group(0) @binding(0) var<storage, read> input: array<f32>;
52 @group(0) @binding(1) var<storage, read> mean: array<f32>;
53 @group(0) @binding(2) var<storage, read> var_: array<f32>;
54 @group(0) @binding(3) var<storage, read> weight: array<f32>;
55 @group(0) @binding(4) var<storage, read> bias: array<f32>;
56 @group(0) @binding(5) var<storage, read_write> output: array<f32>;
57 @group(0) @binding(6) var<uniform> uniforms: Uniforms;
58
59 @compute @workgroup_size(64)
60 fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {{
61 let idx = global_id.x;
62 let total_elements = u32(arrayLength(&input));
63
64 if (idx >= total_elements) {{
65 return;
66 }}
67
68 // Map 1D index to (b, c, s)
69 // input is [Batch, Channel, Spatial] - Row Major
70 // idx = b * (C*S) + c * S + s
71
72 // spatial index s = idx % Spatial
73 // c = (idx / Spatial) % Channels
74 // b = idx / (Spatial * Channels)
75
76 let s = idx % uniforms.spatial_size;
77 let tmp = idx / uniforms.spatial_size;
78 let c = tmp % uniforms.channels;
79
80 // Read params for channel c
81 let m = mean[c];
82 let v = var_[c];
83 let w = weight[c];
84 let b = bias[c];
85 let x = input[idx];
86
87 // y = (x - mean) / sqrt(var + eps) * w + b
88 let norm = (x - m) / sqrt(v + uniforms.eps);
89 output[idx] = norm * w + b;
90 }}
91 "#);
92
93 #[repr(C)]
95 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
96 struct Uniforms {
97 spatial_size: u32,
98 channels: u32,
99 eps: f32,
100 }
101
102 let uniforms_data = Uniforms {
103 spatial_size: spatial as u32,
104 channels: channels as u32,
105 eps,
106 };
107
108 #[cfg(target_arch = "wasm32")]
109 let dq_owned = get_gpu_device()?;
110 #[cfg(target_arch = "wasm32")]
111 let dq = &dq_owned;
112
113 #[cfg(not(target_arch = "wasm32"))]
114 let dq = get_gpu_device()?;
115
116 let device = &dq.device;
117 let queue = &dq.queue;
118
119 let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
121 label: Some("BN Input"),
122 contents: bytemuck::cast_slice(&input.data),
123 usage: wgpu::BufferUsages::STORAGE,
124 });
125
126 let mean_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
127 label: Some("BN Mean"),
128 contents: bytemuck::cast_slice(&running_mean.data),
129 usage: wgpu::BufferUsages::STORAGE,
130 });
131
132 let var_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
133 label: Some("BN Var"),
134 contents: bytemuck::cast_slice(&running_var.data),
135 usage: wgpu::BufferUsages::STORAGE,
136 });
137
138 let weight_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
139 label: Some("BN Weight"),
140 contents: bytemuck::cast_slice(&weight.data),
141 usage: wgpu::BufferUsages::STORAGE,
142 });
143
144 let bias_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
145 label: Some("BN Bias"),
146 contents: bytemuck::cast_slice(&bias.data),
147 usage: wgpu::BufferUsages::STORAGE,
148 });
149
150 let output_size = (num_elements * std::mem::size_of::<f32>()) as u64;
151 let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
152 label: Some("BN Output"),
153 size: output_size,
154 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
155 mapped_at_creation: false,
156 });
157
158 let uniforms_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
159 label: Some("BN Uniforms"),
160 contents: bytemuck::bytes_of(&uniforms_data),
161 usage: wgpu::BufferUsages::UNIFORM,
162 });
163
164 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
166 label: Some("BN Shader"),
167 source: wgpu::ShaderSource::Wgsl(Cow::Owned(shader_src)),
168 });
169
170 let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
171 label: Some("BN Pipeline"),
172 layout: None,
173 module: &shader,
174 entry_point: Some("main"),
175 compilation_options: Default::default(),
176 cache: None,
177 });
178
179 let bind_group_layout = compute_pipeline.get_bind_group_layout(0);
180 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
181 label: Some("BN Bind Group"),
182 layout: &bind_group_layout,
183 entries: &[
184 wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
185 wgpu::BindGroupEntry { binding: 1, resource: mean_buffer.as_entire_binding() },
186 wgpu::BindGroupEntry { binding: 2, resource: var_buffer.as_entire_binding() },
187 wgpu::BindGroupEntry { binding: 3, resource: weight_buffer.as_entire_binding() },
188 wgpu::BindGroupEntry { binding: 4, resource: bias_buffer.as_entire_binding() },
189 wgpu::BindGroupEntry { binding: 5, resource: output_buffer.as_entire_binding() },
190 wgpu::BindGroupEntry { binding: 6, resource: uniforms_buffer.as_entire_binding() },
191 ],
192 });
193
194 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("BN Encoder") });
196 {
197 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("BN Pass"), timestamp_writes: None });
198 cpass.set_pipeline(&compute_pipeline);
199 cpass.set_bind_group(0, &bind_group, &[]);
200
201 let workgroup_size = 64;
202 let groups = (num_elements as u32 + workgroup_size - 1) / workgroup_size;
203 cpass.dispatch_workgroups(groups, 1, 1);
204 }
205
206 let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
208 label: Some("Staging Buffer"),
209 size: output_size,
210 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
211 mapped_at_creation: false,
212 });
213
214 encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, output_size);
215 queue.submit(Some(encoder.finish()));
216
217 let buffer_slice = staging_buffer.slice(..);
218 let (sender, receiver) = futures::channel::oneshot::channel();
219 buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
220
221 device.poll(wgpu::Maintain::Wait);
222
223 if let Ok(Ok(())) = pollster::block_on(receiver) {
224 let data = buffer_slice.get_mapped_range();
225 let result_vec: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
226 drop(data);
227 staging_buffer.unmap();
228
229 Ok(Array::new(output_shape, result_vec))
230 } else {
231 Err(anyhow!("Failed to read GPU buffer"))
232 }
233}