Skip to main content

oxigdal_gpu/kernels/
raster.rs

1//! GPU kernels for element-wise raster operations.
2//!
3//! This module provides GPU-accelerated element-wise operations on rasters,
4//! including arithmetic, logical, and transformation operations.
5
6use crate::buffer::GpuBuffer;
7use crate::context::GpuContext;
8use crate::error::{GpuError, GpuResult};
9use crate::shaders::{
10    ComputePipelineBuilder, WgslShader, create_compute_bind_group_layout, storage_buffer_layout,
11};
12use bytemuck::Pod;
13use tracing::debug;
14use wgpu::{
15    BindGroup, BindGroupDescriptor, BindGroupEntry, CommandEncoderDescriptor,
16    ComputePassDescriptor, ComputePipeline,
17};
18
19/// Element-wise operation type.
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ElementWiseOp {
22    /// Addition: a + b
23    Add,
24    /// Subtraction: a - b
25    Subtract,
26    /// Multiplication: a * b
27    Multiply,
28    /// Division: a / b
29    Divide,
30    /// Power: a ^ b
31    Power,
32    /// Minimum: min(a, b)
33    Min,
34    /// Maximum: max(a, b)
35    Max,
36    /// Modulo: a % b
37    Modulo,
38}
39
40impl ElementWiseOp {
41    /// Get the WGSL shader source for this operation.
42    fn shader_source(&self) -> &'static str {
43        match self {
44            Self::Add => include_str!("shaders/add.wgsl"),
45            Self::Subtract => include_str!("shaders/subtract.wgsl"),
46            Self::Multiply => include_str!("shaders/multiply.wgsl"),
47            Self::Divide => include_str!("shaders/divide.wgsl"),
48            Self::Power => include_str!("shaders/power.wgsl"),
49            Self::Min => include_str!("shaders/min.wgsl"),
50            Self::Max => include_str!("shaders/max.wgsl"),
51            Self::Modulo => include_str!("shaders/modulo.wgsl"),
52        }
53    }
54
55    /// Get the entry point name for this operation.
56    fn entry_point(&self) -> &'static str {
57        match self {
58            Self::Add => "add",
59            Self::Subtract => "subtract",
60            Self::Multiply => "multiply",
61            Self::Divide => "divide",
62            Self::Power => "power",
63            Self::Min => "min_op",
64            Self::Max => "max_op",
65            Self::Modulo => "modulo",
66        }
67    }
68
69    /// Get a fallback inline shader if external shader not available.
70    fn inline_shader(&self) -> String {
71        let op_expr = match self {
72            Self::Add => "input_a[idx] + input_b[idx]",
73            Self::Subtract => "input_a[idx] - input_b[idx]",
74            Self::Multiply => "input_a[idx] * input_b[idx]",
75            Self::Divide => "safe_div(input_a[idx], input_b[idx])",
76            Self::Power => "pow(input_a[idx], input_b[idx])",
77            Self::Min => "min(input_a[idx], input_b[idx])",
78            Self::Max => "max(input_a[idx], input_b[idx])",
79            Self::Modulo => "input_a[idx] % input_b[idx]",
80        };
81
82        format!(
83            r#"
84@group(0) @binding(0) var<storage, read> input_a: array<f32>;
85@group(0) @binding(1) var<storage, read> input_b: array<f32>;
86@group(0) @binding(2) var<storage, read_write> output: array<f32>;
87
88fn safe_div(num: f32, denom: f32) -> f32 {{
89    if (abs(denom) < 1e-10) {{
90        return 0.0;
91    }}
92    return num / denom;
93}}
94
95@compute @workgroup_size(256)
96fn {entry}(@builtin(global_invocation_id) global_id: vec3<u32>) {{
97    let idx = global_id.x;
98    if (idx >= arrayLength(&output)) {{
99        return;
100    }}
101    output[idx] = {op_expr};
102}}
103"#,
104            entry = self.entry_point(),
105            op_expr = op_expr
106        )
107    }
108}
109
110/// GPU kernel for element-wise raster operations.
111pub struct RasterKernel {
112    context: GpuContext,
113    pipeline: ComputePipeline,
114    bind_group_layout: wgpu::BindGroupLayout,
115    workgroup_size: u32,
116}
117
118impl RasterKernel {
119    /// Create a new raster kernel for the specified operation.
120    ///
121    /// # Errors
122    ///
123    /// Returns an error if shader compilation or pipeline creation fails.
124    pub fn new(context: &GpuContext, op: ElementWiseOp) -> GpuResult<Self> {
125        debug!("Creating raster kernel for operation: {:?}", op);
126
127        // Create shader - use inline shader as fallback
128        let shader_source = op.inline_shader();
129        let mut shader = WgslShader::new(shader_source, op.entry_point());
130        let shader_module = shader.compile(context.device())?;
131
132        // Create bind group layout
133        let bind_group_layout = create_compute_bind_group_layout(
134            context.device(),
135            &[
136                storage_buffer_layout(0, true),  // input_a (read-only)
137                storage_buffer_layout(1, true),  // input_b (read-only)
138                storage_buffer_layout(2, false), // output (read-write)
139            ],
140            Some("RasterKernel BindGroupLayout"),
141        )?;
142
143        // Create pipeline
144        let pipeline =
145            ComputePipelineBuilder::new(context.device(), shader_module, op.entry_point())
146                .bind_group_layout(&bind_group_layout)
147                .label(format!("RasterKernel Pipeline: {:?}", op))
148                .build()?;
149
150        Ok(Self {
151            context: context.clone(),
152            pipeline,
153            bind_group_layout,
154            workgroup_size: 256,
155        })
156    }
157
158    /// Execute the kernel on GPU buffers.
159    ///
160    /// # Errors
161    ///
162    /// Returns an error if buffer sizes don't match or execution fails.
163    pub fn execute<T: Pod>(
164        &self,
165        input_a: &GpuBuffer<T>,
166        input_b: &GpuBuffer<T>,
167        output: &mut GpuBuffer<T>,
168    ) -> GpuResult<()> {
169        // Validate buffer sizes
170        if input_a.len() != input_b.len() || input_a.len() != output.len() {
171            return Err(GpuError::invalid_kernel_params(format!(
172                "Buffer size mismatch: {} != {} != {}",
173                input_a.len(),
174                input_b.len(),
175                output.len()
176            )));
177        }
178
179        let num_elements = input_a.len();
180
181        // Create bind group
182        let bind_group = self.create_bind_group(input_a, input_b, output)?;
183
184        // Create command encoder
185        let mut encoder = self
186            .context
187            .device()
188            .create_command_encoder(&CommandEncoderDescriptor {
189                label: Some("RasterKernel Encoder"),
190            });
191
192        // Compute pass
193        {
194            let mut compute_pass = encoder.begin_compute_pass(&ComputePassDescriptor {
195                label: Some("RasterKernel Pass"),
196                timestamp_writes: None,
197            });
198
199            compute_pass.set_pipeline(&self.pipeline);
200            compute_pass.set_bind_group(0, &bind_group, &[]);
201
202            let workgroup_count =
203                (num_elements as u32 + self.workgroup_size - 1) / self.workgroup_size;
204            compute_pass.dispatch_workgroups(workgroup_count, 1, 1);
205        }
206
207        // Submit commands
208        self.context.queue().submit(Some(encoder.finish()));
209
210        debug!("Executed raster kernel on {} elements", num_elements);
211        Ok(())
212    }
213
214    /// Create bind group for kernel execution.
215    fn create_bind_group<T: Pod>(
216        &self,
217        input_a: &GpuBuffer<T>,
218        input_b: &GpuBuffer<T>,
219        output: &GpuBuffer<T>,
220    ) -> GpuResult<BindGroup> {
221        let bind_group = self
222            .context
223            .device()
224            .create_bind_group(&BindGroupDescriptor {
225                label: Some("RasterKernel BindGroup"),
226                layout: &self.bind_group_layout,
227                entries: &[
228                    BindGroupEntry {
229                        binding: 0,
230                        resource: input_a.buffer().as_entire_binding(),
231                    },
232                    BindGroupEntry {
233                        binding: 1,
234                        resource: input_b.buffer().as_entire_binding(),
235                    },
236                    BindGroupEntry {
237                        binding: 2,
238                        resource: output.buffer().as_entire_binding(),
239                    },
240                ],
241            });
242
243        Ok(bind_group)
244    }
245}
246
247/// Unary operation type.
248#[derive(Debug, Clone, Copy, PartialEq, Eq)]
249pub enum UnaryOp {
250    /// Negate: -a
251    Negate,
252    /// Absolute value: |a|
253    Abs,
254    /// Square root: √a
255    Sqrt,
256    /// Square: a²
257    Square,
258    /// Natural logarithm: ln(a)
259    Log,
260    /// Exponential: e^a
261    Exp,
262    /// Sine: sin(a)
263    Sin,
264    /// Cosine: cos(a)
265    Cos,
266    /// Tangent: tan(a)
267    Tan,
268}
269
270impl UnaryOp {
271    /// Get inline shader for this operation.
272    fn inline_shader(&self) -> String {
273        let op_expr = match self {
274            Self::Negate => "-input[idx]",
275            Self::Abs => "abs(input[idx])",
276            Self::Sqrt => "sqrt(max(input[idx], 0.0))",
277            Self::Square => "input[idx] * input[idx]",
278            Self::Log => "log(max(input[idx], 1e-10))",
279            Self::Exp => "exp(input[idx])",
280            Self::Sin => "sin(input[idx])",
281            Self::Cos => "cos(input[idx])",
282            Self::Tan => "tan(input[idx])",
283        };
284
285        format!(
286            r#"
287@group(0) @binding(0) var<storage, read> input: array<f32>;
288@group(0) @binding(1) var<storage, read_write> output: array<f32>;
289
290@compute @workgroup_size(256)
291fn unary_op(@builtin(global_invocation_id) global_id: vec3<u32>) {{
292    let idx = global_id.x;
293    if (idx >= arrayLength(&output)) {{
294        return;
295    }}
296    output[idx] = {op_expr};
297}}
298"#,
299            op_expr = op_expr
300        )
301    }
302}
303
304/// GPU kernel for unary raster operations.
305pub struct UnaryKernel {
306    context: GpuContext,
307    pipeline: ComputePipeline,
308    bind_group_layout: wgpu::BindGroupLayout,
309    workgroup_size: u32,
310}
311
312impl UnaryKernel {
313    /// Create a new unary kernel for the specified operation.
314    ///
315    /// # Errors
316    ///
317    /// Returns an error if shader compilation or pipeline creation fails.
318    pub fn new(context: &GpuContext, op: UnaryOp) -> GpuResult<Self> {
319        debug!("Creating unary kernel for operation: {:?}", op);
320
321        let shader_source = op.inline_shader();
322        let mut shader = WgslShader::new(shader_source, "unary_op");
323        let shader_module = shader.compile(context.device())?;
324
325        let bind_group_layout = create_compute_bind_group_layout(
326            context.device(),
327            &[
328                storage_buffer_layout(0, true),  // input (read-only)
329                storage_buffer_layout(1, false), // output (read-write)
330            ],
331            Some("UnaryKernel BindGroupLayout"),
332        )?;
333
334        let pipeline = ComputePipelineBuilder::new(context.device(), shader_module, "unary_op")
335            .bind_group_layout(&bind_group_layout)
336            .label(format!("UnaryKernel Pipeline: {:?}", op))
337            .build()?;
338
339        Ok(Self {
340            context: context.clone(),
341            pipeline,
342            bind_group_layout,
343            workgroup_size: 256,
344        })
345    }
346
347    /// Execute the kernel on GPU buffer.
348    ///
349    /// # Errors
350    ///
351    /// Returns an error if buffer sizes don't match or execution fails.
352    pub fn execute<T: Pod>(
353        &self,
354        input: &GpuBuffer<T>,
355        output: &mut GpuBuffer<T>,
356    ) -> GpuResult<()> {
357        if input.len() != output.len() {
358            return Err(GpuError::invalid_kernel_params(format!(
359                "Buffer size mismatch: {} != {}",
360                input.len(),
361                output.len()
362            )));
363        }
364
365        let num_elements = input.len();
366
367        let bind_group = self
368            .context
369            .device()
370            .create_bind_group(&BindGroupDescriptor {
371                label: Some("UnaryKernel BindGroup"),
372                layout: &self.bind_group_layout,
373                entries: &[
374                    BindGroupEntry {
375                        binding: 0,
376                        resource: input.buffer().as_entire_binding(),
377                    },
378                    BindGroupEntry {
379                        binding: 1,
380                        resource: output.buffer().as_entire_binding(),
381                    },
382                ],
383            });
384
385        let mut encoder = self
386            .context
387            .device()
388            .create_command_encoder(&CommandEncoderDescriptor {
389                label: Some("UnaryKernel Encoder"),
390            });
391
392        {
393            let mut compute_pass = encoder.begin_compute_pass(&ComputePassDescriptor {
394                label: Some("UnaryKernel Pass"),
395                timestamp_writes: None,
396            });
397
398            compute_pass.set_pipeline(&self.pipeline);
399            compute_pass.set_bind_group(0, &bind_group, &[]);
400
401            let workgroup_count =
402                (num_elements as u32 + self.workgroup_size - 1) / self.workgroup_size;
403            compute_pass.dispatch_workgroups(workgroup_count, 1, 1);
404        }
405
406        self.context.queue().submit(Some(encoder.finish()));
407
408        debug!("Executed unary kernel on {} elements", num_elements);
409        Ok(())
410    }
411}
412
413/// Scalar operation type.
414#[derive(Debug, Clone, Copy, PartialEq)]
415pub enum ScalarOp {
416    /// Add scalar: a + c
417    Add(f32),
418    /// Multiply by scalar: a * c
419    Multiply(f32),
420    /// Clamp to range: clamp(a, min, max)
421    Clamp { min: f32, max: f32 },
422    /// Threshold: a > threshold ? above : below
423    Threshold {
424        threshold: f32,
425        above: f32,
426        below: f32,
427    },
428}
429
430impl ScalarOp {
431    /// Get inline shader for this operation.
432    fn inline_shader(&self) -> String {
433        match self {
434            Self::Add(value) => format!(
435                r#"
436@group(0) @binding(0) var<storage, read> input: array<f32>;
437@group(0) @binding(1) var<storage, read_write> output: array<f32>;
438
439@compute @workgroup_size(256)
440fn scalar_op(@builtin(global_invocation_id) global_id: vec3<u32>) {{
441    let idx = global_id.x;
442    if (idx >= arrayLength(&output)) {{
443        return;
444    }}
445    output[idx] = input[idx] + {value};
446}}
447"#,
448                value = value
449            ),
450            Self::Multiply(value) => format!(
451                r#"
452@group(0) @binding(0) var<storage, read> input: array<f32>;
453@group(0) @binding(1) var<storage, read_write> output: array<f32>;
454
455@compute @workgroup_size(256)
456fn scalar_op(@builtin(global_invocation_id) global_id: vec3<u32>) {{
457    let idx = global_id.x;
458    if (idx >= arrayLength(&output)) {{
459        return;
460    }}
461    output[idx] = input[idx] * {value};
462}}
463"#,
464                value = value
465            ),
466            Self::Clamp { min, max } => format!(
467                r#"
468@group(0) @binding(0) var<storage, read> input: array<f32>;
469@group(0) @binding(1) var<storage, read_write> output: array<f32>;
470
471@compute @workgroup_size(256)
472fn scalar_op(@builtin(global_invocation_id) global_id: vec3<u32>) {{
473    let idx = global_id.x;
474    if (idx >= arrayLength(&output)) {{
475        return;
476    }}
477    output[idx] = clamp(input[idx], {min}, {max});
478}}
479"#,
480                min = min,
481                max = max
482            ),
483            Self::Threshold {
484                threshold,
485                above,
486                below,
487            } => format!(
488                r#"
489@group(0) @binding(0) var<storage, read> input: array<f32>;
490@group(0) @binding(1) var<storage, read_write> output: array<f32>;
491
492@compute @workgroup_size(256)
493fn scalar_op(@builtin(global_invocation_id) global_id: vec3<u32>) {{
494    let idx = global_id.x;
495    if (idx >= arrayLength(&output)) {{
496        return;
497    }}
498    if (input[idx] > {threshold}) {{
499        output[idx] = {above};
500    }} else {{
501        output[idx] = {below};
502    }}
503}}
504"#,
505                threshold = threshold,
506                above = above,
507                below = below
508            ),
509        }
510    }
511}
512
513/// GPU kernel for scalar raster operations.
514pub struct ScalarKernel {
515    context: GpuContext,
516    pipeline: ComputePipeline,
517    bind_group_layout: wgpu::BindGroupLayout,
518    workgroup_size: u32,
519}
520
521impl ScalarKernel {
522    /// Create a new scalar kernel for the specified operation.
523    ///
524    /// # Errors
525    ///
526    /// Returns an error if shader compilation or pipeline creation fails.
527    pub fn new(context: &GpuContext, op: ScalarOp) -> GpuResult<Self> {
528        debug!("Creating scalar kernel for operation: {:?}", op);
529
530        let shader_source = op.inline_shader();
531        let mut shader = WgslShader::new(shader_source, "scalar_op");
532        let shader_module = shader.compile(context.device())?;
533
534        let bind_group_layout = create_compute_bind_group_layout(
535            context.device(),
536            &[
537                storage_buffer_layout(0, true),  // input (read-only)
538                storage_buffer_layout(1, false), // output (read-write)
539            ],
540            Some("ScalarKernel BindGroupLayout"),
541        )?;
542
543        let pipeline = ComputePipelineBuilder::new(context.device(), shader_module, "scalar_op")
544            .bind_group_layout(&bind_group_layout)
545            .label(format!("ScalarKernel Pipeline: {:?}", op))
546            .build()?;
547
548        Ok(Self {
549            context: context.clone(),
550            pipeline,
551            bind_group_layout,
552            workgroup_size: 256,
553        })
554    }
555
556    /// Execute the kernel on GPU buffer.
557    ///
558    /// # Errors
559    ///
560    /// Returns an error if execution fails.
561    pub fn execute<T: Pod>(
562        &self,
563        input: &GpuBuffer<T>,
564        output: &mut GpuBuffer<T>,
565    ) -> GpuResult<()> {
566        if input.len() != output.len() {
567            return Err(GpuError::invalid_kernel_params(format!(
568                "Buffer size mismatch: {} != {}",
569                input.len(),
570                output.len()
571            )));
572        }
573
574        let num_elements = input.len();
575
576        let bind_group = self
577            .context
578            .device()
579            .create_bind_group(&BindGroupDescriptor {
580                label: Some("ScalarKernel BindGroup"),
581                layout: &self.bind_group_layout,
582                entries: &[
583                    BindGroupEntry {
584                        binding: 0,
585                        resource: input.buffer().as_entire_binding(),
586                    },
587                    BindGroupEntry {
588                        binding: 1,
589                        resource: output.buffer().as_entire_binding(),
590                    },
591                ],
592            });
593
594        let mut encoder = self
595            .context
596            .device()
597            .create_command_encoder(&CommandEncoderDescriptor {
598                label: Some("ScalarKernel Encoder"),
599            });
600
601        {
602            let mut compute_pass = encoder.begin_compute_pass(&ComputePassDescriptor {
603                label: Some("ScalarKernel Pass"),
604                timestamp_writes: None,
605            });
606
607            compute_pass.set_pipeline(&self.pipeline);
608            compute_pass.set_bind_group(0, &bind_group, &[]);
609
610            let workgroup_count =
611                (num_elements as u32 + self.workgroup_size - 1) / self.workgroup_size;
612            compute_pass.dispatch_workgroups(workgroup_count, 1, 1);
613        }
614
615        self.context.queue().submit(Some(encoder.finish()));
616
617        debug!("Executed scalar kernel on {} elements", num_elements);
618        Ok(())
619    }
620}
621
622#[cfg(test)]
623mod tests {
624    use super::*;
625
626    #[test]
627    fn test_element_wise_op_shader() {
628        let op = ElementWiseOp::Add;
629        let shader = op.inline_shader();
630        assert!(shader.contains("@compute"));
631        assert!(shader.contains("add"));
632    }
633
634    #[test]
635    fn test_unary_op_shader() {
636        let op = UnaryOp::Sqrt;
637        let shader = op.inline_shader();
638        assert!(shader.contains("sqrt"));
639    }
640
641    #[test]
642    fn test_scalar_op_shader() {
643        let op = ScalarOp::Add(5.0);
644        let shader = op.inline_shader();
645        assert!(shader.contains("5"));
646    }
647
648    #[tokio::test]
649    async fn test_raster_kernel_execution() {
650        if let Ok(context) = GpuContext::new().await {
651            let kernel = RasterKernel::new(&context, ElementWiseOp::Add);
652            if let Ok(_kernel) = kernel {
653                // Kernel created successfully
654            }
655        }
656    }
657}