Skip to main content

trueno/backends/gpu/batch/execute/
dispatch.rs

1//! Shader dispatch infrastructure for GPU compute operations
2//!
3//! Contains pipeline-cached dispatch functions that encode operations into
4//! a shared command encoder.  Pipelines are compiled once per unique shader
5//! and reused across all operations in a batch (KAIZEN-022).
6
7use super::super::GpuCommandBatch;
8use std::collections::HashMap;
9
10/// Workgroup size for element-wise compute shaders (threads per workgroup)
11const WORKGROUP_SIZE: u32 = 256;
12
13/// Cached GPU pipeline: shader module → compute pipeline + bind group layout.
14/// Created once per unique shader source, reused for all operations using that shader.
15pub struct CachedPipeline {
16    pub(crate) pipeline: wgpu::ComputePipeline,
17    pub(crate) bind_group_layout: wgpu::BindGroupLayout,
18}
19
20/// Pipeline cache keyed by shader source pointer address.
21///
22/// Safe because all shader sources are `&'static str` constants with stable addresses.
23/// Create one per training session and pass to `GpuCommandBatch::execute_with_cache()`
24/// to avoid recompiling shaders across batch executions (KAIZEN-023).
25pub type PipelineCache = HashMap<usize, CachedPipeline>;
26
27/// Compute a cache key from a shader source string.
28/// Uses pointer address since all sources are `&'static str` constants.
29fn cache_key(shader_source: &str) -> usize {
30    shader_source.as_ptr() as usize
31}
32
33impl GpuCommandBatch {
34    /// Encode a unary operation (one input, one output) into the command encoder.
35    ///
36    /// Pipeline is cached per shader source — first call compiles, subsequent calls reuse.
37    #[allow(clippy::map_entry)]
38    pub(crate) fn encode_unary_op<T: bytemuck::Pod>(
39        &self,
40        encoder: &mut wgpu::CommandEncoder,
41        cache: &mut PipelineCache,
42        shader_source: &str,
43        label: &str,
44        input_buffer: &wgpu::Buffer,
45        output_buffer: &wgpu::Buffer,
46        size: usize,
47        params: Option<&T>,
48    ) -> Result<(), String> {
49        let key = cache_key(shader_source);
50        let has_params = params.is_some();
51
52        // Get or create cached pipeline
53        if !cache.contains_key(&key) {
54            let shader = self.device.device.create_shader_module(wgpu::ShaderModuleDescriptor {
55                label: Some(&format!("{} Shader", label)),
56                source: wgpu::ShaderSource::Wgsl(shader_source.into()),
57            });
58
59            let mut layout_entries = vec![
60                wgpu::BindGroupLayoutEntry {
61                    binding: 0,
62                    visibility: wgpu::ShaderStages::COMPUTE,
63                    ty: wgpu::BindingType::Buffer {
64                        ty: wgpu::BufferBindingType::Storage { read_only: true },
65                        has_dynamic_offset: false,
66                        min_binding_size: None,
67                    },
68                    count: None,
69                },
70                wgpu::BindGroupLayoutEntry {
71                    binding: 1,
72                    visibility: wgpu::ShaderStages::COMPUTE,
73                    ty: wgpu::BindingType::Buffer {
74                        ty: wgpu::BufferBindingType::Storage { read_only: false },
75                        has_dynamic_offset: false,
76                        min_binding_size: None,
77                    },
78                    count: None,
79                },
80            ];
81
82            if has_params {
83                layout_entries.push(wgpu::BindGroupLayoutEntry {
84                    binding: 2,
85                    visibility: wgpu::ShaderStages::COMPUTE,
86                    ty: wgpu::BindingType::Buffer {
87                        ty: wgpu::BufferBindingType::Uniform,
88                        has_dynamic_offset: false,
89                        min_binding_size: None,
90                    },
91                    count: None,
92                });
93            }
94
95            let bind_group_layout =
96                self.device.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
97                    label: Some(&format!("{} Layout", label)),
98                    entries: &layout_entries,
99                });
100
101            let pipeline_layout =
102                self.device.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
103                    label: Some(&format!("{} PipelineLayout", label)),
104                    bind_group_layouts: &[&bind_group_layout],
105                    push_constant_ranges: &[],
106                });
107
108            let pipeline =
109                self.device.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
110                    label: Some(&format!("{} Pipeline", label)),
111                    layout: Some(&pipeline_layout),
112                    module: &shader,
113                    entry_point: Some("main"),
114                    compilation_options: Default::default(),
115                    cache: None,
116                });
117
118            cache.insert(key, CachedPipeline { pipeline, bind_group_layout });
119        }
120
121        let cached = cache.get(&key).expect("pipeline just inserted");
122
123        // Create uniform buffer if params provided (per-call, not cached)
124        let params_buffer = if let Some(params_data) = params {
125            let buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
126                label: Some(&format!("{} Params", label)),
127                size: std::mem::size_of::<T>() as u64,
128                usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
129                mapped_at_creation: false,
130            });
131            self.device.queue.write_buffer(&buffer, 0, bytemuck::bytes_of(params_data));
132            Some(buffer)
133        } else {
134            None
135        };
136
137        // Create bind group (per-call — references specific buffers)
138        let mut bind_entries = vec![
139            wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
140            wgpu::BindGroupEntry { binding: 1, resource: output_buffer.as_entire_binding() },
141        ];
142
143        if let Some(ref buffer) = params_buffer {
144            bind_entries
145                .push(wgpu::BindGroupEntry { binding: 2, resource: buffer.as_entire_binding() });
146        }
147
148        let bind_group = self.device.device.create_bind_group(&wgpu::BindGroupDescriptor {
149            label: Some(&format!("{} BindGroup", label)),
150            layout: &cached.bind_group_layout,
151            entries: &bind_entries,
152        });
153
154        // Encode compute pass
155        {
156            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
157                label: Some(&format!("{} Pass", label)),
158                timestamp_writes: None,
159            });
160            pass.set_pipeline(&cached.pipeline);
161            pass.set_bind_group(0, &bind_group, &[]);
162            pass.dispatch_workgroups((size as u32).div_ceil(WORKGROUP_SIZE), 1, 1);
163        }
164
165        Ok(())
166    }
167
168    /// Encode a matrix multiplication into the command encoder.
169    ///
170    /// Pipeline is cached — first matmul compiles the tiled shader, subsequent matmuls reuse it.
171    #[allow(clippy::map_entry)]
172    pub(crate) fn encode_matmul_op(
173        &self,
174        encoder: &mut wgpu::CommandEncoder,
175        cache: &mut PipelineCache,
176        shader_source: &str,
177        label: &str,
178        a: &super::super::BufferId,
179        b: &super::super::BufferId,
180        output: &super::super::BufferId,
181        m: u32,
182        k: u32,
183        n: u32,
184    ) -> Result<(), String> {
185        contract_pre_tiled_naive_equivalence!();
186        let a_info = self.buffers.get(a).ok_or("Invalid buffer A ID")?;
187        let b_info = self.buffers.get(b).ok_or("Invalid buffer B ID")?;
188        let output_info = self.buffers.get(output).ok_or("Invalid output buffer ID")?;
189
190        let a_buffer = a_info.gpu_buffer.as_ref().ok_or("Buffer A not created")?;
191        let b_buffer = b_info.gpu_buffer.as_ref().ok_or("Buffer B not created")?;
192        let output_buffer = output_info.gpu_buffer.as_ref().ok_or("Output buffer not created")?;
193
194        let key = cache_key(shader_source);
195
196        // Get or create cached pipeline
197        if !cache.contains_key(&key) {
198            let shader = self.device.device.create_shader_module(wgpu::ShaderModuleDescriptor {
199                label: Some(&format!("{} Shader", label)),
200                source: wgpu::ShaderSource::Wgsl(shader_source.into()),
201            });
202
203            let bind_group_layout =
204                self.device.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
205                    label: Some(&format!("{} Layout", label)),
206                    entries: &[
207                        wgpu::BindGroupLayoutEntry {
208                            binding: 0,
209                            visibility: wgpu::ShaderStages::COMPUTE,
210                            ty: wgpu::BindingType::Buffer {
211                                ty: wgpu::BufferBindingType::Storage { read_only: true },
212                                has_dynamic_offset: false,
213                                min_binding_size: None,
214                            },
215                            count: None,
216                        },
217                        wgpu::BindGroupLayoutEntry {
218                            binding: 1,
219                            visibility: wgpu::ShaderStages::COMPUTE,
220                            ty: wgpu::BindingType::Buffer {
221                                ty: wgpu::BufferBindingType::Storage { read_only: true },
222                                has_dynamic_offset: false,
223                                min_binding_size: None,
224                            },
225                            count: None,
226                        },
227                        wgpu::BindGroupLayoutEntry {
228                            binding: 2,
229                            visibility: wgpu::ShaderStages::COMPUTE,
230                            ty: wgpu::BindingType::Buffer {
231                                ty: wgpu::BufferBindingType::Storage { read_only: false },
232                                has_dynamic_offset: false,
233                                min_binding_size: None,
234                            },
235                            count: None,
236                        },
237                        wgpu::BindGroupLayoutEntry {
238                            binding: 3,
239                            visibility: wgpu::ShaderStages::COMPUTE,
240                            ty: wgpu::BindingType::Buffer {
241                                ty: wgpu::BufferBindingType::Uniform,
242                                has_dynamic_offset: false,
243                                min_binding_size: None,
244                            },
245                            count: None,
246                        },
247                    ],
248                });
249
250            let pipeline_layout =
251                self.device.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
252                    label: Some(&format!("{} PipelineLayout", label)),
253                    bind_group_layouts: &[&bind_group_layout],
254                    push_constant_ranges: &[],
255                });
256
257            let pipeline =
258                self.device.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
259                    label: Some(&format!("{} Pipeline", label)),
260                    layout: Some(&pipeline_layout),
261                    module: &shader,
262                    entry_point: Some("main"),
263                    compilation_options: Default::default(),
264                    cache: None,
265                });
266
267            cache.insert(key, CachedPipeline { pipeline, bind_group_layout });
268        }
269
270        let cached = cache.get(&key).expect("pipeline just inserted");
271
272        // Create dimensions uniform buffer (per-call — dimensions differ per matmul)
273        #[repr(C)]
274        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
275        struct MatmulDims {
276            m: u32,
277            k: u32,
278            n: u32,
279            _pad: u32,
280        }
281
282        let dims = MatmulDims { m, k, n, _pad: 0 };
283        let dims_buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
284            label: Some(&format!("{} Dims", label)),
285            size: std::mem::size_of::<MatmulDims>() as u64,
286            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
287            mapped_at_creation: false,
288        });
289        self.device.queue.write_buffer(&dims_buffer, 0, bytemuck::bytes_of(&dims));
290
291        // Create bind group (per-call — references specific buffers)
292        let bind_group = self.device.device.create_bind_group(&wgpu::BindGroupDescriptor {
293            label: Some(&format!("{} BindGroup", label)),
294            layout: &cached.bind_group_layout,
295            entries: &[
296                wgpu::BindGroupEntry { binding: 0, resource: a_buffer.as_entire_binding() },
297                wgpu::BindGroupEntry { binding: 1, resource: b_buffer.as_entire_binding() },
298                wgpu::BindGroupEntry { binding: 2, resource: output_buffer.as_entire_binding() },
299                wgpu::BindGroupEntry { binding: 3, resource: dims_buffer.as_entire_binding() },
300            ],
301        });
302
303        // Encode compute pass
304        {
305            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
306                label: Some(&format!("{} Pass", label)),
307                timestamp_writes: None,
308            });
309            pass.set_pipeline(&cached.pipeline);
310            pass.set_bind_group(0, &bind_group, &[]);
311            pass.dispatch_workgroups(m.div_ceil(16), n.div_ceil(16), 1);
312        }
313
314        Ok(())
315    }
316
317    /// Encode a binary operation (two inputs, one output) into the command encoder.
318    ///
319    /// Pipeline is cached per shader source.
320    #[allow(clippy::map_entry)]
321    pub(crate) fn encode_binary_op(
322        &self,
323        encoder: &mut wgpu::CommandEncoder,
324        cache: &mut PipelineCache,
325        shader_source: &str,
326        label: &str,
327        a_buffer: &wgpu::Buffer,
328        b_buffer: &wgpu::Buffer,
329        output_buffer: &wgpu::Buffer,
330        size: usize,
331    ) -> Result<(), String> {
332        let key = cache_key(shader_source);
333
334        // Get or create cached pipeline
335        if !cache.contains_key(&key) {
336            let shader = self.device.device.create_shader_module(wgpu::ShaderModuleDescriptor {
337                label: Some(&format!("{} Shader", label)),
338                source: wgpu::ShaderSource::Wgsl(shader_source.into()),
339            });
340
341            let bind_group_layout =
342                self.device.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
343                    label: Some(&format!("{} Layout", label)),
344                    entries: &[
345                        wgpu::BindGroupLayoutEntry {
346                            binding: 0,
347                            visibility: wgpu::ShaderStages::COMPUTE,
348                            ty: wgpu::BindingType::Buffer {
349                                ty: wgpu::BufferBindingType::Storage { read_only: true },
350                                has_dynamic_offset: false,
351                                min_binding_size: None,
352                            },
353                            count: None,
354                        },
355                        wgpu::BindGroupLayoutEntry {
356                            binding: 1,
357                            visibility: wgpu::ShaderStages::COMPUTE,
358                            ty: wgpu::BindingType::Buffer {
359                                ty: wgpu::BufferBindingType::Storage { read_only: true },
360                                has_dynamic_offset: false,
361                                min_binding_size: None,
362                            },
363                            count: None,
364                        },
365                        wgpu::BindGroupLayoutEntry {
366                            binding: 2,
367                            visibility: wgpu::ShaderStages::COMPUTE,
368                            ty: wgpu::BindingType::Buffer {
369                                ty: wgpu::BufferBindingType::Storage { read_only: false },
370                                has_dynamic_offset: false,
371                                min_binding_size: None,
372                            },
373                            count: None,
374                        },
375                    ],
376                });
377
378            let pipeline_layout =
379                self.device.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
380                    label: Some(&format!("{} PipelineLayout", label)),
381                    bind_group_layouts: &[&bind_group_layout],
382                    push_constant_ranges: &[],
383                });
384
385            let pipeline =
386                self.device.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
387                    label: Some(&format!("{} Pipeline", label)),
388                    layout: Some(&pipeline_layout),
389                    module: &shader,
390                    entry_point: Some("main"),
391                    compilation_options: Default::default(),
392                    cache: None,
393                });
394
395            cache.insert(key, CachedPipeline { pipeline, bind_group_layout });
396        }
397
398        let cached = cache.get(&key).expect("pipeline just inserted");
399
400        // Create bind group (per-call)
401        let bind_group = self.device.device.create_bind_group(&wgpu::BindGroupDescriptor {
402            label: Some(&format!("{} BindGroup", label)),
403            layout: &cached.bind_group_layout,
404            entries: &[
405                wgpu::BindGroupEntry { binding: 0, resource: a_buffer.as_entire_binding() },
406                wgpu::BindGroupEntry { binding: 1, resource: b_buffer.as_entire_binding() },
407                wgpu::BindGroupEntry { binding: 2, resource: output_buffer.as_entire_binding() },
408            ],
409        });
410
411        // Encode compute pass
412        {
413            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
414                label: Some(&format!("{} Pass", label)),
415                timestamp_writes: None,
416            });
417            pass.set_pipeline(&cached.pipeline);
418            pass.set_bind_group(0, &bind_group, &[]);
419            pass.dispatch_workgroups((size as u32).div_ceil(WORKGROUP_SIZE), 1, 1);
420        }
421
422        Ok(())
423    }
424}