1use crate::{GpuDevice, Result};
4use std::borrow::Cow;
5use wgpu::{
6 BindGroup, BindGroupDescriptor, BindGroupEntry, BindGroupLayout, BindGroupLayoutDescriptor,
7 BindGroupLayoutEntry, BindingType, BufferBindingType, ComputePipeline,
8 ComputePipelineDescriptor, PipelineLayoutDescriptor, ShaderModule, ShaderModuleDescriptor,
9 ShaderStages,
10};
11
12pub enum ShaderSource<'a> {
14 Wgsl(Cow<'a, str>),
16 Embedded(&'a str),
18}
19
20pub struct ShaderCompiler {
22 device: std::sync::Arc<wgpu::Device>,
23}
24
25impl ShaderCompiler {
26 #[must_use]
28 pub fn new(device: &GpuDevice) -> Self {
29 Self {
30 device: std::sync::Arc::clone(device.device()),
31 }
32 }
33
34 pub fn compile(&self, label: &str, source: ShaderSource<'_>) -> Result<ShaderModule> {
45 let source_str = match source {
46 ShaderSource::Wgsl(code) => code,
47 ShaderSource::Embedded(code) => Cow::Borrowed(code),
48 };
49
50 Ok(self.device.create_shader_module(ShaderModuleDescriptor {
51 label: Some(label),
52 source: wgpu::ShaderSource::Wgsl(source_str),
53 }))
54 }
55
56 pub fn create_pipeline(
69 &self,
70 label: &str,
71 shader: &ShaderModule,
72 entry_point: &str,
73 bind_group_layout: &BindGroupLayout,
74 ) -> Result<ComputePipeline> {
75 let pipeline_layout = self
76 .device
77 .create_pipeline_layout(&PipelineLayoutDescriptor {
78 label: Some(&format!("{label} Layout")),
79 bind_group_layouts: &[Some(bind_group_layout)],
80 immediate_size: 0,
81 });
82
83 Ok(self
84 .device
85 .create_compute_pipeline(&ComputePipelineDescriptor {
86 label: Some(label),
87 layout: Some(&pipeline_layout),
88 module: shader,
89 entry_point: Some(entry_point),
90 cache: None,
91 compilation_options: Default::default(),
92 }))
93 }
94
95 #[must_use]
102 pub fn create_bind_group_layout(
103 &self,
104 label: &str,
105 entries: &[BindGroupLayoutEntry],
106 ) -> BindGroupLayout {
107 self.device
108 .create_bind_group_layout(&BindGroupLayoutDescriptor {
109 label: Some(label),
110 entries,
111 })
112 }
113
114 #[must_use]
122 pub fn create_bind_group(
123 &self,
124 label: &str,
125 layout: &BindGroupLayout,
126 entries: &[BindGroupEntry<'_>],
127 ) -> BindGroup {
128 self.device.create_bind_group(&BindGroupDescriptor {
129 label: Some(label),
130 layout,
131 entries,
132 })
133 }
134}
135
136pub struct BindGroupLayoutBuilder {
138 entries: Vec<BindGroupLayoutEntry>,
139}
140
141impl BindGroupLayoutBuilder {
142 #[must_use]
144 pub fn new() -> Self {
145 Self {
146 entries: Vec::new(),
147 }
148 }
149
150 #[must_use]
152 pub fn add_storage_buffer_read_only(mut self, binding: u32) -> Self {
153 self.entries.push(BindGroupLayoutEntry {
154 binding,
155 visibility: ShaderStages::COMPUTE,
156 ty: BindingType::Buffer {
157 ty: BufferBindingType::Storage { read_only: true },
158 has_dynamic_offset: false,
159 min_binding_size: None,
160 },
161 count: None,
162 });
163 self
164 }
165
166 #[must_use]
168 pub fn add_storage_buffer(mut self, binding: u32) -> Self {
169 self.entries.push(BindGroupLayoutEntry {
170 binding,
171 visibility: ShaderStages::COMPUTE,
172 ty: BindingType::Buffer {
173 ty: BufferBindingType::Storage { read_only: false },
174 has_dynamic_offset: false,
175 min_binding_size: None,
176 },
177 count: None,
178 });
179 self
180 }
181
182 #[must_use]
184 pub fn add_uniform_buffer(mut self, binding: u32) -> Self {
185 self.entries.push(BindGroupLayoutEntry {
186 binding,
187 visibility: ShaderStages::COMPUTE,
188 ty: BindingType::Buffer {
189 ty: BufferBindingType::Uniform,
190 has_dynamic_offset: false,
191 min_binding_size: None,
192 },
193 count: None,
194 });
195 self
196 }
197
198 #[must_use]
200 pub fn build(self) -> Vec<BindGroupLayoutEntry> {
201 self.entries
202 }
203}
204
205impl Default for BindGroupLayoutBuilder {
206 fn default() -> Self {
207 Self::new()
208 }
209}
210
211pub mod embedded {
213 pub const COLORSPACE_SHADER: &str = include_str!("shaders/colorspace.wgsl");
215
216 pub const SCALE_SHADER: &str = include_str!("shaders/scale.wgsl");
218
219 pub const FILTER_SHADER: &str = include_str!("shaders/filter.wgsl");
221
222 pub const TRANSFORM_SHADER: &str = include_str!("shaders/transform.wgsl");
224
225 pub const BILATERAL_SHADER: &str = include_str!("shaders/bilateral.wgsl");
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[test]
234 fn test_bind_group_layout_builder() {
235 let layout = BindGroupLayoutBuilder::new()
236 .add_storage_buffer_read_only(0)
237 .add_storage_buffer(1)
238 .add_uniform_buffer(2)
239 .build();
240
241 assert_eq!(layout.len(), 3);
242 }
243}