numrs/backend/webgpu/
batchnorm.rs

1
2use crate::array::Array;
3use anyhow::{Result, anyhow};
4use std::borrow::Cow;
5use wgpu::util::DeviceExt;
6use crate::backend::webgpu::get_gpu_device;
7
8pub fn batch_norm_1d_training_webgpu(
9    _input: &Array,
10    _running_mean: &mut Array,
11    _running_var: &mut Array,
12    _weight: &Array,
13    _bias: &Array,
14    _momentum: f32,
15    _eps: f32,
16) -> Result<Array> {
17    // Training requires reduction (mean/var calc) which is complex in WGSL.
18    // Fallback to CPU/SIMD for training.
19    Err(anyhow!("WebGPU BatchNorm Training not yet implemented"))
20}
21
22pub fn batch_norm_1d_inference_webgpu(
23    input: &Array,
24    running_mean: &Array,
25    running_var: &Array,
26    weight: &Array,
27    bias: &Array,
28    eps: f32,
29) -> Result<Array> {
30    // Shapes
31    // Input: [Batch, Channels, Length] or [Batch, Channels]
32    // We treat everything as [Batch, Channels, Spatial]
33    // If dims=2 [N, C], Spatial=1.
34    // If dims=3 [N, C, L], Spatial=L.
35    
36    let _batch_size = input.shape[0];
37    let channels = input.shape[1];
38    let spatial = if input.shape.len() > 2 { input.shape[2] } else { 1 };
39    
40    // Output same shape
41    let output_shape = input.shape.clone();
42    let num_elements = input.len();
43
44    let shader_src = format!(r#"
45    struct Uniforms {{
46        spatial_size: u32,
47        channels: u32,
48        eps: f32,
49    }};
50    
51    @group(0) @binding(0) var<storage, read> input: array<f32>;
52    @group(0) @binding(1) var<storage, read> mean: array<f32>;
53    @group(0) @binding(2) var<storage, read> var_: array<f32>;
54    @group(0) @binding(3) var<storage, read> weight: array<f32>;
55    @group(0) @binding(4) var<storage, read> bias: array<f32>;
56    @group(0) @binding(5) var<storage, read_write> output: array<f32>;
57    @group(0) @binding(6) var<uniform> uniforms: Uniforms;
58    
59    @compute @workgroup_size(64)
60    fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {{
61        let idx = global_id.x;
62        let total_elements = u32(arrayLength(&input));
63        
64        if (idx >= total_elements) {{
65            return;
66        }}
67        
68        // Map 1D index to (b, c, s)
69        // input is [Batch, Channel, Spatial] - Row Major
70        // idx = b * (C*S) + c * S + s
71        
72        // spatial index s = idx % Spatial
73        // c = (idx / Spatial) % Channels
74        // b = idx / (Spatial * Channels)
75        
76        let s = idx % uniforms.spatial_size;
77        let tmp = idx / uniforms.spatial_size;
78        let c = tmp % uniforms.channels;
79        
80        // Read params for channel c
81        let m = mean[c];
82        let v = var_[c];
83        let w = weight[c];
84        let b = bias[c];
85        let x = input[idx];
86        
87        // y = (x - mean) / sqrt(var + eps) * w + b
88        let norm = (x - m) / sqrt(v + uniforms.eps);
89        output[idx] = norm * w + b;
90    }}
91    "#);
92    
93    // Uniforms
94    #[repr(C)]
95    #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
96    struct Uniforms {
97        spatial_size: u32,
98        channels: u32,
99        eps: f32,
100    }
101    
102    let uniforms_data = Uniforms {
103        spatial_size: spatial as u32,
104        channels: channels as u32,
105        eps,
106    };
107    
108    #[cfg(target_arch = "wasm32")]
109    let dq_owned = get_gpu_device()?;
110    #[cfg(target_arch = "wasm32")]
111    let dq = &dq_owned;
112    
113    #[cfg(not(target_arch = "wasm32"))]
114    let dq = get_gpu_device()?;
115
116    let device = &dq.device;
117    let queue = &dq.queue;
118    
119    // Buffers
120    let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
121        label: Some("BN Input"),
122        contents: bytemuck::cast_slice(&input.data),
123        usage: wgpu::BufferUsages::STORAGE,
124    });
125    
126    let mean_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
127        label: Some("BN Mean"),
128        contents: bytemuck::cast_slice(&running_mean.data),
129        usage: wgpu::BufferUsages::STORAGE,
130    });
131    
132    let var_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
133        label: Some("BN Var"),
134        contents: bytemuck::cast_slice(&running_var.data),
135        usage: wgpu::BufferUsages::STORAGE,
136    });
137    
138    let weight_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
139        label: Some("BN Weight"),
140        contents: bytemuck::cast_slice(&weight.data),
141        usage: wgpu::BufferUsages::STORAGE,
142    });
143    
144    let bias_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
145        label: Some("BN Bias"),
146        contents: bytemuck::cast_slice(&bias.data),
147        usage: wgpu::BufferUsages::STORAGE,
148    });
149    
150    let output_size = (num_elements * std::mem::size_of::<f32>()) as u64;
151    let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
152        label: Some("BN Output"),
153        size: output_size,
154        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
155        mapped_at_creation: false,
156    });
157    
158    let uniforms_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
159        label: Some("BN Uniforms"),
160        contents: bytemuck::bytes_of(&uniforms_data),
161        usage: wgpu::BufferUsages::UNIFORM,
162    });
163    
164    // Pipeline
165    let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
166        label: Some("BN Shader"),
167        source: wgpu::ShaderSource::Wgsl(Cow::Owned(shader_src)),
168    });
169    
170    let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
171        label: Some("BN Pipeline"),
172        layout: None,
173        module: &shader,
174        entry_point: Some("main"),
175        compilation_options: Default::default(),
176        cache: None,
177    });
178    
179    let bind_group_layout = compute_pipeline.get_bind_group_layout(0);
180    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
181        label: Some("BN Bind Group"),
182        layout: &bind_group_layout,
183        entries: &[
184            wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
185            wgpu::BindGroupEntry { binding: 1, resource: mean_buffer.as_entire_binding() },
186            wgpu::BindGroupEntry { binding: 2, resource: var_buffer.as_entire_binding() },
187            wgpu::BindGroupEntry { binding: 3, resource: weight_buffer.as_entire_binding() },
188            wgpu::BindGroupEntry { binding: 4, resource: bias_buffer.as_entire_binding() },
189            wgpu::BindGroupEntry { binding: 5, resource: output_buffer.as_entire_binding() },
190            wgpu::BindGroupEntry { binding: 6, resource: uniforms_buffer.as_entire_binding() },
191        ],
192    });
193    
194    // Dispatch
195    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("BN Encoder") });
196    {
197        let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("BN Pass"), timestamp_writes: None });
198        cpass.set_pipeline(&compute_pipeline);
199        cpass.set_bind_group(0, &bind_group, &[]);
200        
201        let workgroup_size = 64;
202        let groups = (num_elements as u32 + workgroup_size - 1) / workgroup_size;
203        cpass.dispatch_workgroups(groups, 1, 1);
204    }
205    
206    // Readback
207    let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
208        label: Some("Staging Buffer"),
209        size: output_size,
210        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
211        mapped_at_creation: false,
212    });
213    
214    encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, output_size);
215    queue.submit(Some(encoder.finish()));
216    
217    let buffer_slice = staging_buffer.slice(..);
218    let (sender, receiver) = futures::channel::oneshot::channel();
219    buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
220    
221    device.poll(wgpu::Maintain::Wait);
222    
223    if let Ok(Ok(())) = pollster::block_on(receiver) {
224        let data = buffer_slice.get_mapped_range();
225        let result_vec: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
226        drop(data);
227        staging_buffer.unmap();
228        
229        Ok(Array::new(output_shape, result_vec))
230    } else {
231        Err(anyhow!("Failed to read GPU buffer"))
232    }
233}