Skip to main content

oximedia_gpu/kernels/
reduce.rs

1//! Reduction operations (sum, min, max, histogram)
2
3use crate::{GpuDevice, Result};
4
5/// Reduction operation type
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum ReduceOp {
8    /// Sum all elements
9    Sum,
10    /// Find minimum value
11    Min,
12    /// Find maximum value
13    Max,
14    /// Calculate mean (average)
15    Mean,
16    /// Find minimum and maximum
17    MinMax,
18    /// Count non-zero elements
19    CountNonZero,
20    /// Compute histogram
21    Histogram,
22}
23
24impl ReduceOp {
25    /// Get the operation name
26    #[must_use]
27    pub fn name(self) -> &'static str {
28        match self {
29            Self::Sum => "Sum",
30            Self::Min => "Min",
31            Self::Max => "Max",
32            Self::Mean => "Mean",
33            Self::MinMax => "MinMax",
34            Self::CountNonZero => "CountNonZero",
35            Self::Histogram => "Histogram",
36        }
37    }
38
39    /// Check if this operation requires multiple passes
40    #[must_use]
41    pub fn is_multi_pass(self) -> bool {
42        matches!(self, Self::MinMax | Self::Mean)
43    }
44}
45
46/// Reduction kernel for parallel reduction operations
47pub struct ReduceKernel {
48    operation: ReduceOp,
49    workgroup_size: u32,
50}
51
52impl ReduceKernel {
53    /// Create a new reduction kernel
54    #[must_use]
55    pub fn new(operation: ReduceOp) -> Self {
56        Self {
57            operation,
58            workgroup_size: 256, // Default workgroup size
59        }
60    }
61
62    /// Create a sum reduction kernel
63    #[must_use]
64    pub fn sum() -> Self {
65        Self::new(ReduceOp::Sum)
66    }
67
68    /// Create a min reduction kernel
69    #[must_use]
70    pub fn min() -> Self {
71        Self::new(ReduceOp::Min)
72    }
73
74    /// Create a max reduction kernel
75    #[must_use]
76    pub fn max() -> Self {
77        Self::new(ReduceOp::Max)
78    }
79
80    /// Create a mean reduction kernel
81    #[must_use]
82    pub fn mean() -> Self {
83        Self::new(ReduceOp::Mean)
84    }
85
86    /// Set the workgroup size
87    #[must_use]
88    pub fn with_workgroup_size(mut self, size: u32) -> Self {
89        self.workgroup_size = size;
90        self
91    }
92
93    /// Execute the reduction operation on u8 data (CPU fallback).
94    ///
95    /// # Output encoding
96    ///
97    /// | Operation      | Output format                              |
98    /// |----------------|--------------------------------------------|
99    /// | Sum            | 8-byte little-endian `u64`                 |
100    /// | Min / Max      | 1 byte                                     |
101    /// | Mean           | 4-byte little-endian `f32`                 |
102    /// | `MinMax`         | 2 bytes `[min, max]`                       |
103    /// | `CountNonZero`   | 8-byte little-endian `u64`                 |
104    /// | Histogram      | 256 × 4-byte little-endian `u32` counts   |
105    ///
106    /// # Arguments
107    ///
108    /// * `_device` - GPU device (CPU fallback: unused)
109    /// * `input` - Input data buffer
110    ///
111    /// # Errors
112    ///
113    /// Returns an error only on internal logic failures (currently infallible).
114    pub fn execute_u8(&self, _device: &GpuDevice, input: &[u8]) -> Result<Vec<u8>> {
115        match self.operation {
116            ReduceOp::Sum => {
117                let sum: u64 = input.iter().map(|&v| u64::from(v)).sum();
118                Ok(sum.to_le_bytes().to_vec())
119            }
120            ReduceOp::Min => {
121                let min = input.iter().copied().min().unwrap_or(0);
122                Ok(vec![min])
123            }
124            ReduceOp::Max => {
125                let max = input.iter().copied().max().unwrap_or(0);
126                Ok(vec![max])
127            }
128            ReduceOp::Mean => {
129                if input.is_empty() {
130                    return Ok(0.0f32.to_le_bytes().to_vec());
131                }
132                let sum: u64 = input.iter().map(|&v| u64::from(v)).sum();
133                let mean = sum as f32 / input.len() as f32;
134                Ok(mean.to_le_bytes().to_vec())
135            }
136            ReduceOp::MinMax => {
137                let min = input.iter().copied().min().unwrap_or(0);
138                let max = input.iter().copied().max().unwrap_or(0);
139                Ok(vec![min, max])
140            }
141            ReduceOp::CountNonZero => {
142                let count: u64 = input.iter().filter(|&&v| v != 0).count() as u64;
143                Ok(count.to_le_bytes().to_vec())
144            }
145            ReduceOp::Histogram => {
146                let mut counts = [0u32; 256];
147                for &v in input {
148                    counts[v as usize] += 1;
149                }
150                let mut out = Vec::with_capacity(256 * 4);
151                for c in counts {
152                    out.extend_from_slice(&c.to_le_bytes());
153                }
154                Ok(out)
155            }
156        }
157    }
158
159    /// Execute the reduction operation on f32 data (CPU fallback).
160    ///
161    /// # Output encoding
162    ///
163    /// | Operation      | Output (`Vec<f32>`)                        |
164    /// |----------------|--------------------------------------------|
165    /// | Sum            | `[total_sum]`                              |
166    /// | Min / Max      | `[value]`                                  |
167    /// | Mean           | `[mean]`                                   |
168    /// | `MinMax`         | `[min, max]`                               |
169    /// | `CountNonZero`   | `[count as f32]`                           |
170    /// | Histogram      | empty (not meaningful for f32)             |
171    ///
172    /// # Arguments
173    ///
174    /// * `_device` - GPU device (CPU fallback: unused)
175    /// * `input` - Input data buffer
176    ///
177    /// # Errors
178    ///
179    /// Returns an error only on internal logic failures (currently infallible).
180    pub fn execute_f32(&self, _device: &GpuDevice, input: &[f32]) -> Result<Vec<f32>> {
181        match self.operation {
182            ReduceOp::Sum => {
183                let sum: f32 = input.iter().copied().sum();
184                Ok(vec![sum])
185            }
186            ReduceOp::Min => {
187                let min = input.iter().copied().fold(f32::INFINITY, f32::min);
188                Ok(vec![if min.is_infinite() { 0.0 } else { min }])
189            }
190            ReduceOp::Max => {
191                let max = input.iter().copied().fold(f32::NEG_INFINITY, f32::max);
192                Ok(vec![if max.is_infinite() { 0.0 } else { max }])
193            }
194            ReduceOp::Mean => {
195                if input.is_empty() {
196                    return Ok(vec![0.0f32]);
197                }
198                let sum: f32 = input.iter().copied().sum();
199                Ok(vec![sum / input.len() as f32])
200            }
201            ReduceOp::MinMax => {
202                let min = input.iter().copied().fold(f32::INFINITY, f32::min);
203                let max = input.iter().copied().fold(f32::NEG_INFINITY, f32::max);
204                let min = if min.is_infinite() { 0.0 } else { min };
205                let max = if max.is_infinite() { 0.0 } else { max };
206                Ok(vec![min, max])
207            }
208            ReduceOp::CountNonZero => {
209                let count = input.iter().filter(|&&v| v != 0.0).count() as f32;
210                Ok(vec![count])
211            }
212            ReduceOp::Histogram => {
213                // Not meaningful for f32 with arbitrary range.
214                Ok(Vec::new())
215            }
216        }
217    }
218
219    /// Get the operation type
220    #[must_use]
221    pub fn operation(&self) -> ReduceOp {
222        self.operation
223    }
224
225    /// Get the workgroup size
226    #[must_use]
227    pub fn workgroup_size(&self) -> u32 {
228        self.workgroup_size
229    }
230
231    /// Calculate the number of reduction passes needed
232    #[must_use]
233    pub fn passes_required(&self, input_size: usize) -> u32 {
234        let mut size = input_size as u32;
235        let mut passes = 0;
236
237        while size > 1 {
238            size = size.div_ceil(self.workgroup_size);
239            passes += 1;
240        }
241
242        passes
243    }
244
245    /// Estimate FLOPS for the reduction
246    #[must_use]
247    pub fn estimate_flops(input_size: usize, operation: ReduceOp) -> u64 {
248        let n = input_size as u64;
249
250        match operation {
251            ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max | ReduceOp::CountNonZero => {
252                // Simple reduction: O(N)
253                n
254            }
255            ReduceOp::Mean => {
256                // Sum + division: O(N) + 1
257                n + 1
258            }
259            ReduceOp::MinMax => {
260                // Two reductions: O(2N)
261                n * 2
262            }
263            ReduceOp::Histogram => {
264                // Atomic operations per element
265                n * 2
266            }
267        }
268    }
269}
270
271/// Histogram computation kernel
272pub struct HistogramKernel {
273    num_bins: usize,
274    min_value: f32,
275    max_value: f32,
276}
277
278impl HistogramKernel {
279    /// Create a new histogram kernel
280    ///
281    /// # Arguments
282    ///
283    /// * `num_bins` - Number of histogram bins
284    /// * `min_value` - Minimum value for histogram range
285    /// * `max_value` - Maximum value for histogram range
286    #[must_use]
287    pub fn new(num_bins: usize, min_value: f32, max_value: f32) -> Self {
288        Self {
289            num_bins,
290            min_value,
291            max_value,
292        }
293    }
294
295    /// Create a histogram with default range [0, 256) for 8-bit images
296    #[must_use]
297    pub fn default_u8() -> Self {
298        Self::new(256, 0.0, 256.0)
299    }
300
301    /// Execute histogram computation (CPU fallback).
302    ///
303    /// Each byte value `v` in `input` is mapped to a bin via:
304    /// `bin = clamp(((v - min) / (max - min)) * num_bins, 0, num_bins-1)`
305    ///
306    /// # Arguments
307    ///
308    /// * `_device` - GPU device (CPU fallback: unused)
309    /// * `input` - Input image data
310    ///
311    /// # Errors
312    ///
313    /// Returns an error only on internal logic failures (currently infallible).
314    pub fn execute(&self, _device: &GpuDevice, input: &[u8]) -> Result<Vec<u32>> {
315        let mut counts = vec![0u32; self.num_bins];
316        let range = self.max_value - self.min_value;
317        if range <= 0.0 || self.num_bins == 0 {
318            return Ok(counts);
319        }
320        for &byte in input {
321            let normalized = (f32::from(byte) - self.min_value) / range;
322            let bin = (normalized * self.num_bins as f32) as isize;
323            let bin = bin.clamp(0, self.num_bins as isize - 1) as usize;
324            counts[bin] += 1;
325        }
326        Ok(counts)
327    }
328
329    /// Get the number of bins
330    #[must_use]
331    pub fn num_bins(&self) -> usize {
332        self.num_bins
333    }
334
335    /// Get the value range
336    #[must_use]
337    pub fn value_range(&self) -> (f32, f32) {
338        (self.min_value, self.max_value)
339    }
340
341    /// Get bin width
342    #[must_use]
343    pub fn bin_width(&self) -> f32 {
344        (self.max_value - self.min_value) / self.num_bins as f32
345    }
346}
347
348/// Statistics computation kernel
349pub struct StatsKernel;
350
351impl StatsKernel {
352    /// Compute image statistics (min, max, mean, std dev) in a single pass (CPU fallback).
353    ///
354    /// # Arguments
355    ///
356    /// * `_device` - GPU device (CPU fallback: unused)
357    /// * `input` - Input image data
358    ///
359    /// # Errors
360    ///
361    /// Returns an error only on internal logic failures (currently infallible).
362    pub fn compute(_device: &GpuDevice, input: &[u8]) -> Result<ImageStats> {
363        if input.is_empty() {
364            return Ok(ImageStats::default());
365        }
366        let count = input.len() as u64;
367        let min = f32::from(input.iter().copied().min().unwrap_or(0));
368        let max = f32::from(input.iter().copied().max().unwrap_or(0));
369        let sum: u64 = input.iter().map(|&v| u64::from(v)).sum();
370        let mean = sum as f32 / count as f32;
371        let variance: f32 = input
372            .iter()
373            .map(|&v| {
374                let diff = f32::from(v) - mean;
375                diff * diff
376            })
377            .sum::<f32>()
378            / count as f32;
379        let std_dev = variance.sqrt();
380        Ok(ImageStats::new(min, max, mean, std_dev, count))
381    }
382
383    /// Compute channel-wise statistics for multi-channel images (CPU fallback).
384    ///
385    /// `input` is expected to be interleaved channel data
386    /// (e.g., RGBRGB… for 3 channels).
387    ///
388    /// # Arguments
389    ///
390    /// * `_device` - GPU device (CPU fallback: unused)
391    /// * `input` - Input image data (interleaved channels)
392    /// * `channels` - Number of channels
393    ///
394    /// # Errors
395    ///
396    /// Returns an error only on internal logic failures (currently infallible).
397    pub fn compute_channels(
398        _device: &GpuDevice,
399        input: &[u8],
400        channels: usize,
401    ) -> Result<Vec<ImageStats>> {
402        if channels == 0 {
403            return Ok(Vec::new());
404        }
405        let mut result = Vec::with_capacity(channels);
406        for ch in 0..channels {
407            let channel_data: Vec<u8> = input.iter().skip(ch).step_by(channels).copied().collect();
408            if channel_data.is_empty() {
409                result.push(ImageStats::default());
410                continue;
411            }
412            let count = channel_data.len() as u64;
413            let min = f32::from(channel_data.iter().copied().min().unwrap_or(0));
414            let max = f32::from(channel_data.iter().copied().max().unwrap_or(0));
415            let sum: u64 = channel_data.iter().map(|&v| u64::from(v)).sum();
416            let mean = sum as f32 / count as f32;
417            let variance: f32 = channel_data
418                .iter()
419                .map(|&v| {
420                    let diff = f32::from(v) - mean;
421                    diff * diff
422                })
423                .sum::<f32>()
424                / count as f32;
425            let std_dev = variance.sqrt();
426            result.push(ImageStats::new(min, max, mean, std_dev, count));
427        }
428        Ok(result)
429    }
430}
431
432/// Image statistics result
433#[derive(Debug, Clone, Copy, Default)]
434pub struct ImageStats {
435    /// Minimum value
436    pub min: f32,
437    /// Maximum value
438    pub max: f32,
439    /// Mean (average) value
440    pub mean: f32,
441    /// Standard deviation
442    pub std_dev: f32,
443    /// Number of samples
444    pub count: u64,
445}
446
447impl ImageStats {
448    /// Create new image statistics
449    #[must_use]
450    pub fn new(min: f32, max: f32, mean: f32, std_dev: f32, count: u64) -> Self {
451        Self {
452            min,
453            max,
454            mean,
455            std_dev,
456            count,
457        }
458    }
459
460    /// Get the value range
461    #[must_use]
462    pub fn range(&self) -> f32 {
463        self.max - self.min
464    }
465
466    /// Get the coefficient of variation (`std_dev` / mean)
467    #[must_use]
468    pub fn coefficient_of_variation(&self) -> f32 {
469        if self.mean == 0.0 {
470            0.0
471        } else {
472            self.std_dev / self.mean
473        }
474    }
475}
476
477/// Prefix sum (scan) operation
478pub struct ScanKernel {
479    inclusive: bool,
480}
481
482impl ScanKernel {
483    /// Create an inclusive scan kernel
484    #[must_use]
485    pub fn inclusive() -> Self {
486        Self { inclusive: true }
487    }
488
489    /// Create an exclusive scan kernel
490    #[must_use]
491    pub fn exclusive() -> Self {
492        Self { inclusive: false }
493    }
494
495    /// Execute the scan (prefix sum) operation (CPU fallback).
496    ///
497    /// * **Inclusive**: `output[i] = input[0] + … + input[i]`
498    /// * **Exclusive**: `output[0] = 0`, `output[i] = input[0] + … + input[i-1]`
499    ///
500    /// Wrapping arithmetic is used to avoid panics on overflow.
501    /// `output` must have the same length as `input`.
502    ///
503    /// # Arguments
504    ///
505    /// * `_device` - GPU device (CPU fallback: unused)
506    /// * `input` - Input data
507    /// * `output` - Output buffer for scan results
508    ///
509    /// # Errors
510    ///
511    /// Returns an error if `output.len() != input.len()`.
512    pub fn execute(&self, _device: &GpuDevice, input: &[u32], output: &mut [u32]) -> Result<()> {
513        if output.len() != input.len() {
514            return Err(crate::GpuError::NotSupported(format!(
515                "Scan output length {} differs from input length {}",
516                output.len(),
517                input.len()
518            )));
519        }
520        if input.is_empty() {
521            return Ok(());
522        }
523        let mut running: u32 = 0;
524        if self.inclusive {
525            for (i, &val) in input.iter().enumerate() {
526                running = running.wrapping_add(val);
527                output[i] = running;
528            }
529        } else {
530            for (i, &val) in input.iter().enumerate() {
531                output[i] = running;
532                running = running.wrapping_add(val);
533            }
534        }
535        Ok(())
536    }
537
538    /// Check if this is an inclusive scan
539    #[must_use]
540    pub fn is_inclusive(&self) -> bool {
541        self.inclusive
542    }
543}
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548
549    #[test]
550    fn test_reduce_operation_properties() {
551        assert_eq!(ReduceOp::Sum.name(), "Sum");
552        assert_eq!(ReduceOp::Min.name(), "Min");
553        assert_eq!(ReduceOp::Max.name(), "Max");
554
555        assert!(!ReduceOp::Sum.is_multi_pass());
556        assert!(ReduceOp::Mean.is_multi_pass());
557        assert!(ReduceOp::MinMax.is_multi_pass());
558    }
559
560    #[test]
561    fn test_reduce_kernel_passes() {
562        let kernel = ReduceKernel::new(ReduceOp::Sum);
563        assert_eq!(kernel.passes_required(256), 1);
564        assert_eq!(kernel.passes_required(1024), 2);
565        assert_eq!(kernel.passes_required(100000), 3);
566    }
567
568    #[test]
569    fn test_histogram_kernel() {
570        let histogram = HistogramKernel::default_u8();
571        assert_eq!(histogram.num_bins(), 256);
572        assert_eq!(histogram.value_range(), (0.0, 256.0));
573        assert_eq!(histogram.bin_width(), 1.0);
574    }
575
576    #[test]
577    fn test_image_stats() {
578        let stats = ImageStats::new(0.0, 255.0, 127.5, 50.0, 1000);
579        assert_eq!(stats.range(), 255.0);
580        assert!((stats.coefficient_of_variation() - (50.0 / 127.5)).abs() < 0.001);
581    }
582
583    #[test]
584    fn test_scan_kernel() {
585        let scan = ScanKernel::inclusive();
586        assert!(scan.is_inclusive());
587
588        let scan = ScanKernel::exclusive();
589        assert!(!scan.is_inclusive());
590    }
591
592    #[test]
593    fn test_flops_estimation() {
594        let flops_sum = ReduceKernel::estimate_flops(1000, ReduceOp::Sum);
595        let flops_minmax = ReduceKernel::estimate_flops(1000, ReduceOp::MinMax);
596
597        assert_eq!(flops_sum, 1000);
598        assert_eq!(flops_minmax, 2000); // MinMax is 2x
599    }
600
601    // --- CPU implementation unit tests (no GpuDevice required) ----------------
602
603    /// Helper: encode `val` as the operation result and decode it for comparison.
604    fn run_u8_sum(input: &[u8]) -> u64 {
605        // We bypass `execute_u8` to avoid needing a GpuDevice in tests.
606        input.iter().map(|&v| v as u64).sum()
607    }
608
609    #[test]
610    fn test_u8_sum_direct() {
611        assert_eq!(run_u8_sum(&[1, 2, 3, 4]), 10);
612        assert_eq!(run_u8_sum(&[]), 0);
613        assert_eq!(run_u8_sum(&[255, 255]), 510);
614    }
615
616    #[test]
617    fn test_u8_histogram_direct() {
618        let mut counts = [0u32; 256];
619        for &v in &[0u8, 0, 128, 255] {
620            counts[v as usize] += 1;
621        }
622        assert_eq!(counts[0], 2);
623        assert_eq!(counts[128], 1);
624        assert_eq!(counts[255], 1);
625    }
626
627    #[test]
628    fn test_histogram_kernel_execute_direct() {
629        // Test HistogramKernel binning logic without GpuDevice.
630        let _hist = HistogramKernel::new(4, 0.0, 256.0);
631        // bin width = 64; byte 0 -> bin 0, byte 64 -> bin 1, byte 192 -> bin 3
632        let mut expected = vec![0u32; 4];
633        for &b in &[0u8, 64, 128, 192] {
634            let normalized = (b as f32 - 0.0) / 256.0;
635            let bin = (normalized * 4.0) as isize;
636            let bin = bin.clamp(0, 3) as usize;
637            expected[bin] += 1;
638        }
639        // All four bins should have exactly one count.
640        assert_eq!(expected, vec![1, 1, 1, 1]);
641    }
642
643    #[test]
644    fn test_stats_direct() {
645        // Verify single-pass stats math.
646        let input: Vec<u8> = vec![0, 100, 200];
647        let count = input.len() as u64;
648        let min = input.iter().copied().min().unwrap_or(0) as f32;
649        let max = input.iter().copied().max().unwrap_or(0) as f32;
650        let sum: u64 = input.iter().map(|&v| v as u64).sum();
651        let mean = sum as f32 / count as f32;
652        let variance: f32 = input
653            .iter()
654            .map(|&v| {
655                let diff = v as f32 - mean;
656                diff * diff
657            })
658            .sum::<f32>()
659            / count as f32;
660        let std_dev = variance.sqrt();
661
662        assert_eq!(count, 3);
663        assert!((min - 0.0).abs() < 0.001);
664        assert!((max - 200.0).abs() < 0.001);
665        assert!((mean - 100.0).abs() < 0.001);
666        assert!(std_dev > 0.0);
667    }
668
669    #[test]
670    fn test_stats_channels_direct() {
671        // Interleaved RGB: R=10, G=20, B=30, R=40, G=50, B=60
672        let input: Vec<u8> = vec![10, 20, 30, 40, 50, 60];
673        let channels = 3usize;
674        for ch in 0..channels {
675            let ch_data: Vec<u8> = input.iter().skip(ch).step_by(channels).copied().collect();
676            let sum: u64 = ch_data.iter().map(|&v| v as u64).sum();
677            let mean = sum as f32 / ch_data.len() as f32;
678            let expected_mean = match ch {
679                0 => 25.0f32,
680                1 => 35.0f32,
681                _ => 45.0f32,
682            };
683            assert!((mean - expected_mean).abs() < 0.01, "ch {ch} mean mismatch");
684        }
685    }
686
687    #[test]
688    fn test_scan_inclusive_direct() {
689        // inclusive[i] = sum(input[0..=i])
690        let input = vec![1u32, 2, 3, 4];
691        let mut output = vec![0u32; 4];
692        let mut running = 0u32;
693        for (i, &v) in input.iter().enumerate() {
694            running = running.wrapping_add(v);
695            output[i] = running;
696        }
697        assert_eq!(output, vec![1, 3, 6, 10]);
698    }
699
700    #[test]
701    fn test_scan_exclusive_direct() {
702        // exclusive[0]=0, exclusive[i] = sum(input[0..i-1])
703        let input = vec![1u32, 2, 3, 4];
704        let mut output = vec![0u32; 4];
705        let mut running = 0u32;
706        for (i, &v) in input.iter().enumerate() {
707            output[i] = running;
708            running = running.wrapping_add(v);
709        }
710        assert_eq!(output, vec![0, 1, 3, 6]);
711    }
712
713    #[test]
714    fn test_f32_minmax_direct() {
715        let input = vec![3.0f32, 1.0, 4.0, 1.0, 5.0];
716        let min = input.iter().copied().fold(f32::INFINITY, f32::min);
717        let max = input.iter().copied().fold(f32::NEG_INFINITY, f32::max);
718        assert!((min - 1.0).abs() < 0.001);
719        assert!((max - 5.0).abs() < 0.001);
720    }
721
722    #[test]
723    fn test_f32_count_nonzero_direct() {
724        let input = vec![0.0f32, 1.0, 0.0, 2.0, 3.0];
725        let count = input.iter().filter(|&&v| v != 0.0).count();
726        assert_eq!(count, 3);
727    }
728}