1use crate::{GpuDevice, Result};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum ReduceOp {
8 Sum,
10 Min,
12 Max,
14 Mean,
16 MinMax,
18 CountNonZero,
20 Histogram,
22}
23
24impl ReduceOp {
25 #[must_use]
27 pub fn name(self) -> &'static str {
28 match self {
29 Self::Sum => "Sum",
30 Self::Min => "Min",
31 Self::Max => "Max",
32 Self::Mean => "Mean",
33 Self::MinMax => "MinMax",
34 Self::CountNonZero => "CountNonZero",
35 Self::Histogram => "Histogram",
36 }
37 }
38
39 #[must_use]
41 pub fn is_multi_pass(self) -> bool {
42 matches!(self, Self::MinMax | Self::Mean)
43 }
44}
45
46pub struct ReduceKernel {
48 operation: ReduceOp,
49 workgroup_size: u32,
50}
51
52impl ReduceKernel {
53 #[must_use]
55 pub fn new(operation: ReduceOp) -> Self {
56 Self {
57 operation,
58 workgroup_size: 256, }
60 }
61
62 #[must_use]
64 pub fn sum() -> Self {
65 Self::new(ReduceOp::Sum)
66 }
67
68 #[must_use]
70 pub fn min() -> Self {
71 Self::new(ReduceOp::Min)
72 }
73
74 #[must_use]
76 pub fn max() -> Self {
77 Self::new(ReduceOp::Max)
78 }
79
80 #[must_use]
82 pub fn mean() -> Self {
83 Self::new(ReduceOp::Mean)
84 }
85
86 #[must_use]
88 pub fn with_workgroup_size(mut self, size: u32) -> Self {
89 self.workgroup_size = size;
90 self
91 }
92
93 pub fn execute_u8(&self, _device: &GpuDevice, input: &[u8]) -> Result<Vec<u8>> {
115 match self.operation {
116 ReduceOp::Sum => {
117 let sum: u64 = input.iter().map(|&v| u64::from(v)).sum();
118 Ok(sum.to_le_bytes().to_vec())
119 }
120 ReduceOp::Min => {
121 let min = input.iter().copied().min().unwrap_or(0);
122 Ok(vec![min])
123 }
124 ReduceOp::Max => {
125 let max = input.iter().copied().max().unwrap_or(0);
126 Ok(vec![max])
127 }
128 ReduceOp::Mean => {
129 if input.is_empty() {
130 return Ok(0.0f32.to_le_bytes().to_vec());
131 }
132 let sum: u64 = input.iter().map(|&v| u64::from(v)).sum();
133 let mean = sum as f32 / input.len() as f32;
134 Ok(mean.to_le_bytes().to_vec())
135 }
136 ReduceOp::MinMax => {
137 let min = input.iter().copied().min().unwrap_or(0);
138 let max = input.iter().copied().max().unwrap_or(0);
139 Ok(vec![min, max])
140 }
141 ReduceOp::CountNonZero => {
142 let count: u64 = input.iter().filter(|&&v| v != 0).count() as u64;
143 Ok(count.to_le_bytes().to_vec())
144 }
145 ReduceOp::Histogram => {
146 let mut counts = [0u32; 256];
147 for &v in input {
148 counts[v as usize] += 1;
149 }
150 let mut out = Vec::with_capacity(256 * 4);
151 for c in counts {
152 out.extend_from_slice(&c.to_le_bytes());
153 }
154 Ok(out)
155 }
156 }
157 }
158
159 pub fn execute_f32(&self, _device: &GpuDevice, input: &[f32]) -> Result<Vec<f32>> {
181 match self.operation {
182 ReduceOp::Sum => {
183 let sum: f32 = input.iter().copied().sum();
184 Ok(vec![sum])
185 }
186 ReduceOp::Min => {
187 let min = input.iter().copied().fold(f32::INFINITY, f32::min);
188 Ok(vec![if min.is_infinite() { 0.0 } else { min }])
189 }
190 ReduceOp::Max => {
191 let max = input.iter().copied().fold(f32::NEG_INFINITY, f32::max);
192 Ok(vec![if max.is_infinite() { 0.0 } else { max }])
193 }
194 ReduceOp::Mean => {
195 if input.is_empty() {
196 return Ok(vec![0.0f32]);
197 }
198 let sum: f32 = input.iter().copied().sum();
199 Ok(vec![sum / input.len() as f32])
200 }
201 ReduceOp::MinMax => {
202 let min = input.iter().copied().fold(f32::INFINITY, f32::min);
203 let max = input.iter().copied().fold(f32::NEG_INFINITY, f32::max);
204 let min = if min.is_infinite() { 0.0 } else { min };
205 let max = if max.is_infinite() { 0.0 } else { max };
206 Ok(vec![min, max])
207 }
208 ReduceOp::CountNonZero => {
209 let count = input.iter().filter(|&&v| v != 0.0).count() as f32;
210 Ok(vec![count])
211 }
212 ReduceOp::Histogram => {
213 Ok(Vec::new())
215 }
216 }
217 }
218
219 #[must_use]
221 pub fn operation(&self) -> ReduceOp {
222 self.operation
223 }
224
225 #[must_use]
227 pub fn workgroup_size(&self) -> u32 {
228 self.workgroup_size
229 }
230
231 #[must_use]
233 pub fn passes_required(&self, input_size: usize) -> u32 {
234 let mut size = input_size as u32;
235 let mut passes = 0;
236
237 while size > 1 {
238 size = size.div_ceil(self.workgroup_size);
239 passes += 1;
240 }
241
242 passes
243 }
244
245 #[must_use]
247 pub fn estimate_flops(input_size: usize, operation: ReduceOp) -> u64 {
248 let n = input_size as u64;
249
250 match operation {
251 ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max | ReduceOp::CountNonZero => {
252 n
254 }
255 ReduceOp::Mean => {
256 n + 1
258 }
259 ReduceOp::MinMax => {
260 n * 2
262 }
263 ReduceOp::Histogram => {
264 n * 2
266 }
267 }
268 }
269}
270
271pub struct HistogramKernel {
273 num_bins: usize,
274 min_value: f32,
275 max_value: f32,
276}
277
278impl HistogramKernel {
279 #[must_use]
287 pub fn new(num_bins: usize, min_value: f32, max_value: f32) -> Self {
288 Self {
289 num_bins,
290 min_value,
291 max_value,
292 }
293 }
294
295 #[must_use]
297 pub fn default_u8() -> Self {
298 Self::new(256, 0.0, 256.0)
299 }
300
301 pub fn execute(&self, _device: &GpuDevice, input: &[u8]) -> Result<Vec<u32>> {
315 let mut counts = vec![0u32; self.num_bins];
316 let range = self.max_value - self.min_value;
317 if range <= 0.0 || self.num_bins == 0 {
318 return Ok(counts);
319 }
320 for &byte in input {
321 let normalized = (f32::from(byte) - self.min_value) / range;
322 let bin = (normalized * self.num_bins as f32) as isize;
323 let bin = bin.clamp(0, self.num_bins as isize - 1) as usize;
324 counts[bin] += 1;
325 }
326 Ok(counts)
327 }
328
329 #[must_use]
331 pub fn num_bins(&self) -> usize {
332 self.num_bins
333 }
334
335 #[must_use]
337 pub fn value_range(&self) -> (f32, f32) {
338 (self.min_value, self.max_value)
339 }
340
341 #[must_use]
343 pub fn bin_width(&self) -> f32 {
344 (self.max_value - self.min_value) / self.num_bins as f32
345 }
346}
347
348pub struct StatsKernel;
350
351impl StatsKernel {
352 pub fn compute(_device: &GpuDevice, input: &[u8]) -> Result<ImageStats> {
363 if input.is_empty() {
364 return Ok(ImageStats::default());
365 }
366 let count = input.len() as u64;
367 let min = f32::from(input.iter().copied().min().unwrap_or(0));
368 let max = f32::from(input.iter().copied().max().unwrap_or(0));
369 let sum: u64 = input.iter().map(|&v| u64::from(v)).sum();
370 let mean = sum as f32 / count as f32;
371 let variance: f32 = input
372 .iter()
373 .map(|&v| {
374 let diff = f32::from(v) - mean;
375 diff * diff
376 })
377 .sum::<f32>()
378 / count as f32;
379 let std_dev = variance.sqrt();
380 Ok(ImageStats::new(min, max, mean, std_dev, count))
381 }
382
383 pub fn compute_channels(
398 _device: &GpuDevice,
399 input: &[u8],
400 channels: usize,
401 ) -> Result<Vec<ImageStats>> {
402 if channels == 0 {
403 return Ok(Vec::new());
404 }
405 let mut result = Vec::with_capacity(channels);
406 for ch in 0..channels {
407 let channel_data: Vec<u8> = input.iter().skip(ch).step_by(channels).copied().collect();
408 if channel_data.is_empty() {
409 result.push(ImageStats::default());
410 continue;
411 }
412 let count = channel_data.len() as u64;
413 let min = f32::from(channel_data.iter().copied().min().unwrap_or(0));
414 let max = f32::from(channel_data.iter().copied().max().unwrap_or(0));
415 let sum: u64 = channel_data.iter().map(|&v| u64::from(v)).sum();
416 let mean = sum as f32 / count as f32;
417 let variance: f32 = channel_data
418 .iter()
419 .map(|&v| {
420 let diff = f32::from(v) - mean;
421 diff * diff
422 })
423 .sum::<f32>()
424 / count as f32;
425 let std_dev = variance.sqrt();
426 result.push(ImageStats::new(min, max, mean, std_dev, count));
427 }
428 Ok(result)
429 }
430}
431
432#[derive(Debug, Clone, Copy, Default)]
434pub struct ImageStats {
435 pub min: f32,
437 pub max: f32,
439 pub mean: f32,
441 pub std_dev: f32,
443 pub count: u64,
445}
446
447impl ImageStats {
448 #[must_use]
450 pub fn new(min: f32, max: f32, mean: f32, std_dev: f32, count: u64) -> Self {
451 Self {
452 min,
453 max,
454 mean,
455 std_dev,
456 count,
457 }
458 }
459
460 #[must_use]
462 pub fn range(&self) -> f32 {
463 self.max - self.min
464 }
465
466 #[must_use]
468 pub fn coefficient_of_variation(&self) -> f32 {
469 if self.mean == 0.0 {
470 0.0
471 } else {
472 self.std_dev / self.mean
473 }
474 }
475}
476
477pub struct ScanKernel {
479 inclusive: bool,
480}
481
482impl ScanKernel {
483 #[must_use]
485 pub fn inclusive() -> Self {
486 Self { inclusive: true }
487 }
488
489 #[must_use]
491 pub fn exclusive() -> Self {
492 Self { inclusive: false }
493 }
494
495 pub fn execute(&self, _device: &GpuDevice, input: &[u32], output: &mut [u32]) -> Result<()> {
513 if output.len() != input.len() {
514 return Err(crate::GpuError::NotSupported(format!(
515 "Scan output length {} differs from input length {}",
516 output.len(),
517 input.len()
518 )));
519 }
520 if input.is_empty() {
521 return Ok(());
522 }
523 let mut running: u32 = 0;
524 if self.inclusive {
525 for (i, &val) in input.iter().enumerate() {
526 running = running.wrapping_add(val);
527 output[i] = running;
528 }
529 } else {
530 for (i, &val) in input.iter().enumerate() {
531 output[i] = running;
532 running = running.wrapping_add(val);
533 }
534 }
535 Ok(())
536 }
537
538 #[must_use]
540 pub fn is_inclusive(&self) -> bool {
541 self.inclusive
542 }
543}
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548
549 #[test]
550 fn test_reduce_operation_properties() {
551 assert_eq!(ReduceOp::Sum.name(), "Sum");
552 assert_eq!(ReduceOp::Min.name(), "Min");
553 assert_eq!(ReduceOp::Max.name(), "Max");
554
555 assert!(!ReduceOp::Sum.is_multi_pass());
556 assert!(ReduceOp::Mean.is_multi_pass());
557 assert!(ReduceOp::MinMax.is_multi_pass());
558 }
559
560 #[test]
561 fn test_reduce_kernel_passes() {
562 let kernel = ReduceKernel::new(ReduceOp::Sum);
563 assert_eq!(kernel.passes_required(256), 1);
564 assert_eq!(kernel.passes_required(1024), 2);
565 assert_eq!(kernel.passes_required(100000), 3);
566 }
567
568 #[test]
569 fn test_histogram_kernel() {
570 let histogram = HistogramKernel::default_u8();
571 assert_eq!(histogram.num_bins(), 256);
572 assert_eq!(histogram.value_range(), (0.0, 256.0));
573 assert_eq!(histogram.bin_width(), 1.0);
574 }
575
576 #[test]
577 fn test_image_stats() {
578 let stats = ImageStats::new(0.0, 255.0, 127.5, 50.0, 1000);
579 assert_eq!(stats.range(), 255.0);
580 assert!((stats.coefficient_of_variation() - (50.0 / 127.5)).abs() < 0.001);
581 }
582
583 #[test]
584 fn test_scan_kernel() {
585 let scan = ScanKernel::inclusive();
586 assert!(scan.is_inclusive());
587
588 let scan = ScanKernel::exclusive();
589 assert!(!scan.is_inclusive());
590 }
591
592 #[test]
593 fn test_flops_estimation() {
594 let flops_sum = ReduceKernel::estimate_flops(1000, ReduceOp::Sum);
595 let flops_minmax = ReduceKernel::estimate_flops(1000, ReduceOp::MinMax);
596
597 assert_eq!(flops_sum, 1000);
598 assert_eq!(flops_minmax, 2000); }
600
601 fn run_u8_sum(input: &[u8]) -> u64 {
605 input.iter().map(|&v| v as u64).sum()
607 }
608
609 #[test]
610 fn test_u8_sum_direct() {
611 assert_eq!(run_u8_sum(&[1, 2, 3, 4]), 10);
612 assert_eq!(run_u8_sum(&[]), 0);
613 assert_eq!(run_u8_sum(&[255, 255]), 510);
614 }
615
616 #[test]
617 fn test_u8_histogram_direct() {
618 let mut counts = [0u32; 256];
619 for &v in &[0u8, 0, 128, 255] {
620 counts[v as usize] += 1;
621 }
622 assert_eq!(counts[0], 2);
623 assert_eq!(counts[128], 1);
624 assert_eq!(counts[255], 1);
625 }
626
627 #[test]
628 fn test_histogram_kernel_execute_direct() {
629 let _hist = HistogramKernel::new(4, 0.0, 256.0);
631 let mut expected = vec![0u32; 4];
633 for &b in &[0u8, 64, 128, 192] {
634 let normalized = (b as f32 - 0.0) / 256.0;
635 let bin = (normalized * 4.0) as isize;
636 let bin = bin.clamp(0, 3) as usize;
637 expected[bin] += 1;
638 }
639 assert_eq!(expected, vec![1, 1, 1, 1]);
641 }
642
643 #[test]
644 fn test_stats_direct() {
645 let input: Vec<u8> = vec![0, 100, 200];
647 let count = input.len() as u64;
648 let min = input.iter().copied().min().unwrap_or(0) as f32;
649 let max = input.iter().copied().max().unwrap_or(0) as f32;
650 let sum: u64 = input.iter().map(|&v| v as u64).sum();
651 let mean = sum as f32 / count as f32;
652 let variance: f32 = input
653 .iter()
654 .map(|&v| {
655 let diff = v as f32 - mean;
656 diff * diff
657 })
658 .sum::<f32>()
659 / count as f32;
660 let std_dev = variance.sqrt();
661
662 assert_eq!(count, 3);
663 assert!((min - 0.0).abs() < 0.001);
664 assert!((max - 200.0).abs() < 0.001);
665 assert!((mean - 100.0).abs() < 0.001);
666 assert!(std_dev > 0.0);
667 }
668
669 #[test]
670 fn test_stats_channels_direct() {
671 let input: Vec<u8> = vec![10, 20, 30, 40, 50, 60];
673 let channels = 3usize;
674 for ch in 0..channels {
675 let ch_data: Vec<u8> = input.iter().skip(ch).step_by(channels).copied().collect();
676 let sum: u64 = ch_data.iter().map(|&v| v as u64).sum();
677 let mean = sum as f32 / ch_data.len() as f32;
678 let expected_mean = match ch {
679 0 => 25.0f32,
680 1 => 35.0f32,
681 _ => 45.0f32,
682 };
683 assert!((mean - expected_mean).abs() < 0.01, "ch {ch} mean mismatch");
684 }
685 }
686
687 #[test]
688 fn test_scan_inclusive_direct() {
689 let input = vec![1u32, 2, 3, 4];
691 let mut output = vec![0u32; 4];
692 let mut running = 0u32;
693 for (i, &v) in input.iter().enumerate() {
694 running = running.wrapping_add(v);
695 output[i] = running;
696 }
697 assert_eq!(output, vec![1, 3, 6, 10]);
698 }
699
700 #[test]
701 fn test_scan_exclusive_direct() {
702 let input = vec![1u32, 2, 3, 4];
704 let mut output = vec![0u32; 4];
705 let mut running = 0u32;
706 for (i, &v) in input.iter().enumerate() {
707 output[i] = running;
708 running = running.wrapping_add(v);
709 }
710 assert_eq!(output, vec![0, 1, 3, 6]);
711 }
712
713 #[test]
714 fn test_f32_minmax_direct() {
715 let input = vec![3.0f32, 1.0, 4.0, 1.0, 5.0];
716 let min = input.iter().copied().fold(f32::INFINITY, f32::min);
717 let max = input.iter().copied().fold(f32::NEG_INFINITY, f32::max);
718 assert!((min - 1.0).abs() < 0.001);
719 assert!((max - 5.0).abs() < 0.001);
720 }
721
722 #[test]
723 fn test_f32_count_nonzero_direct() {
724 let input = vec![0.0f32, 1.0, 0.0, 2.0, 3.0];
725 let count = input.iter().filter(|&&v| v != 0.0).count();
726 assert_eq!(count, 3);
727 }
728}