trueno/backends/gpu/batch/execute/
dispatch.rs1use super::super::GpuCommandBatch;
8use std::collections::HashMap;
9
10const WORKGROUP_SIZE: u32 = 256;
12
13pub struct CachedPipeline {
16 pub(crate) pipeline: wgpu::ComputePipeline,
17 pub(crate) bind_group_layout: wgpu::BindGroupLayout,
18}
19
20pub type PipelineCache = HashMap<usize, CachedPipeline>;
26
27fn cache_key(shader_source: &str) -> usize {
30 shader_source.as_ptr() as usize
31}
32
33impl GpuCommandBatch {
34 #[allow(clippy::map_entry)]
38 pub(crate) fn encode_unary_op<T: bytemuck::Pod>(
39 &self,
40 encoder: &mut wgpu::CommandEncoder,
41 cache: &mut PipelineCache,
42 shader_source: &str,
43 label: &str,
44 input_buffer: &wgpu::Buffer,
45 output_buffer: &wgpu::Buffer,
46 size: usize,
47 params: Option<&T>,
48 ) -> Result<(), String> {
49 let key = cache_key(shader_source);
50 let has_params = params.is_some();
51
52 if !cache.contains_key(&key) {
54 let shader = self.device.device.create_shader_module(wgpu::ShaderModuleDescriptor {
55 label: Some(&format!("{} Shader", label)),
56 source: wgpu::ShaderSource::Wgsl(shader_source.into()),
57 });
58
59 let mut layout_entries = vec![
60 wgpu::BindGroupLayoutEntry {
61 binding: 0,
62 visibility: wgpu::ShaderStages::COMPUTE,
63 ty: wgpu::BindingType::Buffer {
64 ty: wgpu::BufferBindingType::Storage { read_only: true },
65 has_dynamic_offset: false,
66 min_binding_size: None,
67 },
68 count: None,
69 },
70 wgpu::BindGroupLayoutEntry {
71 binding: 1,
72 visibility: wgpu::ShaderStages::COMPUTE,
73 ty: wgpu::BindingType::Buffer {
74 ty: wgpu::BufferBindingType::Storage { read_only: false },
75 has_dynamic_offset: false,
76 min_binding_size: None,
77 },
78 count: None,
79 },
80 ];
81
82 if has_params {
83 layout_entries.push(wgpu::BindGroupLayoutEntry {
84 binding: 2,
85 visibility: wgpu::ShaderStages::COMPUTE,
86 ty: wgpu::BindingType::Buffer {
87 ty: wgpu::BufferBindingType::Uniform,
88 has_dynamic_offset: false,
89 min_binding_size: None,
90 },
91 count: None,
92 });
93 }
94
95 let bind_group_layout =
96 self.device.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
97 label: Some(&format!("{} Layout", label)),
98 entries: &layout_entries,
99 });
100
101 let pipeline_layout =
102 self.device.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
103 label: Some(&format!("{} PipelineLayout", label)),
104 bind_group_layouts: &[&bind_group_layout],
105 push_constant_ranges: &[],
106 });
107
108 let pipeline =
109 self.device.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
110 label: Some(&format!("{} Pipeline", label)),
111 layout: Some(&pipeline_layout),
112 module: &shader,
113 entry_point: Some("main"),
114 compilation_options: Default::default(),
115 cache: None,
116 });
117
118 cache.insert(key, CachedPipeline { pipeline, bind_group_layout });
119 }
120
121 let cached = cache.get(&key).expect("pipeline just inserted");
122
123 let params_buffer = if let Some(params_data) = params {
125 let buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
126 label: Some(&format!("{} Params", label)),
127 size: std::mem::size_of::<T>() as u64,
128 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
129 mapped_at_creation: false,
130 });
131 self.device.queue.write_buffer(&buffer, 0, bytemuck::bytes_of(params_data));
132 Some(buffer)
133 } else {
134 None
135 };
136
137 let mut bind_entries = vec![
139 wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
140 wgpu::BindGroupEntry { binding: 1, resource: output_buffer.as_entire_binding() },
141 ];
142
143 if let Some(ref buffer) = params_buffer {
144 bind_entries
145 .push(wgpu::BindGroupEntry { binding: 2, resource: buffer.as_entire_binding() });
146 }
147
148 let bind_group = self.device.device.create_bind_group(&wgpu::BindGroupDescriptor {
149 label: Some(&format!("{} BindGroup", label)),
150 layout: &cached.bind_group_layout,
151 entries: &bind_entries,
152 });
153
154 {
156 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
157 label: Some(&format!("{} Pass", label)),
158 timestamp_writes: None,
159 });
160 pass.set_pipeline(&cached.pipeline);
161 pass.set_bind_group(0, &bind_group, &[]);
162 pass.dispatch_workgroups((size as u32).div_ceil(WORKGROUP_SIZE), 1, 1);
163 }
164
165 Ok(())
166 }
167
168 #[allow(clippy::map_entry)]
172 pub(crate) fn encode_matmul_op(
173 &self,
174 encoder: &mut wgpu::CommandEncoder,
175 cache: &mut PipelineCache,
176 shader_source: &str,
177 label: &str,
178 a: &super::super::BufferId,
179 b: &super::super::BufferId,
180 output: &super::super::BufferId,
181 m: u32,
182 k: u32,
183 n: u32,
184 ) -> Result<(), String> {
185 contract_pre_tiled_naive_equivalence!();
186 let a_info = self.buffers.get(a).ok_or("Invalid buffer A ID")?;
187 let b_info = self.buffers.get(b).ok_or("Invalid buffer B ID")?;
188 let output_info = self.buffers.get(output).ok_or("Invalid output buffer ID")?;
189
190 let a_buffer = a_info.gpu_buffer.as_ref().ok_or("Buffer A not created")?;
191 let b_buffer = b_info.gpu_buffer.as_ref().ok_or("Buffer B not created")?;
192 let output_buffer = output_info.gpu_buffer.as_ref().ok_or("Output buffer not created")?;
193
194 let key = cache_key(shader_source);
195
196 if !cache.contains_key(&key) {
198 let shader = self.device.device.create_shader_module(wgpu::ShaderModuleDescriptor {
199 label: Some(&format!("{} Shader", label)),
200 source: wgpu::ShaderSource::Wgsl(shader_source.into()),
201 });
202
203 let bind_group_layout =
204 self.device.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
205 label: Some(&format!("{} Layout", label)),
206 entries: &[
207 wgpu::BindGroupLayoutEntry {
208 binding: 0,
209 visibility: wgpu::ShaderStages::COMPUTE,
210 ty: wgpu::BindingType::Buffer {
211 ty: wgpu::BufferBindingType::Storage { read_only: true },
212 has_dynamic_offset: false,
213 min_binding_size: None,
214 },
215 count: None,
216 },
217 wgpu::BindGroupLayoutEntry {
218 binding: 1,
219 visibility: wgpu::ShaderStages::COMPUTE,
220 ty: wgpu::BindingType::Buffer {
221 ty: wgpu::BufferBindingType::Storage { read_only: true },
222 has_dynamic_offset: false,
223 min_binding_size: None,
224 },
225 count: None,
226 },
227 wgpu::BindGroupLayoutEntry {
228 binding: 2,
229 visibility: wgpu::ShaderStages::COMPUTE,
230 ty: wgpu::BindingType::Buffer {
231 ty: wgpu::BufferBindingType::Storage { read_only: false },
232 has_dynamic_offset: false,
233 min_binding_size: None,
234 },
235 count: None,
236 },
237 wgpu::BindGroupLayoutEntry {
238 binding: 3,
239 visibility: wgpu::ShaderStages::COMPUTE,
240 ty: wgpu::BindingType::Buffer {
241 ty: wgpu::BufferBindingType::Uniform,
242 has_dynamic_offset: false,
243 min_binding_size: None,
244 },
245 count: None,
246 },
247 ],
248 });
249
250 let pipeline_layout =
251 self.device.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
252 label: Some(&format!("{} PipelineLayout", label)),
253 bind_group_layouts: &[&bind_group_layout],
254 push_constant_ranges: &[],
255 });
256
257 let pipeline =
258 self.device.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
259 label: Some(&format!("{} Pipeline", label)),
260 layout: Some(&pipeline_layout),
261 module: &shader,
262 entry_point: Some("main"),
263 compilation_options: Default::default(),
264 cache: None,
265 });
266
267 cache.insert(key, CachedPipeline { pipeline, bind_group_layout });
268 }
269
270 let cached = cache.get(&key).expect("pipeline just inserted");
271
272 #[repr(C)]
274 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
275 struct MatmulDims {
276 m: u32,
277 k: u32,
278 n: u32,
279 _pad: u32,
280 }
281
282 let dims = MatmulDims { m, k, n, _pad: 0 };
283 let dims_buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
284 label: Some(&format!("{} Dims", label)),
285 size: std::mem::size_of::<MatmulDims>() as u64,
286 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
287 mapped_at_creation: false,
288 });
289 self.device.queue.write_buffer(&dims_buffer, 0, bytemuck::bytes_of(&dims));
290
291 let bind_group = self.device.device.create_bind_group(&wgpu::BindGroupDescriptor {
293 label: Some(&format!("{} BindGroup", label)),
294 layout: &cached.bind_group_layout,
295 entries: &[
296 wgpu::BindGroupEntry { binding: 0, resource: a_buffer.as_entire_binding() },
297 wgpu::BindGroupEntry { binding: 1, resource: b_buffer.as_entire_binding() },
298 wgpu::BindGroupEntry { binding: 2, resource: output_buffer.as_entire_binding() },
299 wgpu::BindGroupEntry { binding: 3, resource: dims_buffer.as_entire_binding() },
300 ],
301 });
302
303 {
305 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
306 label: Some(&format!("{} Pass", label)),
307 timestamp_writes: None,
308 });
309 pass.set_pipeline(&cached.pipeline);
310 pass.set_bind_group(0, &bind_group, &[]);
311 pass.dispatch_workgroups(m.div_ceil(16), n.div_ceil(16), 1);
312 }
313
314 Ok(())
315 }
316
317 #[allow(clippy::map_entry)]
321 pub(crate) fn encode_binary_op(
322 &self,
323 encoder: &mut wgpu::CommandEncoder,
324 cache: &mut PipelineCache,
325 shader_source: &str,
326 label: &str,
327 a_buffer: &wgpu::Buffer,
328 b_buffer: &wgpu::Buffer,
329 output_buffer: &wgpu::Buffer,
330 size: usize,
331 ) -> Result<(), String> {
332 let key = cache_key(shader_source);
333
334 if !cache.contains_key(&key) {
336 let shader = self.device.device.create_shader_module(wgpu::ShaderModuleDescriptor {
337 label: Some(&format!("{} Shader", label)),
338 source: wgpu::ShaderSource::Wgsl(shader_source.into()),
339 });
340
341 let bind_group_layout =
342 self.device.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
343 label: Some(&format!("{} Layout", label)),
344 entries: &[
345 wgpu::BindGroupLayoutEntry {
346 binding: 0,
347 visibility: wgpu::ShaderStages::COMPUTE,
348 ty: wgpu::BindingType::Buffer {
349 ty: wgpu::BufferBindingType::Storage { read_only: true },
350 has_dynamic_offset: false,
351 min_binding_size: None,
352 },
353 count: None,
354 },
355 wgpu::BindGroupLayoutEntry {
356 binding: 1,
357 visibility: wgpu::ShaderStages::COMPUTE,
358 ty: wgpu::BindingType::Buffer {
359 ty: wgpu::BufferBindingType::Storage { read_only: true },
360 has_dynamic_offset: false,
361 min_binding_size: None,
362 },
363 count: None,
364 },
365 wgpu::BindGroupLayoutEntry {
366 binding: 2,
367 visibility: wgpu::ShaderStages::COMPUTE,
368 ty: wgpu::BindingType::Buffer {
369 ty: wgpu::BufferBindingType::Storage { read_only: false },
370 has_dynamic_offset: false,
371 min_binding_size: None,
372 },
373 count: None,
374 },
375 ],
376 });
377
378 let pipeline_layout =
379 self.device.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
380 label: Some(&format!("{} PipelineLayout", label)),
381 bind_group_layouts: &[&bind_group_layout],
382 push_constant_ranges: &[],
383 });
384
385 let pipeline =
386 self.device.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
387 label: Some(&format!("{} Pipeline", label)),
388 layout: Some(&pipeline_layout),
389 module: &shader,
390 entry_point: Some("main"),
391 compilation_options: Default::default(),
392 cache: None,
393 });
394
395 cache.insert(key, CachedPipeline { pipeline, bind_group_layout });
396 }
397
398 let cached = cache.get(&key).expect("pipeline just inserted");
399
400 let bind_group = self.device.device.create_bind_group(&wgpu::BindGroupDescriptor {
402 label: Some(&format!("{} BindGroup", label)),
403 layout: &cached.bind_group_layout,
404 entries: &[
405 wgpu::BindGroupEntry { binding: 0, resource: a_buffer.as_entire_binding() },
406 wgpu::BindGroupEntry { binding: 1, resource: b_buffer.as_entire_binding() },
407 wgpu::BindGroupEntry { binding: 2, resource: output_buffer.as_entire_binding() },
408 ],
409 });
410
411 {
413 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
414 label: Some(&format!("{} Pass", label)),
415 timestamp_writes: None,
416 });
417 pass.set_pipeline(&cached.pipeline);
418 pass.set_bind_group(0, &bind_group, &[]);
419 pass.dispatch_workgroups((size as u32).div_ceil(WORKGROUP_SIZE), 1, 1);
420 }
421
422 Ok(())
423 }
424}