1use crate::buffer::GpuBuffer;
7use crate::context::GpuContext;
8use crate::error::{GpuError, GpuResult};
9use crate::shaders::{
10 ComputePipelineBuilder, WgslShader, create_compute_bind_group_layout, storage_buffer_layout,
11};
12use bytemuck::Pod;
13use tracing::debug;
14use wgpu::{
15 BindGroup, BindGroupDescriptor, BindGroupEntry, CommandEncoderDescriptor,
16 ComputePassDescriptor, ComputePipeline,
17};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ElementWiseOp {
22 Add,
24 Subtract,
26 Multiply,
28 Divide,
30 Power,
32 Min,
34 Max,
36 Modulo,
38}
39
40impl ElementWiseOp {
41 fn shader_source(&self) -> &'static str {
43 match self {
44 Self::Add => include_str!("shaders/add.wgsl"),
45 Self::Subtract => include_str!("shaders/subtract.wgsl"),
46 Self::Multiply => include_str!("shaders/multiply.wgsl"),
47 Self::Divide => include_str!("shaders/divide.wgsl"),
48 Self::Power => include_str!("shaders/power.wgsl"),
49 Self::Min => include_str!("shaders/min.wgsl"),
50 Self::Max => include_str!("shaders/max.wgsl"),
51 Self::Modulo => include_str!("shaders/modulo.wgsl"),
52 }
53 }
54
55 fn entry_point(&self) -> &'static str {
57 match self {
58 Self::Add => "add",
59 Self::Subtract => "subtract",
60 Self::Multiply => "multiply",
61 Self::Divide => "divide",
62 Self::Power => "power",
63 Self::Min => "min_op",
64 Self::Max => "max_op",
65 Self::Modulo => "modulo",
66 }
67 }
68
69 fn inline_shader(&self) -> String {
71 let op_expr = match self {
72 Self::Add => "input_a[idx] + input_b[idx]",
73 Self::Subtract => "input_a[idx] - input_b[idx]",
74 Self::Multiply => "input_a[idx] * input_b[idx]",
75 Self::Divide => "safe_div(input_a[idx], input_b[idx])",
76 Self::Power => "pow(input_a[idx], input_b[idx])",
77 Self::Min => "min(input_a[idx], input_b[idx])",
78 Self::Max => "max(input_a[idx], input_b[idx])",
79 Self::Modulo => "input_a[idx] % input_b[idx]",
80 };
81
82 format!(
83 r#"
84@group(0) @binding(0) var<storage, read> input_a: array<f32>;
85@group(0) @binding(1) var<storage, read> input_b: array<f32>;
86@group(0) @binding(2) var<storage, read_write> output: array<f32>;
87
88fn safe_div(num: f32, denom: f32) -> f32 {{
89 if (abs(denom) < 1e-10) {{
90 return 0.0;
91 }}
92 return num / denom;
93}}
94
95@compute @workgroup_size(256)
96fn {entry}(@builtin(global_invocation_id) global_id: vec3<u32>) {{
97 let idx = global_id.x;
98 if (idx >= arrayLength(&output)) {{
99 return;
100 }}
101 output[idx] = {op_expr};
102}}
103"#,
104 entry = self.entry_point(),
105 op_expr = op_expr
106 )
107 }
108}
109
110pub struct RasterKernel {
112 context: GpuContext,
113 pipeline: ComputePipeline,
114 bind_group_layout: wgpu::BindGroupLayout,
115 workgroup_size: u32,
116}
117
118impl RasterKernel {
119 pub fn new(context: &GpuContext, op: ElementWiseOp) -> GpuResult<Self> {
125 debug!("Creating raster kernel for operation: {:?}", op);
126
127 let shader_source = op.inline_shader();
129 let mut shader = WgslShader::new(shader_source, op.entry_point());
130 let shader_module = shader.compile(context.device())?;
131
132 let bind_group_layout = create_compute_bind_group_layout(
134 context.device(),
135 &[
136 storage_buffer_layout(0, true), storage_buffer_layout(1, true), storage_buffer_layout(2, false), ],
140 Some("RasterKernel BindGroupLayout"),
141 )?;
142
143 let pipeline =
145 ComputePipelineBuilder::new(context.device(), shader_module, op.entry_point())
146 .bind_group_layout(&bind_group_layout)
147 .label(format!("RasterKernel Pipeline: {:?}", op))
148 .build()?;
149
150 Ok(Self {
151 context: context.clone(),
152 pipeline,
153 bind_group_layout,
154 workgroup_size: 256,
155 })
156 }
157
158 pub fn execute<T: Pod>(
164 &self,
165 input_a: &GpuBuffer<T>,
166 input_b: &GpuBuffer<T>,
167 output: &mut GpuBuffer<T>,
168 ) -> GpuResult<()> {
169 if input_a.len() != input_b.len() || input_a.len() != output.len() {
171 return Err(GpuError::invalid_kernel_params(format!(
172 "Buffer size mismatch: {} != {} != {}",
173 input_a.len(),
174 input_b.len(),
175 output.len()
176 )));
177 }
178
179 let num_elements = input_a.len();
180
181 let bind_group = self.create_bind_group(input_a, input_b, output)?;
183
184 let mut encoder = self
186 .context
187 .device()
188 .create_command_encoder(&CommandEncoderDescriptor {
189 label: Some("RasterKernel Encoder"),
190 });
191
192 {
194 let mut compute_pass = encoder.begin_compute_pass(&ComputePassDescriptor {
195 label: Some("RasterKernel Pass"),
196 timestamp_writes: None,
197 });
198
199 compute_pass.set_pipeline(&self.pipeline);
200 compute_pass.set_bind_group(0, &bind_group, &[]);
201
202 let workgroup_count =
203 (num_elements as u32 + self.workgroup_size - 1) / self.workgroup_size;
204 compute_pass.dispatch_workgroups(workgroup_count, 1, 1);
205 }
206
207 self.context.queue().submit(Some(encoder.finish()));
209
210 debug!("Executed raster kernel on {} elements", num_elements);
211 Ok(())
212 }
213
214 fn create_bind_group<T: Pod>(
216 &self,
217 input_a: &GpuBuffer<T>,
218 input_b: &GpuBuffer<T>,
219 output: &GpuBuffer<T>,
220 ) -> GpuResult<BindGroup> {
221 let bind_group = self
222 .context
223 .device()
224 .create_bind_group(&BindGroupDescriptor {
225 label: Some("RasterKernel BindGroup"),
226 layout: &self.bind_group_layout,
227 entries: &[
228 BindGroupEntry {
229 binding: 0,
230 resource: input_a.buffer().as_entire_binding(),
231 },
232 BindGroupEntry {
233 binding: 1,
234 resource: input_b.buffer().as_entire_binding(),
235 },
236 BindGroupEntry {
237 binding: 2,
238 resource: output.buffer().as_entire_binding(),
239 },
240 ],
241 });
242
243 Ok(bind_group)
244 }
245}
246
247#[derive(Debug, Clone, Copy, PartialEq, Eq)]
249pub enum UnaryOp {
250 Negate,
252 Abs,
254 Sqrt,
256 Square,
258 Log,
260 Exp,
262 Sin,
264 Cos,
266 Tan,
268}
269
270impl UnaryOp {
271 fn inline_shader(&self) -> String {
273 let op_expr = match self {
274 Self::Negate => "-input[idx]",
275 Self::Abs => "abs(input[idx])",
276 Self::Sqrt => "sqrt(max(input[idx], 0.0))",
277 Self::Square => "input[idx] * input[idx]",
278 Self::Log => "log(max(input[idx], 1e-10))",
279 Self::Exp => "exp(input[idx])",
280 Self::Sin => "sin(input[idx])",
281 Self::Cos => "cos(input[idx])",
282 Self::Tan => "tan(input[idx])",
283 };
284
285 format!(
286 r#"
287@group(0) @binding(0) var<storage, read> input: array<f32>;
288@group(0) @binding(1) var<storage, read_write> output: array<f32>;
289
290@compute @workgroup_size(256)
291fn unary_op(@builtin(global_invocation_id) global_id: vec3<u32>) {{
292 let idx = global_id.x;
293 if (idx >= arrayLength(&output)) {{
294 return;
295 }}
296 output[idx] = {op_expr};
297}}
298"#,
299 op_expr = op_expr
300 )
301 }
302}
303
304pub struct UnaryKernel {
306 context: GpuContext,
307 pipeline: ComputePipeline,
308 bind_group_layout: wgpu::BindGroupLayout,
309 workgroup_size: u32,
310}
311
312impl UnaryKernel {
313 pub fn new(context: &GpuContext, op: UnaryOp) -> GpuResult<Self> {
319 debug!("Creating unary kernel for operation: {:?}", op);
320
321 let shader_source = op.inline_shader();
322 let mut shader = WgslShader::new(shader_source, "unary_op");
323 let shader_module = shader.compile(context.device())?;
324
325 let bind_group_layout = create_compute_bind_group_layout(
326 context.device(),
327 &[
328 storage_buffer_layout(0, true), storage_buffer_layout(1, false), ],
331 Some("UnaryKernel BindGroupLayout"),
332 )?;
333
334 let pipeline = ComputePipelineBuilder::new(context.device(), shader_module, "unary_op")
335 .bind_group_layout(&bind_group_layout)
336 .label(format!("UnaryKernel Pipeline: {:?}", op))
337 .build()?;
338
339 Ok(Self {
340 context: context.clone(),
341 pipeline,
342 bind_group_layout,
343 workgroup_size: 256,
344 })
345 }
346
347 pub fn execute<T: Pod>(
353 &self,
354 input: &GpuBuffer<T>,
355 output: &mut GpuBuffer<T>,
356 ) -> GpuResult<()> {
357 if input.len() != output.len() {
358 return Err(GpuError::invalid_kernel_params(format!(
359 "Buffer size mismatch: {} != {}",
360 input.len(),
361 output.len()
362 )));
363 }
364
365 let num_elements = input.len();
366
367 let bind_group = self
368 .context
369 .device()
370 .create_bind_group(&BindGroupDescriptor {
371 label: Some("UnaryKernel BindGroup"),
372 layout: &self.bind_group_layout,
373 entries: &[
374 BindGroupEntry {
375 binding: 0,
376 resource: input.buffer().as_entire_binding(),
377 },
378 BindGroupEntry {
379 binding: 1,
380 resource: output.buffer().as_entire_binding(),
381 },
382 ],
383 });
384
385 let mut encoder = self
386 .context
387 .device()
388 .create_command_encoder(&CommandEncoderDescriptor {
389 label: Some("UnaryKernel Encoder"),
390 });
391
392 {
393 let mut compute_pass = encoder.begin_compute_pass(&ComputePassDescriptor {
394 label: Some("UnaryKernel Pass"),
395 timestamp_writes: None,
396 });
397
398 compute_pass.set_pipeline(&self.pipeline);
399 compute_pass.set_bind_group(0, &bind_group, &[]);
400
401 let workgroup_count =
402 (num_elements as u32 + self.workgroup_size - 1) / self.workgroup_size;
403 compute_pass.dispatch_workgroups(workgroup_count, 1, 1);
404 }
405
406 self.context.queue().submit(Some(encoder.finish()));
407
408 debug!("Executed unary kernel on {} elements", num_elements);
409 Ok(())
410 }
411}
412
413#[derive(Debug, Clone, Copy, PartialEq)]
415pub enum ScalarOp {
416 Add(f32),
418 Multiply(f32),
420 Clamp { min: f32, max: f32 },
422 Threshold {
424 threshold: f32,
425 above: f32,
426 below: f32,
427 },
428}
429
430impl ScalarOp {
431 fn inline_shader(&self) -> String {
433 match self {
434 Self::Add(value) => format!(
435 r#"
436@group(0) @binding(0) var<storage, read> input: array<f32>;
437@group(0) @binding(1) var<storage, read_write> output: array<f32>;
438
439@compute @workgroup_size(256)
440fn scalar_op(@builtin(global_invocation_id) global_id: vec3<u32>) {{
441 let idx = global_id.x;
442 if (idx >= arrayLength(&output)) {{
443 return;
444 }}
445 output[idx] = input[idx] + {value};
446}}
447"#,
448 value = value
449 ),
450 Self::Multiply(value) => format!(
451 r#"
452@group(0) @binding(0) var<storage, read> input: array<f32>;
453@group(0) @binding(1) var<storage, read_write> output: array<f32>;
454
455@compute @workgroup_size(256)
456fn scalar_op(@builtin(global_invocation_id) global_id: vec3<u32>) {{
457 let idx = global_id.x;
458 if (idx >= arrayLength(&output)) {{
459 return;
460 }}
461 output[idx] = input[idx] * {value};
462}}
463"#,
464 value = value
465 ),
466 Self::Clamp { min, max } => format!(
467 r#"
468@group(0) @binding(0) var<storage, read> input: array<f32>;
469@group(0) @binding(1) var<storage, read_write> output: array<f32>;
470
471@compute @workgroup_size(256)
472fn scalar_op(@builtin(global_invocation_id) global_id: vec3<u32>) {{
473 let idx = global_id.x;
474 if (idx >= arrayLength(&output)) {{
475 return;
476 }}
477 output[idx] = clamp(input[idx], {min}, {max});
478}}
479"#,
480 min = min,
481 max = max
482 ),
483 Self::Threshold {
484 threshold,
485 above,
486 below,
487 } => format!(
488 r#"
489@group(0) @binding(0) var<storage, read> input: array<f32>;
490@group(0) @binding(1) var<storage, read_write> output: array<f32>;
491
492@compute @workgroup_size(256)
493fn scalar_op(@builtin(global_invocation_id) global_id: vec3<u32>) {{
494 let idx = global_id.x;
495 if (idx >= arrayLength(&output)) {{
496 return;
497 }}
498 if (input[idx] > {threshold}) {{
499 output[idx] = {above};
500 }} else {{
501 output[idx] = {below};
502 }}
503}}
504"#,
505 threshold = threshold,
506 above = above,
507 below = below
508 ),
509 }
510 }
511}
512
513pub struct ScalarKernel {
515 context: GpuContext,
516 pipeline: ComputePipeline,
517 bind_group_layout: wgpu::BindGroupLayout,
518 workgroup_size: u32,
519}
520
521impl ScalarKernel {
522 pub fn new(context: &GpuContext, op: ScalarOp) -> GpuResult<Self> {
528 debug!("Creating scalar kernel for operation: {:?}", op);
529
530 let shader_source = op.inline_shader();
531 let mut shader = WgslShader::new(shader_source, "scalar_op");
532 let shader_module = shader.compile(context.device())?;
533
534 let bind_group_layout = create_compute_bind_group_layout(
535 context.device(),
536 &[
537 storage_buffer_layout(0, true), storage_buffer_layout(1, false), ],
540 Some("ScalarKernel BindGroupLayout"),
541 )?;
542
543 let pipeline = ComputePipelineBuilder::new(context.device(), shader_module, "scalar_op")
544 .bind_group_layout(&bind_group_layout)
545 .label(format!("ScalarKernel Pipeline: {:?}", op))
546 .build()?;
547
548 Ok(Self {
549 context: context.clone(),
550 pipeline,
551 bind_group_layout,
552 workgroup_size: 256,
553 })
554 }
555
556 pub fn execute<T: Pod>(
562 &self,
563 input: &GpuBuffer<T>,
564 output: &mut GpuBuffer<T>,
565 ) -> GpuResult<()> {
566 if input.len() != output.len() {
567 return Err(GpuError::invalid_kernel_params(format!(
568 "Buffer size mismatch: {} != {}",
569 input.len(),
570 output.len()
571 )));
572 }
573
574 let num_elements = input.len();
575
576 let bind_group = self
577 .context
578 .device()
579 .create_bind_group(&BindGroupDescriptor {
580 label: Some("ScalarKernel BindGroup"),
581 layout: &self.bind_group_layout,
582 entries: &[
583 BindGroupEntry {
584 binding: 0,
585 resource: input.buffer().as_entire_binding(),
586 },
587 BindGroupEntry {
588 binding: 1,
589 resource: output.buffer().as_entire_binding(),
590 },
591 ],
592 });
593
594 let mut encoder = self
595 .context
596 .device()
597 .create_command_encoder(&CommandEncoderDescriptor {
598 label: Some("ScalarKernel Encoder"),
599 });
600
601 {
602 let mut compute_pass = encoder.begin_compute_pass(&ComputePassDescriptor {
603 label: Some("ScalarKernel Pass"),
604 timestamp_writes: None,
605 });
606
607 compute_pass.set_pipeline(&self.pipeline);
608 compute_pass.set_bind_group(0, &bind_group, &[]);
609
610 let workgroup_count =
611 (num_elements as u32 + self.workgroup_size - 1) / self.workgroup_size;
612 compute_pass.dispatch_workgroups(workgroup_count, 1, 1);
613 }
614
615 self.context.queue().submit(Some(encoder.finish()));
616
617 debug!("Executed scalar kernel on {} elements", num_elements);
618 Ok(())
619 }
620}
621
622#[cfg(test)]
623mod tests {
624 use super::*;
625
626 #[test]
627 fn test_element_wise_op_shader() {
628 let op = ElementWiseOp::Add;
629 let shader = op.inline_shader();
630 assert!(shader.contains("@compute"));
631 assert!(shader.contains("add"));
632 }
633
634 #[test]
635 fn test_unary_op_shader() {
636 let op = UnaryOp::Sqrt;
637 let shader = op.inline_shader();
638 assert!(shader.contains("sqrt"));
639 }
640
641 #[test]
642 fn test_scalar_op_shader() {
643 let op = ScalarOp::Add(5.0);
644 let shader = op.inline_shader();
645 assert!(shader.contains("5"));
646 }
647
648 #[tokio::test]
649 async fn test_raster_kernel_execution() {
650 if let Ok(context) = GpuContext::new().await {
651 let kernel = RasterKernel::new(&context, ElementWiseOp::Add);
652 if let Ok(_kernel) = kernel {
653 }
655 }
656 }
657}