Skip to main content

oxigdal_gpu/kernels/
statistics.rs

1//! GPU kernels for statistical raster operations.
2//!
3//! This module provides GPU-accelerated statistical operations including
4//! parallel reduction, histogram computation, and basic statistics.
5
6use crate::buffer::GpuBuffer;
7use crate::context::GpuContext;
8use crate::error::{GpuError, GpuResult};
9use crate::shaders::{
10    ComputePipelineBuilder, WgslShader, create_compute_bind_group_layout, storage_buffer_layout,
11    uniform_buffer_layout,
12};
13use bytemuck::{Pod, Zeroable};
14use tracing::debug;
15use wgpu::{
16    BindGroupDescriptor, BindGroupEntry, BufferUsages, CommandEncoderDescriptor,
17    ComputePassDescriptor, ComputePipeline,
18};
19
20/// Reduction operation type.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum ReductionOp {
23    /// Sum of all values.
24    Sum,
25    /// Minimum value.
26    Min,
27    /// Maximum value.
28    Max,
29    /// Product of all values.
30    Product,
31}
32
33impl ReductionOp {
34    /// Get the identity value for this operation.
35    fn identity(&self) -> f32 {
36        match self {
37            Self::Sum => 0.0,
38            Self::Min => f32::MAX,
39            Self::Max => f32::MIN,
40            Self::Product => 1.0,
41        }
42    }
43
44    /// Get the WGSL operation expression.
45    fn operation_expr(&self) -> &'static str {
46        match self {
47            Self::Sum => "a + b",
48            Self::Min => "min(a, b)",
49            Self::Max => "max(a, b)",
50            Self::Product => "a * b",
51        }
52    }
53
54    /// Get inline shader for parallel reduction.
55    fn reduction_shader(&self) -> String {
56        format!(
57            r#"
58@group(0) @binding(0) var<storage, read> input: array<f32>;
59@group(0) @binding(1) var<storage, read_write> output: array<f32>;
60
61var<workgroup> shared_data: array<f32, 256>;
62
63@compute @workgroup_size(256)
64fn reduce(@builtin(global_invocation_id) global_id: vec3<u32>,
65          @builtin(local_invocation_id) local_id: vec3<u32>,
66          @builtin(workgroup_id) workgroup_id: vec3<u32>) {{
67    let idx = global_id.x;
68    let local_idx = local_id.x;
69    let n = arrayLength(&input);
70
71    // Load data into shared memory
72    if (idx < n) {{
73        shared_data[local_idx] = input[idx];
74    }} else {{
75        shared_data[local_idx] = {identity};
76    }}
77
78    workgroupBarrier();
79
80    // Parallel reduction in shared memory
81    var stride = 128u;
82    while (stride > 0u) {{
83        if (local_idx < stride && idx + stride < n) {{
84            let a = shared_data[local_idx];
85            let b = shared_data[local_idx + stride];
86            shared_data[local_idx] = {op};
87        }}
88        stride = stride / 2u;
89        workgroupBarrier();
90    }}
91
92    // Write result from first thread
93    if (local_idx == 0u) {{
94        output[workgroup_id.x] = shared_data[0];
95    }}
96}}
97"#,
98            identity = self.identity(),
99            op = self.operation_expr()
100        )
101    }
102}
103
104/// GPU kernel for parallel reduction operations.
105pub struct ReductionKernel {
106    context: GpuContext,
107    pipeline: ComputePipeline,
108    bind_group_layout: wgpu::BindGroupLayout,
109    workgroup_size: u32,
110}
111
112impl ReductionKernel {
113    /// Create a new reduction kernel.
114    ///
115    /// # Errors
116    ///
117    /// Returns an error if shader compilation or pipeline creation fails.
118    pub fn new(context: &GpuContext, op: ReductionOp) -> GpuResult<Self> {
119        debug!("Creating reduction kernel for operation: {:?}", op);
120
121        let shader_source = op.reduction_shader();
122        let mut shader = WgslShader::new(shader_source, "reduce");
123        let shader_module = shader.compile(context.device())?;
124
125        let bind_group_layout = create_compute_bind_group_layout(
126            context.device(),
127            &[
128                storage_buffer_layout(0, true),  // input
129                storage_buffer_layout(1, false), // output
130            ],
131            Some("ReductionKernel BindGroupLayout"),
132        )?;
133
134        let pipeline = ComputePipelineBuilder::new(context.device(), shader_module, "reduce")
135            .bind_group_layout(&bind_group_layout)
136            .label(format!("ReductionKernel Pipeline: {:?}", op))
137            .build()?;
138
139        Ok(Self {
140            context: context.clone(),
141            pipeline,
142            bind_group_layout,
143            workgroup_size: 256,
144        })
145    }
146
147    /// Execute reduction on GPU buffer.
148    ///
149    /// # Errors
150    ///
151    /// Returns an error if execution fails.
152    pub async fn execute<T: Pod + Copy>(
153        &self,
154        input: &GpuBuffer<T>,
155        _op: ReductionOp,
156    ) -> GpuResult<T> {
157        let mut current_input = input.clone();
158        let mut iteration = 0;
159
160        loop {
161            let input_size = current_input.len();
162            let num_workgroups =
163                (input_size as u32 + self.workgroup_size - 1) / self.workgroup_size;
164
165            if num_workgroups == 1 && input_size <= self.workgroup_size as usize {
166                // Final reduction
167                let output = GpuBuffer::new(
168                    &self.context,
169                    1,
170                    BufferUsages::STORAGE | BufferUsages::COPY_SRC,
171                )?;
172
173                self.execute_pass(&current_input, &output, num_workgroups)?;
174
175                // Read result
176                let staging = GpuBuffer::staging(&self.context, 1)?;
177                let mut staging_mut = staging.clone();
178                staging_mut.copy_from(&output)?;
179
180                let result = staging.read().await?;
181                return Ok(result[0]);
182            }
183
184            // Intermediate reduction
185            let output = GpuBuffer::new(
186                &self.context,
187                num_workgroups as usize,
188                BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
189            )?;
190
191            self.execute_pass(&current_input, &output, num_workgroups)?;
192
193            current_input = output;
194            iteration += 1;
195
196            if iteration > 10 {
197                return Err(GpuError::execution_failed(
198                    "Reduction did not converge after 10 iterations",
199                ));
200            }
201        }
202    }
203
204    /// Execute a single reduction pass.
205    fn execute_pass<T: Pod>(
206        &self,
207        input: &GpuBuffer<T>,
208        output: &GpuBuffer<T>,
209        num_workgroups: u32,
210    ) -> GpuResult<()> {
211        let bind_group = self
212            .context
213            .device()
214            .create_bind_group(&BindGroupDescriptor {
215                label: Some("ReductionKernel BindGroup"),
216                layout: &self.bind_group_layout,
217                entries: &[
218                    BindGroupEntry {
219                        binding: 0,
220                        resource: input.buffer().as_entire_binding(),
221                    },
222                    BindGroupEntry {
223                        binding: 1,
224                        resource: output.buffer().as_entire_binding(),
225                    },
226                ],
227            });
228
229        let mut encoder = self
230            .context
231            .device()
232            .create_command_encoder(&CommandEncoderDescriptor {
233                label: Some("ReductionKernel Encoder"),
234            });
235
236        {
237            let mut compute_pass = encoder.begin_compute_pass(&ComputePassDescriptor {
238                label: Some("ReductionKernel Pass"),
239                timestamp_writes: None,
240            });
241
242            compute_pass.set_pipeline(&self.pipeline);
243            compute_pass.set_bind_group(0, &bind_group, &[]);
244            compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
245        }
246
247        self.context.queue().submit(Some(encoder.finish()));
248        Ok(())
249    }
250
251    /// Execute reduction synchronously.
252    ///
253    /// # Errors
254    ///
255    /// Returns an error if execution fails.
256    pub fn execute_blocking<T: Pod + Copy>(
257        &self,
258        input: &GpuBuffer<T>,
259        op: ReductionOp,
260    ) -> GpuResult<T> {
261        pollster::block_on(self.execute(input, op))
262    }
263}
264
265/// Histogram parameters.
266#[derive(Debug, Clone, Copy, Pod, Zeroable)]
267#[repr(C)]
268pub struct HistogramParams {
269    /// Number of bins.
270    pub num_bins: u32,
271    /// Minimum value.
272    pub min_value: f32,
273    /// Maximum value.
274    pub max_value: f32,
275    /// Padding for alignment.
276    _padding: u32,
277}
278
279impl HistogramParams {
280    /// Create new histogram parameters.
281    pub fn new(num_bins: u32, min_value: f32, max_value: f32) -> Self {
282        Self {
283            num_bins,
284            min_value,
285            max_value,
286            _padding: 0,
287        }
288    }
289
290    /// Create histogram with automatic range.
291    pub fn auto(num_bins: u32) -> Self {
292        Self::new(num_bins, 0.0, 1.0)
293    }
294}
295
296/// GPU kernel for histogram computation.
297pub struct HistogramKernel {
298    context: GpuContext,
299    pipeline: ComputePipeline,
300    bind_group_layout: wgpu::BindGroupLayout,
301    workgroup_size: u32,
302}
303
304impl HistogramKernel {
305    /// Create a new histogram kernel.
306    ///
307    /// # Errors
308    ///
309    /// Returns an error if shader compilation or pipeline creation fails.
310    pub fn new(context: &GpuContext) -> GpuResult<Self> {
311        debug!("Creating histogram kernel");
312
313        let shader_source = Self::histogram_shader();
314        let mut shader = WgslShader::new(shader_source, "histogram");
315        let shader_module = shader.compile(context.device())?;
316
317        let bind_group_layout = create_compute_bind_group_layout(
318            context.device(),
319            &[
320                storage_buffer_layout(0, true),  // input
321                uniform_buffer_layout(1),        // params
322                storage_buffer_layout(2, false), // histogram output
323            ],
324            Some("HistogramKernel BindGroupLayout"),
325        )?;
326
327        let pipeline = ComputePipelineBuilder::new(context.device(), shader_module, "histogram")
328            .bind_group_layout(&bind_group_layout)
329            .label("HistogramKernel Pipeline")
330            .build()?;
331
332        Ok(Self {
333            context: context.clone(),
334            pipeline,
335            bind_group_layout,
336            workgroup_size: 256,
337        })
338    }
339
340    /// Get histogram shader source.
341    fn histogram_shader() -> String {
342        r#"
343struct HistogramParams {
344    num_bins: u32,
345    min_value: f32,
346    max_value: f32,
347    _padding: u32,
348}
349
350@group(0) @binding(0) var<storage, read> input: array<f32>;
351@group(0) @binding(1) var<uniform> params: HistogramParams;
352@group(0) @binding(2) var<storage, read_write> histogram: array<atomic<u32>>;
353
354@compute @workgroup_size(256)
355fn histogram(@builtin(global_invocation_id) global_id: vec3<u32>) {
356    let idx = global_id.x;
357    if (idx >= arrayLength(&input)) {
358        return;
359    }
360
361    let value = input[idx];
362    let range = params.max_value - params.min_value;
363
364    if (value >= params.min_value && value <= params.max_value && range > 0.0) {
365        let normalized = (value - params.min_value) / range;
366        var bin = u32(normalized * f32(params.num_bins));
367
368        // Clamp to valid bin range
369        if (bin >= params.num_bins) {
370            bin = params.num_bins - 1u;
371        }
372
373        atomicAdd(&histogram[bin], 1u);
374    }
375}
376"#
377        .to_string()
378    }
379
380    /// Compute histogram of GPU buffer.
381    ///
382    /// # Errors
383    ///
384    /// Returns an error if execution fails.
385    pub async fn execute<T: Pod>(
386        &self,
387        input: &GpuBuffer<T>,
388        params: HistogramParams,
389    ) -> GpuResult<Vec<u32>> {
390        // Create histogram buffer (atomic u32)
391        let histogram = GpuBuffer::<u32>::new(
392            &self.context,
393            params.num_bins as usize,
394            BufferUsages::STORAGE | BufferUsages::COPY_SRC,
395        )?;
396
397        // Create params uniform buffer
398        let params_buffer = GpuBuffer::from_data(
399            &self.context,
400            &[params],
401            BufferUsages::UNIFORM | BufferUsages::COPY_DST,
402        )?;
403
404        let bind_group = self
405            .context
406            .device()
407            .create_bind_group(&BindGroupDescriptor {
408                label: Some("HistogramKernel BindGroup"),
409                layout: &self.bind_group_layout,
410                entries: &[
411                    BindGroupEntry {
412                        binding: 0,
413                        resource: input.buffer().as_entire_binding(),
414                    },
415                    BindGroupEntry {
416                        binding: 1,
417                        resource: params_buffer.buffer().as_entire_binding(),
418                    },
419                    BindGroupEntry {
420                        binding: 2,
421                        resource: histogram.buffer().as_entire_binding(),
422                    },
423                ],
424            });
425
426        let mut encoder = self
427            .context
428            .device()
429            .create_command_encoder(&CommandEncoderDescriptor {
430                label: Some("HistogramKernel Encoder"),
431            });
432
433        {
434            let mut compute_pass = encoder.begin_compute_pass(&ComputePassDescriptor {
435                label: Some("HistogramKernel Pass"),
436                timestamp_writes: None,
437            });
438
439            compute_pass.set_pipeline(&self.pipeline);
440            compute_pass.set_bind_group(0, &bind_group, &[]);
441
442            let num_workgroups =
443                (input.len() as u32 + self.workgroup_size - 1) / self.workgroup_size;
444            compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
445        }
446
447        self.context.queue().submit(Some(encoder.finish()));
448
449        // Read histogram result
450        let staging = GpuBuffer::staging(&self.context, params.num_bins as usize)?;
451        let mut staging_mut = staging.clone();
452        staging_mut.copy_from(&histogram)?;
453
454        let result = staging.read().await?;
455        debug!("Computed histogram with {} bins", params.num_bins);
456        Ok(result)
457    }
458
459    /// Compute histogram synchronously.
460    ///
461    /// # Errors
462    ///
463    /// Returns an error if execution fails.
464    pub fn execute_blocking<T: Pod>(
465        &self,
466        input: &GpuBuffer<T>,
467        params: HistogramParams,
468    ) -> GpuResult<Vec<u32>> {
469        pollster::block_on(self.execute(input, params))
470    }
471}
472
473/// Statistics result.
474#[derive(Debug, Clone, Copy, PartialEq)]
475pub struct Statistics {
476    /// Minimum value.
477    pub min: f32,
478    /// Maximum value.
479    pub max: f32,
480    /// Sum of all values.
481    pub sum: f32,
482    /// Number of values.
483    pub count: usize,
484}
485
486impl Statistics {
487    /// Calculate mean.
488    pub fn mean(&self) -> f32 {
489        if self.count == 0 {
490            0.0
491        } else {
492            self.sum / self.count as f32
493        }
494    }
495
496    /// Calculate range.
497    pub fn range(&self) -> f32 {
498        self.max - self.min
499    }
500}
501
502/// Compute basic statistics on GPU buffer.
503///
504/// # Errors
505///
506/// Returns an error if GPU operations fail.
507pub async fn compute_statistics(
508    context: &GpuContext,
509    input: &GpuBuffer<f32>,
510) -> GpuResult<Statistics> {
511    let sum_kernel = ReductionKernel::new(context, ReductionOp::Sum)?;
512    let min_kernel = ReductionKernel::new(context, ReductionOp::Min)?;
513    let max_kernel = ReductionKernel::new(context, ReductionOp::Max)?;
514
515    let sum = sum_kernel.execute(input, ReductionOp::Sum).await?;
516    let min = min_kernel.execute(input, ReductionOp::Min).await?;
517    let max = max_kernel.execute(input, ReductionOp::Max).await?;
518
519    Ok(Statistics {
520        min,
521        max,
522        sum,
523        count: input.len(),
524    })
525}
526
527/// Compute basic statistics synchronously.
528///
529/// # Errors
530///
531/// Returns an error if GPU operations fail.
532pub fn compute_statistics_blocking(
533    context: &GpuContext,
534    input: &GpuBuffer<f32>,
535) -> GpuResult<Statistics> {
536    pollster::block_on(compute_statistics(context, input))
537}
538
539// Re-export for convenience
540pub use compute_statistics_blocking as compute_stats_blocking;
541
542#[cfg(test)]
543mod tests {
544    use super::*;
545
546    #[test]
547    fn test_reduction_op_identity() {
548        assert_eq!(ReductionOp::Sum.identity(), 0.0);
549        assert_eq!(ReductionOp::Product.identity(), 1.0);
550    }
551
552    #[test]
553    fn test_histogram_params() {
554        let params = HistogramParams::new(256, 0.0, 255.0);
555        assert_eq!(params.num_bins, 256);
556        assert_eq!(params.min_value, 0.0);
557        assert_eq!(params.max_value, 255.0);
558    }
559
560    #[tokio::test]
561    #[ignore]
562    async fn test_reduction_kernel() {
563        if let Ok(context) = GpuContext::new().await {
564            if let Ok(kernel) = ReductionKernel::new(&context, ReductionOp::Sum) {
565                let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
566
567                if let Ok(buffer) = GpuBuffer::from_data(
568                    &context,
569                    &data,
570                    BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
571                ) {
572                    if let Ok(result) = kernel.execute(&buffer, ReductionOp::Sum).await {
573                        assert!((result - 15.0).abs() < 1e-5);
574                    }
575                }
576            }
577        }
578    }
579
580    #[test]
581    fn test_statistics_calculations() {
582        let stats = Statistics {
583            min: 0.0,
584            max: 100.0,
585            sum: 500.0,
586            count: 10,
587        };
588
589        assert_eq!(stats.mean(), 50.0);
590        assert_eq!(stats.range(), 100.0);
591    }
592}