Skip to main content

oximedia_gpu/
shader.rs

1//! Shader compilation and pipeline management
2
3use crate::{GpuDevice, Result};
4use std::borrow::Cow;
5use wgpu::{
6    BindGroup, BindGroupDescriptor, BindGroupEntry, BindGroupLayout, BindGroupLayoutDescriptor,
7    BindGroupLayoutEntry, BindingType, BufferBindingType, ComputePipeline,
8    ComputePipelineDescriptor, PipelineLayoutDescriptor, ShaderModule, ShaderModuleDescriptor,
9    ShaderStages,
10};
11
12/// Shader source type
13pub enum ShaderSource<'a> {
14    /// WGSL source code
15    Wgsl(Cow<'a, str>),
16    /// Embedded shader (included at compile time)
17    Embedded(&'a str),
18}
19
20/// Shader compiler and pipeline builder
21pub struct ShaderCompiler {
22    device: std::sync::Arc<wgpu::Device>,
23}
24
25impl ShaderCompiler {
26    /// Create a new shader compiler
27    #[must_use]
28    pub fn new(device: &GpuDevice) -> Self {
29        Self {
30            device: std::sync::Arc::clone(device.device()),
31        }
32    }
33
34    /// Compile a shader module from source
35    ///
36    /// # Arguments
37    ///
38    /// * `label` - Shader label for debugging
39    /// * `source` - Shader source code
40    ///
41    /// # Errors
42    ///
43    /// Returns an error if shader compilation fails.
44    pub fn compile(&self, label: &str, source: ShaderSource<'_>) -> Result<ShaderModule> {
45        let source_str = match source {
46            ShaderSource::Wgsl(code) => code,
47            ShaderSource::Embedded(code) => Cow::Borrowed(code),
48        };
49
50        Ok(self.device.create_shader_module(ShaderModuleDescriptor {
51            label: Some(label),
52            source: wgpu::ShaderSource::Wgsl(source_str),
53        }))
54    }
55
56    /// Create a compute pipeline
57    ///
58    /// # Arguments
59    ///
60    /// * `label` - Pipeline label for debugging
61    /// * `shader` - Compiled shader module
62    /// * `entry_point` - Entry point function name
63    /// * `bind_group_layout` - Bind group layout for resources
64    ///
65    /// # Errors
66    ///
67    /// Returns an error if pipeline creation fails.
68    pub fn create_pipeline(
69        &self,
70        label: &str,
71        shader: &ShaderModule,
72        entry_point: &str,
73        bind_group_layout: &BindGroupLayout,
74    ) -> Result<ComputePipeline> {
75        let pipeline_layout = self
76            .device
77            .create_pipeline_layout(&PipelineLayoutDescriptor {
78                label: Some(&format!("{label} Layout")),
79                bind_group_layouts: &[Some(bind_group_layout)],
80                immediate_size: 0,
81            });
82
83        Ok(self
84            .device
85            .create_compute_pipeline(&ComputePipelineDescriptor {
86                label: Some(label),
87                layout: Some(&pipeline_layout),
88                module: shader,
89                entry_point: Some(entry_point),
90                cache: None,
91                compilation_options: Default::default(),
92            }))
93    }
94
95    /// Create a bind group layout for compute operations
96    ///
97    /// # Arguments
98    ///
99    /// * `label` - Layout label for debugging
100    /// * `entries` - Bind group layout entries
101    #[must_use]
102    pub fn create_bind_group_layout(
103        &self,
104        label: &str,
105        entries: &[BindGroupLayoutEntry],
106    ) -> BindGroupLayout {
107        self.device
108            .create_bind_group_layout(&BindGroupLayoutDescriptor {
109                label: Some(label),
110                entries,
111            })
112    }
113
114    /// Create a bind group
115    ///
116    /// # Arguments
117    ///
118    /// * `label` - Bind group label for debugging
119    /// * `layout` - Bind group layout
120    /// * `entries` - Bind group entries
121    #[must_use]
122    pub fn create_bind_group(
123        &self,
124        label: &str,
125        layout: &BindGroupLayout,
126        entries: &[BindGroupEntry<'_>],
127    ) -> BindGroup {
128        self.device.create_bind_group(&BindGroupDescriptor {
129            label: Some(label),
130            layout,
131            entries,
132        })
133    }
134}
135
136/// Helper for creating standard bind group layouts
137pub struct BindGroupLayoutBuilder {
138    entries: Vec<BindGroupLayoutEntry>,
139}
140
141impl BindGroupLayoutBuilder {
142    /// Create a new bind group layout builder
143    #[must_use]
144    pub fn new() -> Self {
145        Self {
146            entries: Vec::new(),
147        }
148    }
149
150    /// Add a storage buffer binding (read-only)
151    #[must_use]
152    pub fn add_storage_buffer_read_only(mut self, binding: u32) -> Self {
153        self.entries.push(BindGroupLayoutEntry {
154            binding,
155            visibility: ShaderStages::COMPUTE,
156            ty: BindingType::Buffer {
157                ty: BufferBindingType::Storage { read_only: true },
158                has_dynamic_offset: false,
159                min_binding_size: None,
160            },
161            count: None,
162        });
163        self
164    }
165
166    /// Add a storage buffer binding (read-write)
167    #[must_use]
168    pub fn add_storage_buffer(mut self, binding: u32) -> Self {
169        self.entries.push(BindGroupLayoutEntry {
170            binding,
171            visibility: ShaderStages::COMPUTE,
172            ty: BindingType::Buffer {
173                ty: BufferBindingType::Storage { read_only: false },
174                has_dynamic_offset: false,
175                min_binding_size: None,
176            },
177            count: None,
178        });
179        self
180    }
181
182    /// Add a uniform buffer binding
183    #[must_use]
184    pub fn add_uniform_buffer(mut self, binding: u32) -> Self {
185        self.entries.push(BindGroupLayoutEntry {
186            binding,
187            visibility: ShaderStages::COMPUTE,
188            ty: BindingType::Buffer {
189                ty: BufferBindingType::Uniform,
190                has_dynamic_offset: false,
191                min_binding_size: None,
192            },
193            count: None,
194        });
195        self
196    }
197
198    /// Build the bind group layout entries
199    #[must_use]
200    pub fn build(self) -> Vec<BindGroupLayoutEntry> {
201        self.entries
202    }
203}
204
205impl Default for BindGroupLayoutBuilder {
206    fn default() -> Self {
207        Self::new()
208    }
209}
210
211/// Precompiled shaders embedded at compile time
212pub mod embedded {
213    /// Color space conversion shader source
214    pub const COLORSPACE_SHADER: &str = include_str!("shaders/colorspace.wgsl");
215
216    /// Image scaling shader source
217    pub const SCALE_SHADER: &str = include_str!("shaders/scale.wgsl");
218
219    /// Convolution filter shader source
220    pub const FILTER_SHADER: &str = include_str!("shaders/filter.wgsl");
221
222    /// Transform operations shader source
223    pub const TRANSFORM_SHADER: &str = include_str!("shaders/transform.wgsl");
224
225    /// Bilateral filter denoising shader source
226    pub const BILATERAL_SHADER: &str = include_str!("shaders/bilateral.wgsl");
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn test_bind_group_layout_builder() {
235        let layout = BindGroupLayoutBuilder::new()
236            .add_storage_buffer_read_only(0)
237            .add_storage_buffer(1)
238            .add_uniform_buffer(2)
239            .build();
240
241        assert_eq!(layout.len(), 3);
242    }
243}