numrs/backend/webgpu/
conv.rs

1
2use crate::array::Array;
3use anyhow::{Result, anyhow};
4use std::borrow::Cow;
5use wgpu::util::DeviceExt;
6use crate::backend::webgpu::get_gpu_device;
7
8/// WebGPU Conv1D Implementation
9pub fn conv1d_webgpu(
10    input: &Array,
11    weight: &Array,
12    bias: Option<&Array>,
13    stride: usize,
14    padding: usize,
15) -> Result<Array> {
16    // 1. Validate shapes
17    let batch_size = input.shape[0];
18    let in_channels = input.shape[1];
19    let in_length = input.shape[2];
20    
21    let out_channels = weight.shape[0];
22    let kernel_size = weight.shape[2];
23    
24    // Output length calculation
25    let out_length = (in_length + 2 * padding - kernel_size) / stride + 1;
26    let output_shape = vec![batch_size, out_channels, out_length];
27    
28    // Shader Source
29    let shader_src = format!(r#"
30    struct Uniforms {{
31        batch_size: u32,
32        in_channels: u32,
33        in_length: u32,
34        out_channels: u32,
35        kernel_size: u32,
36        out_length: u32,
37        stride: u32,
38        padding: u32,
39    }};
40    
41    @group(0) @binding(0) var<storage, read> input: array<f32>;
42    @group(0) @binding(1) var<storage, read> weight: array<f32>;
43    @group(0) @binding(2) var<storage, read> bias: array<f32>;
44    @group(0) @binding(3) var<storage, read_write> output: array<f32>;
45    @group(0) @binding(4) var<uniform> uniforms: Uniforms;
46    
47    @compute @workgroup_size(64)
48    fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {{
49        // x: out_index (time step)
50        // y: out_channel
51        // z: batch_index
52        
53        let out_t = global_id.x;
54        let out_c = global_id.y;
55        let b = global_id.z;
56        
57        if (out_t >= uniforms.out_length || out_c >= uniforms.out_channels || b >= uniforms.batch_size) {{
58            return;
59        }}
60        
61        // Init accumulator with bias
62        var sum = bias[out_c];
63        
64        // Loop over InChannels and Kernel
65        for (var ic = 0u; ic < uniforms.in_channels; ic = ic + 1u) {{
66            for (var k = 0u; k < uniforms.kernel_size; k = k + 1u) {{
67                // Determine input time index
68                // in_t = out_t * stride + k - padding
69                // Signed arithmetic needed for padding check
70                let in_t_signed = i32(out_t * uniforms.stride + k) - i32(uniforms.padding);
71                
72                if (in_t_signed >= 0 && in_t_signed < i32(uniforms.in_length)) {{
73                    let in_t = u32(in_t_signed);
74                    
75                    // Input index: b * (C_in * LEN_in) + ic * LEN_in + t
76                    let in_idx = b * (uniforms.in_channels * uniforms.in_length) + ic * uniforms.in_length + in_t;
77                    
78                    // Weight index: out_c * (C_in * K) + ic * K + k
79                    let w_idx = out_c * (uniforms.in_channels * uniforms.kernel_size) + ic * uniforms.kernel_size + k;
80                    
81                    sum = sum + input[in_idx] * weight[w_idx];
82                }}
83            }}
84        }}
85        
86        // Output index: b * (C_out * LEN_out) + out_c * LEN_out + out_t
87        let out_idx = b * (uniforms.out_channels * uniforms.out_length) + out_c * uniforms.out_length + out_t;
88        output[out_idx] = sum;
89    }}
90    "#);
91
92    // 2. Uniforms Data
93    let uniforms_data = [
94        batch_size as u32,
95        in_channels as u32,
96        in_length as u32,
97        out_channels as u32,
98        kernel_size as u32,
99        out_length as u32,
100        stride as u32,
101        padding as u32,
102    ];
103
104    #[cfg(target_arch = "wasm32")]
105    let dq_owned = get_gpu_device()?;
106    #[cfg(target_arch = "wasm32")]
107    let dq = &dq_owned;
108    
109    #[cfg(not(target_arch = "wasm32"))]
110    let dq = get_gpu_device()?;
111
112    let device = &dq.device;
113    let queue = &dq.queue;
114
115    // 3. Create Buffers
116    let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
117        label: Some("Conv1D Input"),
118        contents: bytemuck::cast_slice(&input.data),
119        usage: wgpu::BufferUsages::STORAGE,
120    });
121    
122    let weight_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
123        label: Some("Conv1D Weight"),
124        contents: bytemuck::cast_slice(&weight.data),
125        usage: wgpu::BufferUsages::STORAGE,
126    });
127    
128    // Handle optional bias (create zero buffer if None)
129    let bias_content = if let Some(b) = bias {
130        Cow::Borrowed(&b.data)
131    } else {
132        Cow::Owned(vec![0.0f32; out_channels])
133    };
134    
135    let bias_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
136        label: Some("Conv1D Bias"),
137        contents: bytemuck::cast_slice(&bias_content),
138        usage: wgpu::BufferUsages::STORAGE,
139    });
140    
141    let output_size = (batch_size * out_channels * out_length * std::mem::size_of::<f32>()) as u64;
142    let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
143        label: Some("Conv1D Output"),
144        size: output_size,
145        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
146        mapped_at_creation: false,
147    });
148    
149    let uniforms_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
150        label: Some("Conv1D Uniforms"),
151        contents: bytemuck::cast_slice(&uniforms_data),
152        usage: wgpu::BufferUsages::UNIFORM,
153    });
154
155    // 4. Pipeline Setup
156    let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
157        label: Some("Conv1D Shader"),
158        source: wgpu::ShaderSource::Wgsl(Cow::Owned(shader_src)),
159    });
160    
161    let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
162        label: Some("Conv1D Pipeline"),
163        layout: None,
164        module: &shader,
165        entry_point: Some("main"),
166        compilation_options: Default::default(),
167        cache: None,
168    });
169    
170    let bind_group_layout = compute_pipeline.get_bind_group_layout(0);
171    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
172        label: Some("Conv1D Bind Group"),
173        layout: &bind_group_layout,
174        entries: &[
175            wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
176            wgpu::BindGroupEntry { binding: 1, resource: weight_buffer.as_entire_binding() },
177            wgpu::BindGroupEntry { binding: 2, resource: bias_buffer.as_entire_binding() },
178            wgpu::BindGroupEntry { binding: 3, resource: output_buffer.as_entire_binding() },
179            wgpu::BindGroupEntry { binding: 4, resource: uniforms_buffer.as_entire_binding() },
180        ],
181    });
182
183    // 5. Dispatch
184    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("Conv1D Encoder") });
185    {
186        let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("Conv1D Pass"), timestamp_writes: None });
187        cpass.set_pipeline(&compute_pipeline);
188        cpass.set_bind_group(0, &bind_group, &[]);
189        
190        let workgroup_size_x = 64;
191        let group_x = (out_length as u32 + workgroup_size_x - 1) / workgroup_size_x;
192        cpass.dispatch_workgroups(group_x, out_channels as u32, batch_size as u32);
193    }
194    
195    // 6. Readback
196    let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
197        label: Some("Staging Buffer"),
198        size: output_size,
199        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
200        mapped_at_creation: false,
201    });
202    
203    encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, output_size);
204    queue.submit(Some(encoder.finish()));
205    
206    let buffer_slice = staging_buffer.slice(..);
207    let (sender, receiver) = futures::channel::oneshot::channel();
208    buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
209    
210    device.poll(wgpu::Maintain::Wait);
211    
212    if let Ok(Ok(())) = pollster::block_on(receiver) {
213        let data = buffer_slice.get_mapped_range();
214        let result_vec: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
215        drop(data);
216        staging_buffer.unmap();
217        
218        Ok(Array::new(output_shape, result_vec))
219    } else {
220        Err(anyhow!("Failed to read GPU buffer"))
221    }
222}