1use crate::Context;
2
3
4#[derive(Debug, Clone, Copy)]
6pub enum BindingType {
7 Storage { read_only: bool },
9 Uniform,
11}
12
13#[derive(Debug, Clone)]
17pub struct KernelBinding {
18 pub binding: u32,
20 pub ty: BindingType,
22}
23
24impl KernelBinding {
25 pub fn new(binding: u32, ty: BindingType) -> Self {
26 Self { binding, ty }
27 }
28}
29
30pub trait KernelArgument {
35 #[doc(hidden)]
36 fn as_binding_resource(&self) -> wgpu::BindingResource<'_>;
37}
38
39pub struct ComputeKernel {
43 pipeline: wgpu::ComputePipeline,
44 bind_group_layout: wgpu::BindGroupLayout,
45}
46
47pub struct ComputeKernelBuilder<'a> {
51 source: Option<&'a str>,
52 entry_point: &'a str,
53 bindings: Vec<KernelBinding>,
54 label: &'a str,
55}
56
57impl<'a> ComputeKernelBuilder<'a> {
58 pub fn new() -> Self {
60 Self {
61 source: None,
62 entry_point: "main",
63 bindings: Vec::new(),
64 label: "compute_kernel",
65 }
66 }
67
68 pub fn source(mut self, source: &'a str) -> Self {
70 self.source = Some(source);
71 self
72 }
73
74 pub fn entry_point(mut self, entry_point: &'a str) -> Self {
76 self.entry_point = entry_point;
77 self
78 }
79
80 pub fn label(mut self, label: &'a str) -> Self {
82 self.label = label;
83 self
84 }
85
86 pub fn bind(mut self, binding: KernelBinding) -> Self {
88 self.bindings.push(binding);
89 self
90 }
91
92 pub fn add_storage_read(self, binding: u32) -> Self {
94 self.bind(KernelBinding::new(binding, BindingType::Storage { read_only: true }))
95 }
96
97 pub fn add_storage_read_write(self, binding: u32) -> Self {
99 self.bind(KernelBinding::new(binding, BindingType::Storage { read_only: false }))
100 }
101
102 pub fn add_uniform(self, binding: u32) -> Self {
104 self.bind(KernelBinding::new(binding, BindingType::Uniform))
105 }
106
107 pub async fn build(self, ctx: &Context) -> Result<ComputeKernel, String> {
109 let source = self.source.ok_or("Shader source not provided")?;
110 ComputeKernel::new(ctx, source, self.entry_point, &self.bindings, self.label).await
111 }
112}
113
114pub trait Dispatch {
118 fn as_workgroups(&self) -> (u32, u32, u32);
120}
121
122impl Dispatch for u32 {
123 fn as_workgroups(&self) -> (u32, u32, u32) {
124 (*self, 1, 1)
125 }
126}
127
128impl Dispatch for (u32, u32) {
129 fn as_workgroups(&self) -> (u32, u32, u32) {
130 (self.0, self.1, 1)
131 }
132}
133
134impl Dispatch for (u32, u32, u32) {
135 fn as_workgroups(&self) -> (u32, u32, u32) {
136 *self
137 }
138}
139
140impl Dispatch for [u32; 3] {
141 fn as_workgroups(&self) -> (u32, u32, u32) {
142 (self[0], self[1], self[2])
143 }
144}
145
146impl ComputeKernel {
147 pub fn builder<'a>() -> ComputeKernelBuilder<'a> {
149 ComputeKernelBuilder::new()
150 }
151
152 pub async fn new(
154 ctx: &Context,
155 shader_src: &str,
156 entry_point: &str,
157 layout_bindings: &[KernelBinding],
158 label: &str,
159 ) -> Result<Self, String> {
160 let shader = ctx
162 .device
163 .create_shader_module(wgpu::ShaderModuleDescriptor {
164 label: Some(&format!("{}_shader", label)),
165 source: wgpu::ShaderSource::Wgsl(shader_src.into()),
166 });
167
168 let wgpu_entries: Vec<wgpu::BindGroupLayoutEntry> = layout_bindings
170 .iter()
171 .map(|kb| wgpu::BindGroupLayoutEntry {
172 binding: kb.binding,
173 visibility: wgpu::ShaderStages::COMPUTE,
174 ty: match kb.ty {
175 BindingType::Storage { read_only } => wgpu::BindingType::Buffer {
176 ty: wgpu::BufferBindingType::Storage { read_only },
177 has_dynamic_offset: false,
178 min_binding_size: None,
179 },
180 BindingType::Uniform => wgpu::BindingType::Buffer {
181 ty: wgpu::BufferBindingType::Uniform,
182 has_dynamic_offset: false,
183 min_binding_size: None,
184 },
185 },
186 count: None,
187 })
188 .collect();
189
190 let bind_group_layout =
192 ctx.device
193 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
194 label: Some(&format!("{}_layout", label)),
195 entries: &wgpu_entries,
196 });
197
198 let pipeline_layout = ctx
200 .device
201 .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
202 label: Some(&format!("{}_pipeline_layout", label)),
203 bind_group_layouts: &[&bind_group_layout],
204 push_constant_ranges: &[],
205 });
206
207 let pipeline = ctx
209 .device
210 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
211 label: Some(label),
212 layout: Some(&pipeline_layout),
213 module: &shader,
214 entry_point: Some(entry_point),
215 compilation_options: Default::default(),
216 cache: None,
217 });
218
219 Ok(Self {
220 pipeline,
221 bind_group_layout,
222 })
223 }
224
225 pub fn run(
231 &self,
232 ctx: &Context,
233 workgroups: impl Dispatch,
234 args: &[&dyn KernelArgument],
235 ) {
236 let workgroups = workgroups.as_workgroups();
237
238 let entries: Vec<wgpu::BindGroupEntry> = args
240 .iter()
241 .enumerate()
242 .map(|(i, arg)| wgpu::BindGroupEntry {
243 binding: i as u32,
244 resource: arg.as_binding_resource(),
245 })
246 .collect();
247
248 let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
250 label: None,
251 layout: &self.bind_group_layout,
252 entries: &entries,
253 });
254
255 let mut encoder = ctx
257 .device
258 .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
259 {
260 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
261 label: None,
262 timestamp_writes: None,
263 });
264 compute_pass.set_pipeline(&self.pipeline);
265 compute_pass.set_bind_group(0, &bind_group, &[]);
266 compute_pass.dispatch_workgroups(workgroups.0, workgroups.1, workgroups.2);
267 }
268
269 ctx.queue.submit(Some(encoder.finish()));
271 }
272}