1use crate::error::{GpuError, GpuResult};
7use std::borrow::Cow;
8use tracing::debug;
9use wgpu::{
10 BindGroupLayout, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType,
11 BufferBindingType, ComputePipeline, ComputePipelineDescriptor, Device,
12 PipelineLayoutDescriptor, ShaderModule, ShaderModuleDescriptor, ShaderSource, ShaderStages,
13 naga,
14};
15
16pub struct WgslShader {
18 source: String,
20 entry_point: String,
22 module: Option<ShaderModule>,
24}
25
26impl WgslShader {
27 pub fn new(source: impl Into<String>, entry_point: impl Into<String>) -> Self {
29 Self {
30 source: source.into(),
31 entry_point: entry_point.into(),
32 module: None,
33 }
34 }
35
36 pub fn validate(&self) -> GpuResult<()> {
42 let module = naga::front::wgsl::parse_str(&self.source)
43 .map_err(|e| GpuError::shader_compilation(e.to_string()))?;
44
45 let mut validator = naga::valid::Validator::new(
47 naga::valid::ValidationFlags::all(),
48 naga::valid::Capabilities::all(),
49 );
50
51 validator
52 .validate(&module)
53 .map_err(|e| GpuError::shader_validation(e.to_string()))?;
54
55 debug!("Shader validation successful");
56 Ok(())
57 }
58
59 pub fn compile(&mut self, device: &Device) -> GpuResult<&ShaderModule> {
65 if self.module.is_none() {
66 self.validate()?;
68
69 let module = device.create_shader_module(ShaderModuleDescriptor {
70 label: Some(&format!("Shader: {}", self.entry_point)),
71 source: ShaderSource::Wgsl(Cow::Borrowed(&self.source)),
72 });
73
74 self.module = Some(module);
75 debug!("Shader compiled: {}", self.entry_point);
76 }
77
78 Ok(self
79 .module
80 .as_ref()
81 .ok_or_else(|| GpuError::internal("Module should be compiled"))?)
82 }
83
84 pub fn entry_point(&self) -> &str {
86 &self.entry_point
87 }
88
89 pub fn source(&self) -> &str {
91 &self.source
92 }
93}
94
95pub struct ComputePipelineBuilder<'a> {
97 device: &'a Device,
98 shader: &'a ShaderModule,
99 entry_point: String,
100 bind_group_layouts: Vec<Option<&'a BindGroupLayout>>,
101 label: Option<String>,
102}
103
104impl<'a> ComputePipelineBuilder<'a> {
105 pub fn new(
107 device: &'a Device,
108 shader: &'a ShaderModule,
109 entry_point: impl Into<String>,
110 ) -> Self {
111 Self {
112 device,
113 shader,
114 entry_point: entry_point.into(),
115 bind_group_layouts: Vec::new(),
116 label: None,
117 }
118 }
119
120 pub fn bind_group_layout(mut self, layout: &'a BindGroupLayout) -> Self {
122 self.bind_group_layouts.push(Some(layout));
123 self
124 }
125
126 pub fn label(mut self, label: impl Into<String>) -> Self {
128 self.label = Some(label.into());
129 self
130 }
131
132 pub fn build(self) -> GpuResult<ComputePipeline> {
138 let pipeline_layout = self
139 .device
140 .create_pipeline_layout(&PipelineLayoutDescriptor {
141 label: self.label.as_deref(),
142 bind_group_layouts: &self.bind_group_layouts,
143 immediate_size: 0,
144 });
145
146 let pipeline = self
147 .device
148 .create_compute_pipeline(&ComputePipelineDescriptor {
149 label: self.label.as_deref(),
150 layout: Some(&pipeline_layout),
151 module: self.shader,
152 entry_point: Some(&self.entry_point),
153 compilation_options: Default::default(),
154 cache: None,
155 });
156
157 debug!("Compute pipeline created: {:?}", self.label);
158 Ok(pipeline)
159 }
160}
161
162pub fn storage_buffer_layout(binding: u32, read_only: bool) -> BindGroupLayoutEntry {
164 BindGroupLayoutEntry {
165 binding,
166 visibility: ShaderStages::COMPUTE,
167 ty: BindingType::Buffer {
168 ty: BufferBindingType::Storage { read_only },
169 has_dynamic_offset: false,
170 min_binding_size: None,
171 },
172 count: None,
173 }
174}
175
176pub fn uniform_buffer_layout(binding: u32) -> BindGroupLayoutEntry {
178 BindGroupLayoutEntry {
179 binding,
180 visibility: ShaderStages::COMPUTE,
181 ty: BindingType::Buffer {
182 ty: BufferBindingType::Uniform,
183 has_dynamic_offset: false,
184 min_binding_size: None,
185 },
186 count: None,
187 }
188}
189
190pub fn create_compute_bind_group_layout(
196 device: &Device,
197 entries: &[BindGroupLayoutEntry],
198 label: Option<&str>,
199) -> GpuResult<BindGroupLayout> {
200 Ok(device.create_bind_group_layout(&BindGroupLayoutDescriptor { label, entries }))
201}
202
203pub struct ShaderLibrary;
205
206impl ShaderLibrary {
207 pub fn common_utils() -> &'static str {
209 r#"
210// Common utility functions for WGSL shaders
211
212// Convert 2D coordinates to 1D index
213fn coord_to_index(x: u32, y: u32, width: u32) -> u32 {
214 return y * width + x;
215}
216
217// Convert 1D index to 2D coordinates
218fn index_to_coord(index: u32, width: u32) -> vec2<u32> {
219 return vec2<u32>(index % width, index / width);
220}
221
222// Clamp value to range [min, max]
223fn clamp_value(value: f32, min_val: f32, max_val: f32) -> f32 {
224 return clamp(value, min_val, max_val);
225}
226
227// Linear interpolation
228fn lerp(a: f32, b: f32, t: f32) -> f32 {
229 return a + (b - a) * t;
230}
231
232// Bilinear interpolation
233fn bilinear_interp(
234 v00: f32, v10: f32,
235 v01: f32, v11: f32,
236 tx: f32, ty: f32
237) -> f32 {
238 let v0 = lerp(v00, v10, tx);
239 let v1 = lerp(v01, v11, tx);
240 return lerp(v0, v1, ty);
241}
242
243// Safe division (returns 0 if denominator is 0)
244fn safe_div(num: f32, denom: f32) -> f32 {
245 if (abs(denom) < 1e-10) {
246 return 0.0;
247 }
248 return num / denom;
249}
250
251// Check if value is NaN
252fn is_nan(value: f32) -> bool {
253 return value != value;
254}
255
256// Check if value is infinite
257fn is_inf(value: f32) -> bool {
258 return abs(value) > 1e38;
259}
260
261// Safe value (replace NaN/Inf with 0)
262fn safe_value(value: f32) -> f32 {
263 if (is_nan(value) || is_inf(value)) {
264 return 0.0;
265 }
266 return value;
267}
268"#
269 }
270
271 pub fn ndvi_shader() -> &'static str {
273 r#"
274@group(0) @binding(0) var<storage, read> nir: array<f32>;
275@group(0) @binding(1) var<storage, read> red: array<f32>;
276@group(0) @binding(2) var<storage, read_write> output: array<f32>;
277
278@compute @workgroup_size(256)
279fn ndvi(@builtin(global_invocation_id) global_id: vec3<u32>) {
280 let idx = global_id.x;
281 if (idx >= arrayLength(&output)) {
282 return;
283 }
284
285 let nir_val = nir[idx];
286 let red_val = red[idx];
287 let sum = nir_val + red_val;
288
289 if (abs(sum) < 1e-10) {
290 output[idx] = 0.0;
291 } else {
292 output[idx] = (nir_val - red_val) / sum;
293 }
294}
295"#
296 }
297
298 pub fn add_shader() -> &'static str {
300 r#"
301@group(0) @binding(0) var<storage, read> input_a: array<f32>;
302@group(0) @binding(1) var<storage, read> input_b: array<f32>;
303@group(0) @binding(2) var<storage, read_write> output: array<f32>;
304
305@compute @workgroup_size(256)
306fn add(@builtin(global_invocation_id) global_id: vec3<u32>) {
307 let idx = global_id.x;
308 if (idx >= arrayLength(&output)) {
309 return;
310 }
311 output[idx] = input_a[idx] + input_b[idx];
312}
313"#
314 }
315
316 pub fn multiply_shader() -> &'static str {
318 r#"
319@group(0) @binding(0) var<storage, read> input_a: array<f32>;
320@group(0) @binding(1) var<storage, read> input_b: array<f32>;
321@group(0) @binding(2) var<storage, read_write> output: array<f32>;
322
323@compute @workgroup_size(256)
324fn multiply(@builtin(global_invocation_id) global_id: vec3<u32>) {
325 let idx = global_id.x;
326 if (idx >= arrayLength(&output)) {
327 return;
328 }
329 output[idx] = input_a[idx] * input_b[idx];
330}
331"#
332 }
333
334 pub fn threshold_shader() -> &'static str {
336 r#"
337struct Params {
338 threshold: f32,
339 value_below: f32,
340 value_above: f32,
341}
342
343@group(0) @binding(0) var<storage, read> input: array<f32>;
344@group(0) @binding(1) var<uniform> params: Params;
345@group(0) @binding(2) var<storage, read_write> output: array<f32>;
346
347@compute @workgroup_size(256)
348fn threshold(@builtin(global_invocation_id) global_id: vec3<u32>) {
349 let idx = global_id.x;
350 if (idx >= arrayLength(&output)) {
351 return;
352 }
353
354 if (input[idx] < params.threshold) {
355 output[idx] = params.value_below;
356 } else {
357 output[idx] = params.value_above;
358 }
359}
360"#
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn test_shader_library() {
370 let utils = ShaderLibrary::common_utils();
371 assert!(utils.contains("coord_to_index"));
372 assert!(utils.contains("bilinear_interp"));
373
374 let ndvi = ShaderLibrary::ndvi_shader();
375 assert!(ndvi.contains("@compute"));
376 assert!(ndvi.contains("workgroup_size"));
377 }
378
379 #[test]
380 fn test_shader_validation() {
381 let shader = WgslShader::new(ShaderLibrary::add_shader(), "add");
382
383 let _ = shader.validate();
385 }
386
387 #[test]
388 fn test_bind_group_layout_helpers() {
389 let storage_ro = storage_buffer_layout(0, true);
390 assert_eq!(storage_ro.binding, 0);
391
392 let storage_rw = storage_buffer_layout(1, false);
393 assert_eq!(storage_rw.binding, 1);
394
395 let uniform = uniform_buffer_layout(2);
396 assert_eq!(uniform.binding, 2);
397 }
398}