Skip to main content

oxigdal_gpu/shaders/
mod.rs

1//! WGSL shader management for OxiGDAL GPU operations.
2//!
3//! This module provides utilities for loading, compiling, and validating
4//! WGSL compute shaders used in GPU-accelerated raster operations.
5
6use crate::error::{GpuError, GpuResult};
7use std::borrow::Cow;
8use tracing::debug;
9use wgpu::{
10    BindGroupLayout, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType,
11    BufferBindingType, ComputePipeline, ComputePipelineDescriptor, Device,
12    PipelineLayoutDescriptor, ShaderModule, ShaderModuleDescriptor, ShaderSource, ShaderStages,
13    naga,
14};
15
16/// WGSL shader wrapper with validation and compilation.
17pub struct WgslShader {
18    /// Shader source code.
19    source: String,
20    /// Shader entry point.
21    entry_point: String,
22    /// Compiled shader module.
23    module: Option<ShaderModule>,
24}
25
26impl WgslShader {
27    /// Create a new WGSL shader from source.
28    pub fn new(source: impl Into<String>, entry_point: impl Into<String>) -> Self {
29        Self {
30            source: source.into(),
31            entry_point: entry_point.into(),
32            module: None,
33        }
34    }
35
36    /// Validate the shader source without compiling.
37    ///
38    /// # Errors
39    ///
40    /// Returns an error if shader validation fails.
41    pub fn validate(&self) -> GpuResult<()> {
42        let module = naga::front::wgsl::parse_str(&self.source)
43            .map_err(|e| GpuError::shader_compilation(e.to_string()))?;
44
45        // Validate the module
46        let mut validator = naga::valid::Validator::new(
47            naga::valid::ValidationFlags::all(),
48            naga::valid::Capabilities::all(),
49        );
50
51        validator
52            .validate(&module)
53            .map_err(|e| GpuError::shader_validation(e.to_string()))?;
54
55        debug!("Shader validation successful");
56        Ok(())
57    }
58
59    /// Compile the shader for the given device.
60    ///
61    /// # Errors
62    ///
63    /// Returns an error if shader compilation fails.
64    pub fn compile(&mut self, device: &Device) -> GpuResult<&ShaderModule> {
65        if self.module.is_none() {
66            // Validate first
67            self.validate()?;
68
69            let module = device.create_shader_module(ShaderModuleDescriptor {
70                label: Some(&format!("Shader: {}", self.entry_point)),
71                source: ShaderSource::Wgsl(Cow::Borrowed(&self.source)),
72            });
73
74            self.module = Some(module);
75            debug!("Shader compiled: {}", self.entry_point);
76        }
77
78        Ok(self
79            .module
80            .as_ref()
81            .ok_or_else(|| GpuError::internal("Module should be compiled"))?)
82    }
83
84    /// Get the shader entry point.
85    pub fn entry_point(&self) -> &str {
86        &self.entry_point
87    }
88
89    /// Get the shader source.
90    pub fn source(&self) -> &str {
91        &self.source
92    }
93}
94
95/// Builder for compute pipelines with common configurations.
96pub struct ComputePipelineBuilder<'a> {
97    device: &'a Device,
98    shader: &'a ShaderModule,
99    entry_point: String,
100    bind_group_layouts: Vec<Option<&'a BindGroupLayout>>,
101    label: Option<String>,
102}
103
104impl<'a> ComputePipelineBuilder<'a> {
105    /// Create a new compute pipeline builder.
106    pub fn new(
107        device: &'a Device,
108        shader: &'a ShaderModule,
109        entry_point: impl Into<String>,
110    ) -> Self {
111        Self {
112            device,
113            shader,
114            entry_point: entry_point.into(),
115            bind_group_layouts: Vec::new(),
116            label: None,
117        }
118    }
119
120    /// Add a bind group layout.
121    pub fn bind_group_layout(mut self, layout: &'a BindGroupLayout) -> Self {
122        self.bind_group_layouts.push(Some(layout));
123        self
124    }
125
126    /// Set the pipeline label.
127    pub fn label(mut self, label: impl Into<String>) -> Self {
128        self.label = Some(label.into());
129        self
130    }
131
132    /// Build the compute pipeline.
133    ///
134    /// # Errors
135    ///
136    /// Returns an error if pipeline creation fails.
137    pub fn build(self) -> GpuResult<ComputePipeline> {
138        let pipeline_layout = self
139            .device
140            .create_pipeline_layout(&PipelineLayoutDescriptor {
141                label: self.label.as_deref(),
142                bind_group_layouts: &self.bind_group_layouts,
143                immediate_size: 0,
144            });
145
146        let pipeline = self
147            .device
148            .create_compute_pipeline(&ComputePipelineDescriptor {
149                label: self.label.as_deref(),
150                layout: Some(&pipeline_layout),
151                module: self.shader,
152                entry_point: Some(&self.entry_point),
153                compilation_options: Default::default(),
154                cache: None,
155            });
156
157        debug!("Compute pipeline created: {:?}", self.label);
158        Ok(pipeline)
159    }
160}
161
162/// Create a storage buffer bind group layout entry.
163pub fn storage_buffer_layout(binding: u32, read_only: bool) -> BindGroupLayoutEntry {
164    BindGroupLayoutEntry {
165        binding,
166        visibility: ShaderStages::COMPUTE,
167        ty: BindingType::Buffer {
168            ty: BufferBindingType::Storage { read_only },
169            has_dynamic_offset: false,
170            min_binding_size: None,
171        },
172        count: None,
173    }
174}
175
176/// Create a uniform buffer bind group layout entry.
177pub fn uniform_buffer_layout(binding: u32) -> BindGroupLayoutEntry {
178    BindGroupLayoutEntry {
179        binding,
180        visibility: ShaderStages::COMPUTE,
181        ty: BindingType::Buffer {
182            ty: BufferBindingType::Uniform,
183            has_dynamic_offset: false,
184            min_binding_size: None,
185        },
186        count: None,
187    }
188}
189
190/// Create a bind group layout for common compute patterns.
191///
192/// # Errors
193///
194/// Returns an error if layout creation fails.
195pub fn create_compute_bind_group_layout(
196    device: &Device,
197    entries: &[BindGroupLayoutEntry],
198    label: Option<&str>,
199) -> GpuResult<BindGroupLayout> {
200    Ok(device.create_bind_group_layout(&BindGroupLayoutDescriptor { label, entries }))
201}
202
203/// Shader library with common WGSL functions.
204pub struct ShaderLibrary;
205
206impl ShaderLibrary {
207    /// Get common utility functions for compute shaders.
208    pub fn common_utils() -> &'static str {
209        r#"
210// Common utility functions for WGSL shaders
211
212// Convert 2D coordinates to 1D index
213fn coord_to_index(x: u32, y: u32, width: u32) -> u32 {
214    return y * width + x;
215}
216
217// Convert 1D index to 2D coordinates
218fn index_to_coord(index: u32, width: u32) -> vec2<u32> {
219    return vec2<u32>(index % width, index / width);
220}
221
222// Clamp value to range [min, max]
223fn clamp_value(value: f32, min_val: f32, max_val: f32) -> f32 {
224    return clamp(value, min_val, max_val);
225}
226
227// Linear interpolation
228fn lerp(a: f32, b: f32, t: f32) -> f32 {
229    return a + (b - a) * t;
230}
231
232// Bilinear interpolation
233fn bilinear_interp(
234    v00: f32, v10: f32,
235    v01: f32, v11: f32,
236    tx: f32, ty: f32
237) -> f32 {
238    let v0 = lerp(v00, v10, tx);
239    let v1 = lerp(v01, v11, tx);
240    return lerp(v0, v1, ty);
241}
242
243// Safe division (returns 0 if denominator is 0)
244fn safe_div(num: f32, denom: f32) -> f32 {
245    if (abs(denom) < 1e-10) {
246        return 0.0;
247    }
248    return num / denom;
249}
250
251// Check if value is NaN
252fn is_nan(value: f32) -> bool {
253    return value != value;
254}
255
256// Check if value is infinite
257fn is_inf(value: f32) -> bool {
258    return abs(value) > 1e38;
259}
260
261// Safe value (replace NaN/Inf with 0)
262fn safe_value(value: f32) -> f32 {
263    if (is_nan(value) || is_inf(value)) {
264        return 0.0;
265    }
266    return value;
267}
268"#
269    }
270
271    /// Get NDVI (Normalized Difference Vegetation Index) shader.
272    pub fn ndvi_shader() -> &'static str {
273        r#"
274@group(0) @binding(0) var<storage, read> nir: array<f32>;
275@group(0) @binding(1) var<storage, read> red: array<f32>;
276@group(0) @binding(2) var<storage, read_write> output: array<f32>;
277
278@compute @workgroup_size(256)
279fn ndvi(@builtin(global_invocation_id) global_id: vec3<u32>) {
280    let idx = global_id.x;
281    if (idx >= arrayLength(&output)) {
282        return;
283    }
284
285    let nir_val = nir[idx];
286    let red_val = red[idx];
287    let sum = nir_val + red_val;
288
289    if (abs(sum) < 1e-10) {
290        output[idx] = 0.0;
291    } else {
292        output[idx] = (nir_val - red_val) / sum;
293    }
294}
295"#
296    }
297
298    /// Get element-wise addition shader.
299    pub fn add_shader() -> &'static str {
300        r#"
301@group(0) @binding(0) var<storage, read> input_a: array<f32>;
302@group(0) @binding(1) var<storage, read> input_b: array<f32>;
303@group(0) @binding(2) var<storage, read_write> output: array<f32>;
304
305@compute @workgroup_size(256)
306fn add(@builtin(global_invocation_id) global_id: vec3<u32>) {
307    let idx = global_id.x;
308    if (idx >= arrayLength(&output)) {
309        return;
310    }
311    output[idx] = input_a[idx] + input_b[idx];
312}
313"#
314    }
315
316    /// Get element-wise multiplication shader.
317    pub fn multiply_shader() -> &'static str {
318        r#"
319@group(0) @binding(0) var<storage, read> input_a: array<f32>;
320@group(0) @binding(1) var<storage, read> input_b: array<f32>;
321@group(0) @binding(2) var<storage, read_write> output: array<f32>;
322
323@compute @workgroup_size(256)
324fn multiply(@builtin(global_invocation_id) global_id: vec3<u32>) {
325    let idx = global_id.x;
326    if (idx >= arrayLength(&output)) {
327        return;
328    }
329    output[idx] = input_a[idx] * input_b[idx];
330}
331"#
332    }
333
334    /// Get threshold shader.
335    pub fn threshold_shader() -> &'static str {
336        r#"
337struct Params {
338    threshold: f32,
339    value_below: f32,
340    value_above: f32,
341}
342
343@group(0) @binding(0) var<storage, read> input: array<f32>;
344@group(0) @binding(1) var<uniform> params: Params;
345@group(0) @binding(2) var<storage, read_write> output: array<f32>;
346
347@compute @workgroup_size(256)
348fn threshold(@builtin(global_invocation_id) global_id: vec3<u32>) {
349    let idx = global_id.x;
350    if (idx >= arrayLength(&output)) {
351        return;
352    }
353
354    if (input[idx] < params.threshold) {
355        output[idx] = params.value_below;
356    } else {
357        output[idx] = params.value_above;
358    }
359}
360"#
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn test_shader_library() {
370        let utils = ShaderLibrary::common_utils();
371        assert!(utils.contains("coord_to_index"));
372        assert!(utils.contains("bilinear_interp"));
373
374        let ndvi = ShaderLibrary::ndvi_shader();
375        assert!(ndvi.contains("@compute"));
376        assert!(ndvi.contains("workgroup_size"));
377    }
378
379    #[test]
380    fn test_shader_validation() {
381        let shader = WgslShader::new(ShaderLibrary::add_shader(), "add");
382
383        // Validation might fail without GPU, so we just check it doesn't panic
384        let _ = shader.validate();
385    }
386
387    #[test]
388    fn test_bind_group_layout_helpers() {
389        let storage_ro = storage_buffer_layout(0, true);
390        assert_eq!(storage_ro.binding, 0);
391
392        let storage_rw = storage_buffer_layout(1, false);
393        assert_eq!(storage_rw.binding, 1);
394
395        let uniform = uniform_buffer_layout(2);
396        assert_eq!(uniform.binding, 2);
397    }
398}