1use crate::buffer::GpuBuffer;
7use crate::context::GpuContext;
8use crate::error::{GpuError, GpuResult};
9use crate::shaders::{
10 ComputePipelineBuilder, WgslShader, create_compute_bind_group_layout, storage_buffer_layout,
11 uniform_buffer_layout,
12};
13use bytemuck::{Pod, Zeroable};
14use tracing::debug;
15use wgpu::{
16 BindGroupDescriptor, BindGroupEntry, BufferUsages, CommandEncoderDescriptor,
17 ComputePassDescriptor, ComputePipeline,
18};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum ReductionOp {
23 Sum,
25 Min,
27 Max,
29 Product,
31}
32
33impl ReductionOp {
34 fn identity(&self) -> f32 {
36 match self {
37 Self::Sum => 0.0,
38 Self::Min => f32::MAX,
39 Self::Max => f32::MIN,
40 Self::Product => 1.0,
41 }
42 }
43
44 fn operation_expr(&self) -> &'static str {
46 match self {
47 Self::Sum => "a + b",
48 Self::Min => "min(a, b)",
49 Self::Max => "max(a, b)",
50 Self::Product => "a * b",
51 }
52 }
53
54 fn reduction_shader(&self) -> String {
56 format!(
57 r#"
58@group(0) @binding(0) var<storage, read> input: array<f32>;
59@group(0) @binding(1) var<storage, read_write> output: array<f32>;
60
61var<workgroup> shared_data: array<f32, 256>;
62
63@compute @workgroup_size(256)
64fn reduce(@builtin(global_invocation_id) global_id: vec3<u32>,
65 @builtin(local_invocation_id) local_id: vec3<u32>,
66 @builtin(workgroup_id) workgroup_id: vec3<u32>) {{
67 let idx = global_id.x;
68 let local_idx = local_id.x;
69 let n = arrayLength(&input);
70
71 // Load data into shared memory
72 if (idx < n) {{
73 shared_data[local_idx] = input[idx];
74 }} else {{
75 shared_data[local_idx] = {identity};
76 }}
77
78 workgroupBarrier();
79
80 // Parallel reduction in shared memory
81 var stride = 128u;
82 while (stride > 0u) {{
83 if (local_idx < stride && idx + stride < n) {{
84 let a = shared_data[local_idx];
85 let b = shared_data[local_idx + stride];
86 shared_data[local_idx] = {op};
87 }}
88 stride = stride / 2u;
89 workgroupBarrier();
90 }}
91
92 // Write result from first thread
93 if (local_idx == 0u) {{
94 output[workgroup_id.x] = shared_data[0];
95 }}
96}}
97"#,
98 identity = self.identity(),
99 op = self.operation_expr()
100 )
101 }
102}
103
104pub struct ReductionKernel {
106 context: GpuContext,
107 pipeline: ComputePipeline,
108 bind_group_layout: wgpu::BindGroupLayout,
109 workgroup_size: u32,
110}
111
112impl ReductionKernel {
113 pub fn new(context: &GpuContext, op: ReductionOp) -> GpuResult<Self> {
119 debug!("Creating reduction kernel for operation: {:?}", op);
120
121 let shader_source = op.reduction_shader();
122 let mut shader = WgslShader::new(shader_source, "reduce");
123 let shader_module = shader.compile(context.device())?;
124
125 let bind_group_layout = create_compute_bind_group_layout(
126 context.device(),
127 &[
128 storage_buffer_layout(0, true), storage_buffer_layout(1, false), ],
131 Some("ReductionKernel BindGroupLayout"),
132 )?;
133
134 let pipeline = ComputePipelineBuilder::new(context.device(), shader_module, "reduce")
135 .bind_group_layout(&bind_group_layout)
136 .label(format!("ReductionKernel Pipeline: {:?}", op))
137 .build()?;
138
139 Ok(Self {
140 context: context.clone(),
141 pipeline,
142 bind_group_layout,
143 workgroup_size: 256,
144 })
145 }
146
147 pub async fn execute<T: Pod + Copy>(
153 &self,
154 input: &GpuBuffer<T>,
155 _op: ReductionOp,
156 ) -> GpuResult<T> {
157 let mut current_input = input.clone();
158 let mut iteration = 0;
159
160 loop {
161 let input_size = current_input.len();
162 let num_workgroups =
163 (input_size as u32 + self.workgroup_size - 1) / self.workgroup_size;
164
165 if num_workgroups == 1 && input_size <= self.workgroup_size as usize {
166 let output = GpuBuffer::new(
168 &self.context,
169 1,
170 BufferUsages::STORAGE | BufferUsages::COPY_SRC,
171 )?;
172
173 self.execute_pass(¤t_input, &output, num_workgroups)?;
174
175 let staging = GpuBuffer::staging(&self.context, 1)?;
177 let mut staging_mut = staging.clone();
178 staging_mut.copy_from(&output)?;
179
180 let result = staging.read().await?;
181 return Ok(result[0]);
182 }
183
184 let output = GpuBuffer::new(
186 &self.context,
187 num_workgroups as usize,
188 BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
189 )?;
190
191 self.execute_pass(¤t_input, &output, num_workgroups)?;
192
193 current_input = output;
194 iteration += 1;
195
196 if iteration > 10 {
197 return Err(GpuError::execution_failed(
198 "Reduction did not converge after 10 iterations",
199 ));
200 }
201 }
202 }
203
204 fn execute_pass<T: Pod>(
206 &self,
207 input: &GpuBuffer<T>,
208 output: &GpuBuffer<T>,
209 num_workgroups: u32,
210 ) -> GpuResult<()> {
211 let bind_group = self
212 .context
213 .device()
214 .create_bind_group(&BindGroupDescriptor {
215 label: Some("ReductionKernel BindGroup"),
216 layout: &self.bind_group_layout,
217 entries: &[
218 BindGroupEntry {
219 binding: 0,
220 resource: input.buffer().as_entire_binding(),
221 },
222 BindGroupEntry {
223 binding: 1,
224 resource: output.buffer().as_entire_binding(),
225 },
226 ],
227 });
228
229 let mut encoder = self
230 .context
231 .device()
232 .create_command_encoder(&CommandEncoderDescriptor {
233 label: Some("ReductionKernel Encoder"),
234 });
235
236 {
237 let mut compute_pass = encoder.begin_compute_pass(&ComputePassDescriptor {
238 label: Some("ReductionKernel Pass"),
239 timestamp_writes: None,
240 });
241
242 compute_pass.set_pipeline(&self.pipeline);
243 compute_pass.set_bind_group(0, &bind_group, &[]);
244 compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
245 }
246
247 self.context.queue().submit(Some(encoder.finish()));
248 Ok(())
249 }
250
251 pub fn execute_blocking<T: Pod + Copy>(
257 &self,
258 input: &GpuBuffer<T>,
259 op: ReductionOp,
260 ) -> GpuResult<T> {
261 pollster::block_on(self.execute(input, op))
262 }
263}
264
265#[derive(Debug, Clone, Copy, Pod, Zeroable)]
267#[repr(C)]
268pub struct HistogramParams {
269 pub num_bins: u32,
271 pub min_value: f32,
273 pub max_value: f32,
275 _padding: u32,
277}
278
279impl HistogramParams {
280 pub fn new(num_bins: u32, min_value: f32, max_value: f32) -> Self {
282 Self {
283 num_bins,
284 min_value,
285 max_value,
286 _padding: 0,
287 }
288 }
289
290 pub fn auto(num_bins: u32) -> Self {
292 Self::new(num_bins, 0.0, 1.0)
293 }
294}
295
296pub struct HistogramKernel {
298 context: GpuContext,
299 pipeline: ComputePipeline,
300 bind_group_layout: wgpu::BindGroupLayout,
301 workgroup_size: u32,
302}
303
304impl HistogramKernel {
305 pub fn new(context: &GpuContext) -> GpuResult<Self> {
311 debug!("Creating histogram kernel");
312
313 let shader_source = Self::histogram_shader();
314 let mut shader = WgslShader::new(shader_source, "histogram");
315 let shader_module = shader.compile(context.device())?;
316
317 let bind_group_layout = create_compute_bind_group_layout(
318 context.device(),
319 &[
320 storage_buffer_layout(0, true), uniform_buffer_layout(1), storage_buffer_layout(2, false), ],
324 Some("HistogramKernel BindGroupLayout"),
325 )?;
326
327 let pipeline = ComputePipelineBuilder::new(context.device(), shader_module, "histogram")
328 .bind_group_layout(&bind_group_layout)
329 .label("HistogramKernel Pipeline")
330 .build()?;
331
332 Ok(Self {
333 context: context.clone(),
334 pipeline,
335 bind_group_layout,
336 workgroup_size: 256,
337 })
338 }
339
340 fn histogram_shader() -> String {
342 r#"
343struct HistogramParams {
344 num_bins: u32,
345 min_value: f32,
346 max_value: f32,
347 _padding: u32,
348}
349
350@group(0) @binding(0) var<storage, read> input: array<f32>;
351@group(0) @binding(1) var<uniform> params: HistogramParams;
352@group(0) @binding(2) var<storage, read_write> histogram: array<atomic<u32>>;
353
354@compute @workgroup_size(256)
355fn histogram(@builtin(global_invocation_id) global_id: vec3<u32>) {
356 let idx = global_id.x;
357 if (idx >= arrayLength(&input)) {
358 return;
359 }
360
361 let value = input[idx];
362 let range = params.max_value - params.min_value;
363
364 if (value >= params.min_value && value <= params.max_value && range > 0.0) {
365 let normalized = (value - params.min_value) / range;
366 var bin = u32(normalized * f32(params.num_bins));
367
368 // Clamp to valid bin range
369 if (bin >= params.num_bins) {
370 bin = params.num_bins - 1u;
371 }
372
373 atomicAdd(&histogram[bin], 1u);
374 }
375}
376"#
377 .to_string()
378 }
379
380 pub async fn execute<T: Pod>(
386 &self,
387 input: &GpuBuffer<T>,
388 params: HistogramParams,
389 ) -> GpuResult<Vec<u32>> {
390 let histogram = GpuBuffer::<u32>::new(
392 &self.context,
393 params.num_bins as usize,
394 BufferUsages::STORAGE | BufferUsages::COPY_SRC,
395 )?;
396
397 let params_buffer = GpuBuffer::from_data(
399 &self.context,
400 &[params],
401 BufferUsages::UNIFORM | BufferUsages::COPY_DST,
402 )?;
403
404 let bind_group = self
405 .context
406 .device()
407 .create_bind_group(&BindGroupDescriptor {
408 label: Some("HistogramKernel BindGroup"),
409 layout: &self.bind_group_layout,
410 entries: &[
411 BindGroupEntry {
412 binding: 0,
413 resource: input.buffer().as_entire_binding(),
414 },
415 BindGroupEntry {
416 binding: 1,
417 resource: params_buffer.buffer().as_entire_binding(),
418 },
419 BindGroupEntry {
420 binding: 2,
421 resource: histogram.buffer().as_entire_binding(),
422 },
423 ],
424 });
425
426 let mut encoder = self
427 .context
428 .device()
429 .create_command_encoder(&CommandEncoderDescriptor {
430 label: Some("HistogramKernel Encoder"),
431 });
432
433 {
434 let mut compute_pass = encoder.begin_compute_pass(&ComputePassDescriptor {
435 label: Some("HistogramKernel Pass"),
436 timestamp_writes: None,
437 });
438
439 compute_pass.set_pipeline(&self.pipeline);
440 compute_pass.set_bind_group(0, &bind_group, &[]);
441
442 let num_workgroups =
443 (input.len() as u32 + self.workgroup_size - 1) / self.workgroup_size;
444 compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
445 }
446
447 self.context.queue().submit(Some(encoder.finish()));
448
449 let staging = GpuBuffer::staging(&self.context, params.num_bins as usize)?;
451 let mut staging_mut = staging.clone();
452 staging_mut.copy_from(&histogram)?;
453
454 let result = staging.read().await?;
455 debug!("Computed histogram with {} bins", params.num_bins);
456 Ok(result)
457 }
458
459 pub fn execute_blocking<T: Pod>(
465 &self,
466 input: &GpuBuffer<T>,
467 params: HistogramParams,
468 ) -> GpuResult<Vec<u32>> {
469 pollster::block_on(self.execute(input, params))
470 }
471}
472
473#[derive(Debug, Clone, Copy, PartialEq)]
475pub struct Statistics {
476 pub min: f32,
478 pub max: f32,
480 pub sum: f32,
482 pub count: usize,
484}
485
486impl Statistics {
487 pub fn mean(&self) -> f32 {
489 if self.count == 0 {
490 0.0
491 } else {
492 self.sum / self.count as f32
493 }
494 }
495
496 pub fn range(&self) -> f32 {
498 self.max - self.min
499 }
500}
501
502pub async fn compute_statistics(
508 context: &GpuContext,
509 input: &GpuBuffer<f32>,
510) -> GpuResult<Statistics> {
511 let sum_kernel = ReductionKernel::new(context, ReductionOp::Sum)?;
512 let min_kernel = ReductionKernel::new(context, ReductionOp::Min)?;
513 let max_kernel = ReductionKernel::new(context, ReductionOp::Max)?;
514
515 let sum = sum_kernel.execute(input, ReductionOp::Sum).await?;
516 let min = min_kernel.execute(input, ReductionOp::Min).await?;
517 let max = max_kernel.execute(input, ReductionOp::Max).await?;
518
519 Ok(Statistics {
520 min,
521 max,
522 sum,
523 count: input.len(),
524 })
525}
526
527pub fn compute_statistics_blocking(
533 context: &GpuContext,
534 input: &GpuBuffer<f32>,
535) -> GpuResult<Statistics> {
536 pollster::block_on(compute_statistics(context, input))
537}
538
539pub use compute_statistics_blocking as compute_stats_blocking;
541
542#[cfg(test)]
543mod tests {
544 use super::*;
545
546 #[test]
547 fn test_reduction_op_identity() {
548 assert_eq!(ReductionOp::Sum.identity(), 0.0);
549 assert_eq!(ReductionOp::Product.identity(), 1.0);
550 }
551
552 #[test]
553 fn test_histogram_params() {
554 let params = HistogramParams::new(256, 0.0, 255.0);
555 assert_eq!(params.num_bins, 256);
556 assert_eq!(params.min_value, 0.0);
557 assert_eq!(params.max_value, 255.0);
558 }
559
560 #[tokio::test]
561 #[ignore]
562 async fn test_reduction_kernel() {
563 if let Ok(context) = GpuContext::new().await {
564 if let Ok(kernel) = ReductionKernel::new(&context, ReductionOp::Sum) {
565 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
566
567 if let Ok(buffer) = GpuBuffer::from_data(
568 &context,
569 &data,
570 BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
571 ) {
572 if let Ok(result) = kernel.execute(&buffer, ReductionOp::Sum).await {
573 assert!((result - 15.0).abs() < 1e-5);
574 }
575 }
576 }
577 }
578 }
579
580 #[test]
581 fn test_statistics_calculations() {
582 let stats = Statistics {
583 min: 0.0,
584 max: 100.0,
585 sum: 500.0,
586 count: 10,
587 };
588
589 assert_eq!(stats.mean(), 50.0);
590 assert_eq!(stats.range(), 100.0);
591 }
592}