oxgpu/
compute_kernel.rs

1use crate::Context;
2
3
4/// Describes the type of data bound to a kernel.
5#[derive(Debug, Clone, Copy)]
6pub enum BindingType {
7    /// A storage buffer (array of data). `read_only` determines if the shader can write to it.
8    Storage { read_only: bool },
9    /// A uniform buffer (read-only constant data).
10    Uniform,
11}
12
13/// Helper struct to define a binding layout for a kernel.
14///
15/// Typically used internally by the `ComputeKernelBuilder`.
16#[derive(Debug, Clone)]
17pub struct KernelBinding {
18    /// The binding index (e.g., `@binding(0)`).
19    pub binding: u32,
20    /// The type of resource bound.
21    pub ty: BindingType,
22}
23
24impl KernelBinding {
25    pub fn new(binding: u32, ty: BindingType) -> Self {
26        Self { binding, ty }
27    }
28}
29
30/// A trait for types that can be passed as arguments to a compute kernel.
31///
32/// Implemented for `Buffer<T>`. This trait is sealed or effectively hidden 
33/// to simplify the public API, abstracting away `wgpu::BindingResource`.
34pub trait KernelArgument {
35    #[doc(hidden)]
36    fn as_binding_resource(&self) -> wgpu::BindingResource<'_>;
37}
38
39/// Represents a compiled compute, ready for execution.
40///
41/// Handles the pipeline creation and command encoding.
42pub struct ComputeKernel {
43    pipeline: wgpu::ComputePipeline,
44    bind_group_layout: wgpu::BindGroupLayout,
45}
46
47/// A builder for creating a `ComputeKernel`.
48///
49/// Allows specifying shader source, entry point, bindings, and label in a fluent manner.
50pub struct ComputeKernelBuilder<'a> {
51    source: Option<&'a str>,
52    entry_point: &'a str,
53    bindings: Vec<KernelBinding>,
54    label: &'a str,
55}
56
57impl<'a> ComputeKernelBuilder<'a> {
58    /// Creates a new, empty builder.
59    pub fn new() -> Self {
60        Self {
61            source: None,
62            entry_point: "main",
63            bindings: Vec::new(),
64            label: "compute_kernel",
65        }
66    }
67
68    /// Sets the WGSL shader source code.
69    pub fn source(mut self, source: &'a str) -> Self {
70        self.source = Some(source);
71        self
72    }
73
74    /// Sets the entry point function name (default: "main").
75    pub fn entry_point(mut self, entry_point: &'a str) -> Self {
76        self.entry_point = entry_point;
77        self
78    }
79
80    /// Sets a label for the kernel (used for debugging/profiling).
81    pub fn label(mut self, label: &'a str) -> Self {
82        self.label = label;
83        self
84    }
85
86    /// Manually adds a binding definition.
87    pub fn bind(mut self, binding: KernelBinding) -> Self {
88        self.bindings.push(binding);
89        self
90    }
91
92    /// Adds a read-only storage buffer binding (e.g., `var<storage, read>`).
93    pub fn add_storage_read(self, binding: u32) -> Self {
94        self.bind(KernelBinding::new(binding, BindingType::Storage { read_only: true }))
95    }
96
97    /// Adds a read-write storage buffer binding (e.g., `var<storage, read_write>`).
98    pub fn add_storage_read_write(self, binding: u32) -> Self {
99        self.bind(KernelBinding::new(binding, BindingType::Storage { read_only: false }))
100    }
101
102    /// Adds a uniform buffer binding (e.g., `var<uniform>`).
103    pub fn add_uniform(self, binding: u32) -> Self {
104        self.bind(KernelBinding::new(binding, BindingType::Uniform))
105    }
106
107    /// Consumes the builder and creates the `ComputeKernel`.
108    pub async fn build(self, ctx: &Context) -> Result<ComputeKernel, String> {
109        let source = self.source.ok_or("Shader source not provided")?;
110        ComputeKernel::new(ctx, source, self.entry_point, &self.bindings, self.label).await
111    }
112}
113
114/// Trait for defining compute workgroup dimensions.
115///
116/// Supported types: `u32`, `(u32, u32)`, `(u32, u32, u32)`, `[u32; 3]`.
117pub trait Dispatch {
118    /// Returns the (x, y, z) dimensions for dispatch.
119    fn as_workgroups(&self) -> (u32, u32, u32);
120}
121
122impl Dispatch for u32 {
123    fn as_workgroups(&self) -> (u32, u32, u32) {
124        (*self, 1, 1)
125    }
126}
127
128impl Dispatch for (u32, u32) {
129    fn as_workgroups(&self) -> (u32, u32, u32) {
130        (self.0, self.1, 1)
131    }
132}
133
134impl Dispatch for (u32, u32, u32) {
135    fn as_workgroups(&self) -> (u32, u32, u32) {
136        *self
137    }
138}
139
140impl Dispatch for [u32; 3] {
141    fn as_workgroups(&self) -> (u32, u32, u32) {
142        (self[0], self[1], self[2])
143    }
144}
145
146impl ComputeKernel {
147    /// Returns a new `ComputeKernelBuilder`.
148    pub fn builder<'a>() -> ComputeKernelBuilder<'a> {
149        ComputeKernelBuilder::new()
150    }
151
152    /// Internal constructor used by the builder.
153    pub async fn new(
154        ctx: &Context,
155        shader_src: &str,
156        entry_point: &str,
157        layout_bindings: &[KernelBinding],
158        label: &str,
159    ) -> Result<Self, String> {
160        // 1. Load the shader module
161        let shader = ctx
162            .device
163            .create_shader_module(wgpu::ShaderModuleDescriptor {
164                label: Some(&format!("{}_shader", label)),
165                source: wgpu::ShaderSource::Wgsl(shader_src.into()),
166            });
167
168        // 2. Convert abstract bindings to wgpu entries
169        let wgpu_entries: Vec<wgpu::BindGroupLayoutEntry> = layout_bindings
170            .iter()
171            .map(|kb| wgpu::BindGroupLayoutEntry {
172                binding: kb.binding,
173                visibility: wgpu::ShaderStages::COMPUTE,
174                ty: match kb.ty {
175                    BindingType::Storage { read_only } => wgpu::BindingType::Buffer {
176                        ty: wgpu::BufferBindingType::Storage { read_only },
177                        has_dynamic_offset: false,
178                        min_binding_size: None,
179                    },
180                    BindingType::Uniform => wgpu::BindingType::Buffer {
181                        ty: wgpu::BufferBindingType::Uniform,
182                        has_dynamic_offset: false,
183                        min_binding_size: None,
184                    },
185                },
186                count: None,
187            })
188            .collect();
189
190        // 3. Create Bind Group Layout
191        let bind_group_layout =
192            ctx.device
193                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
194                    label: Some(&format!("{}_layout", label)),
195                    entries: &wgpu_entries,
196                });
197
198        // 4. Create Pipeline Layout
199        let pipeline_layout = ctx
200            .device
201            .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
202                label: Some(&format!("{}_pipeline_layout", label)),
203                bind_group_layouts: &[&bind_group_layout],
204                push_constant_ranges: &[],
205            });
206
207        // 5. Create the actual Pipeline
208        let pipeline = ctx
209            .device
210            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
211                label: Some(label),
212                layout: Some(&pipeline_layout),
213                module: &shader,
214                entry_point: Some(entry_point),
215                compilation_options: Default::default(),
216                cache: None,
217            });
218
219        Ok(Self {
220            pipeline,
221            bind_group_layout,
222        })
223    }
224
225    /// Executes the compute kernel with the given arguments and workgroups.
226    ///
227    /// `workgroups` can be a single number (1D), or a 2D/3D tuple/array.
228    /// `args` is a slice of buffers implementing `KernelArgument`.
229    /// The number and order of arguments must match the bindings defined during build.
230    pub fn run(
231        &self,
232        ctx: &Context,
233        workgroups: impl Dispatch,
234        args: &[&dyn KernelArgument],
235    ) {
236        let workgroups = workgroups.as_workgroups();
237        
238        // 1. Create Bind Group entries dynamically
239        let entries: Vec<wgpu::BindGroupEntry> = args
240            .iter()
241            .enumerate()
242            .map(|(i, arg)| wgpu::BindGroupEntry {
243                binding: i as u32,
244                resource: arg.as_binding_resource(),
245            })
246            .collect();
247
248        // 2. Create Bind Group
249        let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
250            label: None,
251            layout: &self.bind_group_layout,
252            entries: &entries,
253        });
254
255        // 3. Encode commands
256        let mut encoder = ctx
257            .device
258            .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
259        {
260            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
261                label: None,
262                timestamp_writes: None,
263            });
264            compute_pass.set_pipeline(&self.pipeline);
265            compute_pass.set_bind_group(0, &bind_group, &[]);
266            compute_pass.dispatch_workgroups(workgroups.0, workgroups.1, workgroups.2);
267        }
268
269        // 4. Submit to Queue
270        ctx.queue.submit(Some(encoder.finish()));
271    }
272}