Skip to main content

trueno/backends/gpu/tiled_reduction/
mod.rs

1//! CPU fallback implementation of tiled reduction algorithms
2//!
3//! This module provides CPU implementations that mirror the GPU tiled reduction
4//! algorithms. These are useful for:
5//! - Testing and validation (compare GPU results against CPU reference)
6//! - Fallback when GPU is unavailable
7//! - Understanding the algorithm without GPU complexity
8//!
9//! The algorithms use the same 16×16 tile structure as the GPU shaders.
10
11use super::partition_view::PartitionView;
12use super::tensor_view::TensorView;
13
14/// Default tile size for 2D reductions (matches GPU workgroup size)
15pub const TILE_SIZE: usize = 16;
16
17/// Reduction operation trait for generic tile reduction
18pub trait ReduceOp {
19    /// Identity element for the reduction (0 for sum, -inf for max, inf for min)
20    fn identity() -> f32;
21    /// Combine two values
22    fn combine(a: f32, b: f32) -> f32;
23}
24
25/// Sum reduction operation
26pub struct SumOp;
27
28impl ReduceOp for SumOp {
29    #[inline]
30    fn identity() -> f32 {
31        0.0
32    }
33
34    #[inline]
35    fn combine(a: f32, b: f32) -> f32 {
36        a + b
37    }
38}
39
40/// Max reduction operation
41pub struct MaxOp;
42
43impl ReduceOp for MaxOp {
44    #[inline]
45    fn identity() -> f32 {
46        f32::NEG_INFINITY
47    }
48
49    #[inline]
50    fn combine(a: f32, b: f32) -> f32 {
51        a.max(b)
52    }
53}
54
55/// Min reduction operation
56pub struct MinOp;
57
58impl ReduceOp for MinOp {
59    #[inline]
60    fn identity() -> f32 {
61        f32::INFINITY
62    }
63
64    #[inline]
65    fn combine(a: f32, b: f32) -> f32 {
66        a.min(b)
67    }
68}
69
70/// Perform tiled reduction on 2D data (CPU fallback)
71///
72/// This simulates the GPU algorithm:
73/// 1. Partition input into 16×16 tiles
74/// 2. Reduce each tile to a single value
75/// 3. Combine partial results
76pub fn tiled_reduce_2d<Op: ReduceOp>(data: &[f32], width: usize, height: usize) -> f32 {
77    let partial = collect_tile_results::<Op>(data, width, height);
78    partial.iter().copied().fold(Op::identity(), Op::combine)
79}
80
81/// Convenience function for tiled sum reduction
82#[inline]
83pub fn tiled_sum_2d(data: &[f32], width: usize, height: usize) -> f32 {
84    tiled_reduce_2d::<SumOp>(data, width, height)
85}
86
87/// Convenience function for tiled max reduction
88#[inline]
89pub fn tiled_max_2d(data: &[f32], width: usize, height: usize) -> f32 {
90    tiled_reduce_2d::<MaxOp>(data, width, height)
91}
92
93/// Convenience function for tiled min reduction
94#[inline]
95pub fn tiled_min_2d(data: &[f32], width: usize, height: usize) -> f32 {
96    tiled_reduce_2d::<MinOp>(data, width, height)
97}
98
99/// Compute partial tile results for verification
100///
101/// Returns the partial reduction result for each tile, which can be
102/// compared against GPU partial results buffer for validation.
103pub fn tiled_reduce_partial<Op: ReduceOp>(data: &[f32], width: usize, height: usize) -> Vec<f32> {
104    collect_tile_results::<Op>(data, width, height)
105}
106
107/// Shared implementation: reduce each tile and return partial results.
108fn collect_tile_results<Op: ReduceOp>(data: &[f32], width: usize, height: usize) -> Vec<f32> {
109    if data.is_empty() || width == 0 || height == 0 {
110        return vec![Op::identity()];
111    }
112
113    let view: TensorView<f32> = TensorView::new([height, width, 1, 1]);
114    let partition: PartitionView<f32> = PartitionView::new(view, [TILE_SIZE, TILE_SIZE, 1, 1]);
115
116    let tiles_y = partition.tile_count()[0];
117    let tiles_x = partition.tile_count()[1];
118
119    let mut results = Vec::with_capacity(tiles_y * tiles_x);
120    for tile_y in 0..tiles_y {
121        for tile_x in 0..tiles_x {
122            results.push(reduce_tile::<Op>(data, width, height, tile_x, tile_y));
123        }
124    }
125    results
126}
127
128/// Load data into a tile with bounds checking.
129fn load_tile(
130    tile: &mut [[f32; TILE_SIZE]; TILE_SIZE],
131    data: &[f32],
132    width: usize,
133    height: usize,
134    start_x: usize,
135    start_y: usize,
136) {
137    #[allow(clippy::needless_range_loop)]
138    for ly in 0..TILE_SIZE {
139        let gy = start_y + ly;
140        if gy >= height {
141            break;
142        }
143        #[allow(clippy::needless_range_loop)]
144        for lx in 0..TILE_SIZE {
145            let gx = start_x + lx;
146            if gx >= width {
147                break;
148            }
149            tile[ly][lx] = data[gy * width + gx];
150        }
151    }
152}
153
154/// Tree reduction along rows: halve stride each step until 1.
155fn reduce_rows<Op: ReduceOp>(tile: &mut [[f32; TILE_SIZE]; TILE_SIZE]) {
156    #[allow(clippy::needless_range_loop)]
157    for ly in 0..TILE_SIZE {
158        let mut stride = TILE_SIZE / 2;
159        while stride > 0 {
160            for lx in 0..stride {
161                tile[ly][lx] = Op::combine(tile[ly][lx], tile[ly][lx + stride]);
162            }
163            stride /= 2;
164        }
165    }
166}
167
168/// Tree reduction along columns: halve stride each step on column 0.
169fn reduce_columns<Op: ReduceOp>(tile: &mut [[f32; TILE_SIZE]; TILE_SIZE]) {
170    let mut stride = TILE_SIZE / 2;
171    while stride > 0 {
172        for ly in 0..stride {
173            tile[ly][0] = Op::combine(tile[ly][0], tile[ly + stride][0]);
174        }
175        stride /= 2;
176    }
177}
178
179/// Reduce a single 16×16 tile using tree reduction pattern
180fn reduce_tile<Op: ReduceOp>(
181    data: &[f32],
182    width: usize,
183    height: usize,
184    tile_x: usize,
185    tile_y: usize,
186) -> f32 {
187    let mut tile = [[Op::identity(); TILE_SIZE]; TILE_SIZE];
188    load_tile(&mut tile, data, width, height, tile_x * TILE_SIZE, tile_y * TILE_SIZE);
189    reduce_rows::<Op>(&mut tile);
190    reduce_columns::<Op>(&mut tile);
191    tile[0][0]
192}
193
194#[cfg(test)]
195mod tests;