Skip to main content

trueno/backends/gpu/device/reductions/
mod.rs

1//! GPU reduction operations
2//!
3//! Parallel max/sum reductions and 2D tiled reductions (sum/max/min).
4//!
5//! # Submodules
6//!
7//! - `reduce_1d` - 1D parallel reductions (max, sum) used by activation functions
8//! - `tiled_2d` - Generic 2D tiled reduction infrastructure
9
10mod reduce_1d;
11mod tiled_2d;
12
13#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
14use super::super::runtime;
15use super::super::shaders;
16use super::GpuDevice;
17
18impl GpuDevice {
19    /// 2D Tiled Sum Reduction on GPU (sync, native only)
20    ///
21    /// Uses 16x16 workgroups for efficient parallel reduction with
22    /// optimal memory coalescing. GPU version of `tiled_sum_2d`.
23    ///
24    /// # Arguments
25    ///
26    /// * `data` - Input 2D data in row-major order
27    /// * `width` - Number of columns
28    /// * `height` - Number of rows
29    ///
30    /// # Returns
31    ///
32    /// Sum of all elements
33    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
34    pub fn tiled_sum_2d(&self, data: &[f32], width: usize, height: usize) -> Result<f32, String> {
35        runtime::block_on(self.tiled_sum_2d_async(data, width, height))
36    }
37
38    /// 2D Tiled Sum Reduction on GPU (async, works on all platforms)
39    pub async fn tiled_sum_2d_async(
40        &self,
41        data: &[f32],
42        width: usize,
43        height: usize,
44    ) -> Result<f32, String> {
45        self.tiled_reduce_2d_async(
46            data,
47            width,
48            height,
49            shaders::TILED_SUM_REDUCTION_SHADER,
50            "TiledSum",
51            0.0, // identity for sum
52            |partials| partials.iter().sum(),
53        )
54        .await
55    }
56
57    /// 2D Tiled Max Reduction on GPU (sync, native only)
58    ///
59    /// Uses 16x16 workgroups for efficient parallel max reduction.
60    /// GPU version of `tiled_max_2d`.
61    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
62    pub fn tiled_max_2d(&self, data: &[f32], width: usize, height: usize) -> Result<f32, String> {
63        runtime::block_on(self.tiled_max_2d_async(data, width, height))
64    }
65
66    /// 2D Tiled Max Reduction on GPU (async, works on all platforms)
67    pub async fn tiled_max_2d_async(
68        &self,
69        data: &[f32],
70        width: usize,
71        height: usize,
72    ) -> Result<f32, String> {
73        self.tiled_reduce_2d_async(
74            data,
75            width,
76            height,
77            shaders::TILED_MAX_REDUCTION_SHADER,
78            "TiledMax",
79            f32::NEG_INFINITY, // identity for max
80            |partials| partials.iter().copied().fold(f32::NEG_INFINITY, f32::max),
81        )
82        .await
83    }
84
85    /// 2D Tiled Min Reduction on GPU (sync, native only)
86    ///
87    /// Uses 16x16 workgroups for efficient parallel min reduction.
88    /// GPU version of `tiled_min_2d`.
89    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
90    pub fn tiled_min_2d(&self, data: &[f32], width: usize, height: usize) -> Result<f32, String> {
91        runtime::block_on(self.tiled_min_2d_async(data, width, height))
92    }
93
94    /// 2D Tiled Min Reduction on GPU (async, works on all platforms)
95    pub async fn tiled_min_2d_async(
96        &self,
97        data: &[f32],
98        width: usize,
99        height: usize,
100    ) -> Result<f32, String> {
101        self.tiled_reduce_2d_async(
102            data,
103            width,
104            height,
105            shaders::TILED_MIN_REDUCTION_SHADER,
106            "TiledMin",
107            f32::INFINITY, // identity for min
108            |partials| partials.iter().copied().fold(f32::INFINITY, f32::min),
109        )
110        .await
111    }
112}