numrs/backend/webgpu/
conv.rs1
2use crate::array::Array;
3use anyhow::{Result, anyhow};
4use std::borrow::Cow;
5use wgpu::util::DeviceExt;
6use crate::backend::webgpu::get_gpu_device;
7
8pub fn conv1d_webgpu(
10 input: &Array,
11 weight: &Array,
12 bias: Option<&Array>,
13 stride: usize,
14 padding: usize,
15) -> Result<Array> {
16 let batch_size = input.shape[0];
18 let in_channels = input.shape[1];
19 let in_length = input.shape[2];
20
21 let out_channels = weight.shape[0];
22 let kernel_size = weight.shape[2];
23
24 let out_length = (in_length + 2 * padding - kernel_size) / stride + 1;
26 let output_shape = vec![batch_size, out_channels, out_length];
27
28 let shader_src = format!(r#"
30 struct Uniforms {{
31 batch_size: u32,
32 in_channels: u32,
33 in_length: u32,
34 out_channels: u32,
35 kernel_size: u32,
36 out_length: u32,
37 stride: u32,
38 padding: u32,
39 }};
40
41 @group(0) @binding(0) var<storage, read> input: array<f32>;
42 @group(0) @binding(1) var<storage, read> weight: array<f32>;
43 @group(0) @binding(2) var<storage, read> bias: array<f32>;
44 @group(0) @binding(3) var<storage, read_write> output: array<f32>;
45 @group(0) @binding(4) var<uniform> uniforms: Uniforms;
46
47 @compute @workgroup_size(64)
48 fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {{
49 // x: out_index (time step)
50 // y: out_channel
51 // z: batch_index
52
53 let out_t = global_id.x;
54 let out_c = global_id.y;
55 let b = global_id.z;
56
57 if (out_t >= uniforms.out_length || out_c >= uniforms.out_channels || b >= uniforms.batch_size) {{
58 return;
59 }}
60
61 // Init accumulator with bias
62 var sum = bias[out_c];
63
64 // Loop over InChannels and Kernel
65 for (var ic = 0u; ic < uniforms.in_channels; ic = ic + 1u) {{
66 for (var k = 0u; k < uniforms.kernel_size; k = k + 1u) {{
67 // Determine input time index
68 // in_t = out_t * stride + k - padding
69 // Signed arithmetic needed for padding check
70 let in_t_signed = i32(out_t * uniforms.stride + k) - i32(uniforms.padding);
71
72 if (in_t_signed >= 0 && in_t_signed < i32(uniforms.in_length)) {{
73 let in_t = u32(in_t_signed);
74
75 // Input index: b * (C_in * LEN_in) + ic * LEN_in + t
76 let in_idx = b * (uniforms.in_channels * uniforms.in_length) + ic * uniforms.in_length + in_t;
77
78 // Weight index: out_c * (C_in * K) + ic * K + k
79 let w_idx = out_c * (uniforms.in_channels * uniforms.kernel_size) + ic * uniforms.kernel_size + k;
80
81 sum = sum + input[in_idx] * weight[w_idx];
82 }}
83 }}
84 }}
85
86 // Output index: b * (C_out * LEN_out) + out_c * LEN_out + out_t
87 let out_idx = b * (uniforms.out_channels * uniforms.out_length) + out_c * uniforms.out_length + out_t;
88 output[out_idx] = sum;
89 }}
90 "#);
91
92 let uniforms_data = [
94 batch_size as u32,
95 in_channels as u32,
96 in_length as u32,
97 out_channels as u32,
98 kernel_size as u32,
99 out_length as u32,
100 stride as u32,
101 padding as u32,
102 ];
103
104 #[cfg(target_arch = "wasm32")]
105 let dq_owned = get_gpu_device()?;
106 #[cfg(target_arch = "wasm32")]
107 let dq = &dq_owned;
108
109 #[cfg(not(target_arch = "wasm32"))]
110 let dq = get_gpu_device()?;
111
112 let device = &dq.device;
113 let queue = &dq.queue;
114
115 let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
117 label: Some("Conv1D Input"),
118 contents: bytemuck::cast_slice(&input.data),
119 usage: wgpu::BufferUsages::STORAGE,
120 });
121
122 let weight_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
123 label: Some("Conv1D Weight"),
124 contents: bytemuck::cast_slice(&weight.data),
125 usage: wgpu::BufferUsages::STORAGE,
126 });
127
128 let bias_content = if let Some(b) = bias {
130 Cow::Borrowed(&b.data)
131 } else {
132 Cow::Owned(vec![0.0f32; out_channels])
133 };
134
135 let bias_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
136 label: Some("Conv1D Bias"),
137 contents: bytemuck::cast_slice(&bias_content),
138 usage: wgpu::BufferUsages::STORAGE,
139 });
140
141 let output_size = (batch_size * out_channels * out_length * std::mem::size_of::<f32>()) as u64;
142 let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
143 label: Some("Conv1D Output"),
144 size: output_size,
145 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
146 mapped_at_creation: false,
147 });
148
149 let uniforms_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
150 label: Some("Conv1D Uniforms"),
151 contents: bytemuck::cast_slice(&uniforms_data),
152 usage: wgpu::BufferUsages::UNIFORM,
153 });
154
155 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
157 label: Some("Conv1D Shader"),
158 source: wgpu::ShaderSource::Wgsl(Cow::Owned(shader_src)),
159 });
160
161 let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
162 label: Some("Conv1D Pipeline"),
163 layout: None,
164 module: &shader,
165 entry_point: Some("main"),
166 compilation_options: Default::default(),
167 cache: None,
168 });
169
170 let bind_group_layout = compute_pipeline.get_bind_group_layout(0);
171 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
172 label: Some("Conv1D Bind Group"),
173 layout: &bind_group_layout,
174 entries: &[
175 wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
176 wgpu::BindGroupEntry { binding: 1, resource: weight_buffer.as_entire_binding() },
177 wgpu::BindGroupEntry { binding: 2, resource: bias_buffer.as_entire_binding() },
178 wgpu::BindGroupEntry { binding: 3, resource: output_buffer.as_entire_binding() },
179 wgpu::BindGroupEntry { binding: 4, resource: uniforms_buffer.as_entire_binding() },
180 ],
181 });
182
183 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("Conv1D Encoder") });
185 {
186 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("Conv1D Pass"), timestamp_writes: None });
187 cpass.set_pipeline(&compute_pipeline);
188 cpass.set_bind_group(0, &bind_group, &[]);
189
190 let workgroup_size_x = 64;
191 let group_x = (out_length as u32 + workgroup_size_x - 1) / workgroup_size_x;
192 cpass.dispatch_workgroups(group_x, out_channels as u32, batch_size as u32);
193 }
194
195 let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
197 label: Some("Staging Buffer"),
198 size: output_size,
199 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
200 mapped_at_creation: false,
201 });
202
203 encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, output_size);
204 queue.submit(Some(encoder.finish()));
205
206 let buffer_slice = staging_buffer.slice(..);
207 let (sender, receiver) = futures::channel::oneshot::channel();
208 buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
209
210 device.poll(wgpu::Maintain::Wait);
211
212 if let Ok(Ok(())) = pollster::block_on(receiver) {
213 let data = buffer_slice.get_mapped_range();
214 let result_vec: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
215 drop(data);
216 staging_buffer.unmap();
217
218 Ok(Array::new(output_shape, result_vec))
219 } else {
220 Err(anyhow!("Failed to read GPU buffer"))
221 }
222}