use crate::error::{GpuError, GpuResult};
use std::borrow::Cow;
use tracing::debug;
use wgpu::{
BindGroupLayout, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType,
BufferBindingType, ComputePipeline, ComputePipelineDescriptor, Device,
PipelineLayoutDescriptor, ShaderModule, ShaderModuleDescriptor, ShaderSource, ShaderStages,
naga,
};
pub struct WgslShader {
source: String,
entry_point: String,
module: Option<ShaderModule>,
}
impl WgslShader {
pub fn new(source: impl Into<String>, entry_point: impl Into<String>) -> Self {
Self {
source: source.into(),
entry_point: entry_point.into(),
module: None,
}
}
pub fn validate(&self) -> GpuResult<()> {
let module = naga::front::wgsl::parse_str(&self.source)
.map_err(|e| GpuError::shader_compilation(e.to_string()))?;
let mut validator = naga::valid::Validator::new(
naga::valid::ValidationFlags::all(),
naga::valid::Capabilities::all(),
);
validator
.validate(&module)
.map_err(|e| GpuError::shader_validation(e.to_string()))?;
debug!("Shader validation successful");
Ok(())
}
pub fn compile(&mut self, device: &Device) -> GpuResult<&ShaderModule> {
if self.module.is_none() {
self.validate()?;
let module = device.create_shader_module(ShaderModuleDescriptor {
label: Some(&format!("Shader: {}", self.entry_point)),
source: ShaderSource::Wgsl(Cow::Borrowed(&self.source)),
});
self.module = Some(module);
debug!("Shader compiled: {}", self.entry_point);
}
Ok(self
.module
.as_ref()
.ok_or_else(|| GpuError::internal("Module should be compiled"))?)
}
pub fn entry_point(&self) -> &str {
&self.entry_point
}
pub fn source(&self) -> &str {
&self.source
}
}
pub struct ComputePipelineBuilder<'a> {
device: &'a Device,
shader: &'a ShaderModule,
entry_point: String,
bind_group_layouts: Vec<Option<&'a BindGroupLayout>>,
label: Option<String>,
}
impl<'a> ComputePipelineBuilder<'a> {
pub fn new(
device: &'a Device,
shader: &'a ShaderModule,
entry_point: impl Into<String>,
) -> Self {
Self {
device,
shader,
entry_point: entry_point.into(),
bind_group_layouts: Vec::new(),
label: None,
}
}
pub fn bind_group_layout(mut self, layout: &'a BindGroupLayout) -> Self {
self.bind_group_layouts.push(Some(layout));
self
}
pub fn label(mut self, label: impl Into<String>) -> Self {
self.label = Some(label.into());
self
}
pub fn build(self) -> GpuResult<ComputePipeline> {
let pipeline_layout = self
.device
.create_pipeline_layout(&PipelineLayoutDescriptor {
label: self.label.as_deref(),
bind_group_layouts: &self.bind_group_layouts,
immediate_size: 0,
});
let pipeline = self
.device
.create_compute_pipeline(&ComputePipelineDescriptor {
label: self.label.as_deref(),
layout: Some(&pipeline_layout),
module: self.shader,
entry_point: Some(&self.entry_point),
compilation_options: Default::default(),
cache: None,
});
debug!("Compute pipeline created: {:?}", self.label);
Ok(pipeline)
}
}
pub fn storage_buffer_layout(binding: u32, read_only: bool) -> BindGroupLayoutEntry {
BindGroupLayoutEntry {
binding,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}
}
pub fn uniform_buffer_layout(binding: u32) -> BindGroupLayoutEntry {
BindGroupLayoutEntry {
binding,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}
}
pub fn create_compute_bind_group_layout(
device: &Device,
entries: &[BindGroupLayoutEntry],
label: Option<&str>,
) -> GpuResult<BindGroupLayout> {
Ok(device.create_bind_group_layout(&BindGroupLayoutDescriptor { label, entries }))
}
pub struct ShaderLibrary;
impl ShaderLibrary {
pub fn common_utils() -> &'static str {
r#"
// Common utility functions for WGSL shaders
// Convert 2D coordinates to 1D index
fn coord_to_index(x: u32, y: u32, width: u32) -> u32 {
return y * width + x;
}
// Convert 1D index to 2D coordinates
fn index_to_coord(index: u32, width: u32) -> vec2<u32> {
return vec2<u32>(index % width, index / width);
}
// Clamp value to range [min, max]
fn clamp_value(value: f32, min_val: f32, max_val: f32) -> f32 {
return clamp(value, min_val, max_val);
}
// Linear interpolation
fn lerp(a: f32, b: f32, t: f32) -> f32 {
return a + (b - a) * t;
}
// Bilinear interpolation
fn bilinear_interp(
v00: f32, v10: f32,
v01: f32, v11: f32,
tx: f32, ty: f32
) -> f32 {
let v0 = lerp(v00, v10, tx);
let v1 = lerp(v01, v11, tx);
return lerp(v0, v1, ty);
}
// Safe division (returns 0 if denominator is 0)
fn safe_div(num: f32, denom: f32) -> f32 {
if (abs(denom) < 1e-10) {
return 0.0;
}
return num / denom;
}
// Check if value is NaN
fn is_nan(value: f32) -> bool {
return value != value;
}
// Check if value is infinite
fn is_inf(value: f32) -> bool {
return abs(value) > 1e38;
}
// Safe value (replace NaN/Inf with 0)
fn safe_value(value: f32) -> f32 {
if (is_nan(value) || is_inf(value)) {
return 0.0;
}
return value;
}
"#
}
pub fn ndvi_shader() -> &'static str {
r#"
@group(0) @binding(0) var<storage, read> nir: array<f32>;
@group(0) @binding(1) var<storage, read> red: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@compute @workgroup_size(256)
fn ndvi(@builtin(global_invocation_id) global_id: vec3<u32>) {
let idx = global_id.x;
if (idx >= arrayLength(&output)) {
return;
}
let nir_val = nir[idx];
let red_val = red[idx];
let sum = nir_val + red_val;
if (abs(sum) < 1e-10) {
output[idx] = 0.0;
} else {
output[idx] = (nir_val - red_val) / sum;
}
}
"#
}
pub fn add_shader() -> &'static str {
r#"
@group(0) @binding(0) var<storage, read> input_a: array<f32>;
@group(0) @binding(1) var<storage, read> input_b: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@compute @workgroup_size(256)
fn add(@builtin(global_invocation_id) global_id: vec3<u32>) {
let idx = global_id.x;
if (idx >= arrayLength(&output)) {
return;
}
output[idx] = input_a[idx] + input_b[idx];
}
"#
}
pub fn multiply_shader() -> &'static str {
r#"
@group(0) @binding(0) var<storage, read> input_a: array<f32>;
@group(0) @binding(1) var<storage, read> input_b: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@compute @workgroup_size(256)
fn multiply(@builtin(global_invocation_id) global_id: vec3<u32>) {
let idx = global_id.x;
if (idx >= arrayLength(&output)) {
return;
}
output[idx] = input_a[idx] * input_b[idx];
}
"#
}
pub fn threshold_shader() -> &'static str {
r#"
struct Params {
threshold: f32,
value_below: f32,
value_above: f32,
}
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<uniform> params: Params;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@compute @workgroup_size(256)
fn threshold(@builtin(global_invocation_id) global_id: vec3<u32>) {
let idx = global_id.x;
if (idx >= arrayLength(&output)) {
return;
}
if (input[idx] < params.threshold) {
output[idx] = params.value_below;
} else {
output[idx] = params.value_above;
}
}
"#
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shader_library() {
let utils = ShaderLibrary::common_utils();
assert!(utils.contains("coord_to_index"));
assert!(utils.contains("bilinear_interp"));
let ndvi = ShaderLibrary::ndvi_shader();
assert!(ndvi.contains("@compute"));
assert!(ndvi.contains("workgroup_size"));
}
#[test]
fn test_shader_validation() {
let shader = WgslShader::new(ShaderLibrary::add_shader(), "add");
let _ = shader.validate();
}
#[test]
fn test_bind_group_layout_helpers() {
let storage_ro = storage_buffer_layout(0, true);
assert_eq!(storage_ro.binding, 0);
let storage_rw = storage_buffer_layout(1, false);
assert_eq!(storage_rw.binding, 1);
let uniform = uniform_buffer_layout(2);
assert_eq!(uniform.binding, 2);
}
}