trueno/backends/gpu/tiled_reduction/
mod.rs1use super::partition_view::PartitionView;
12use super::tensor_view::TensorView;
13
14pub const TILE_SIZE: usize = 16;
16
17pub trait ReduceOp {
19 fn identity() -> f32;
21 fn combine(a: f32, b: f32) -> f32;
23}
24
25pub struct SumOp;
27
28impl ReduceOp for SumOp {
29 #[inline]
30 fn identity() -> f32 {
31 0.0
32 }
33
34 #[inline]
35 fn combine(a: f32, b: f32) -> f32 {
36 a + b
37 }
38}
39
40pub struct MaxOp;
42
43impl ReduceOp for MaxOp {
44 #[inline]
45 fn identity() -> f32 {
46 f32::NEG_INFINITY
47 }
48
49 #[inline]
50 fn combine(a: f32, b: f32) -> f32 {
51 a.max(b)
52 }
53}
54
55pub struct MinOp;
57
58impl ReduceOp for MinOp {
59 #[inline]
60 fn identity() -> f32 {
61 f32::INFINITY
62 }
63
64 #[inline]
65 fn combine(a: f32, b: f32) -> f32 {
66 a.min(b)
67 }
68}
69
70pub fn tiled_reduce_2d<Op: ReduceOp>(data: &[f32], width: usize, height: usize) -> f32 {
77 let partial = collect_tile_results::<Op>(data, width, height);
78 partial.iter().copied().fold(Op::identity(), Op::combine)
79}
80
81#[inline]
83pub fn tiled_sum_2d(data: &[f32], width: usize, height: usize) -> f32 {
84 tiled_reduce_2d::<SumOp>(data, width, height)
85}
86
87#[inline]
89pub fn tiled_max_2d(data: &[f32], width: usize, height: usize) -> f32 {
90 tiled_reduce_2d::<MaxOp>(data, width, height)
91}
92
93#[inline]
95pub fn tiled_min_2d(data: &[f32], width: usize, height: usize) -> f32 {
96 tiled_reduce_2d::<MinOp>(data, width, height)
97}
98
99pub fn tiled_reduce_partial<Op: ReduceOp>(data: &[f32], width: usize, height: usize) -> Vec<f32> {
104 collect_tile_results::<Op>(data, width, height)
105}
106
107fn collect_tile_results<Op: ReduceOp>(data: &[f32], width: usize, height: usize) -> Vec<f32> {
109 if data.is_empty() || width == 0 || height == 0 {
110 return vec![Op::identity()];
111 }
112
113 let view: TensorView<f32> = TensorView::new([height, width, 1, 1]);
114 let partition: PartitionView<f32> = PartitionView::new(view, [TILE_SIZE, TILE_SIZE, 1, 1]);
115
116 let tiles_y = partition.tile_count()[0];
117 let tiles_x = partition.tile_count()[1];
118
119 let mut results = Vec::with_capacity(tiles_y * tiles_x);
120 for tile_y in 0..tiles_y {
121 for tile_x in 0..tiles_x {
122 results.push(reduce_tile::<Op>(data, width, height, tile_x, tile_y));
123 }
124 }
125 results
126}
127
128fn load_tile(
130 tile: &mut [[f32; TILE_SIZE]; TILE_SIZE],
131 data: &[f32],
132 width: usize,
133 height: usize,
134 start_x: usize,
135 start_y: usize,
136) {
137 #[allow(clippy::needless_range_loop)]
138 for ly in 0..TILE_SIZE {
139 let gy = start_y + ly;
140 if gy >= height {
141 break;
142 }
143 #[allow(clippy::needless_range_loop)]
144 for lx in 0..TILE_SIZE {
145 let gx = start_x + lx;
146 if gx >= width {
147 break;
148 }
149 tile[ly][lx] = data[gy * width + gx];
150 }
151 }
152}
153
154fn reduce_rows<Op: ReduceOp>(tile: &mut [[f32; TILE_SIZE]; TILE_SIZE]) {
156 #[allow(clippy::needless_range_loop)]
157 for ly in 0..TILE_SIZE {
158 let mut stride = TILE_SIZE / 2;
159 while stride > 0 {
160 for lx in 0..stride {
161 tile[ly][lx] = Op::combine(tile[ly][lx], tile[ly][lx + stride]);
162 }
163 stride /= 2;
164 }
165 }
166}
167
168fn reduce_columns<Op: ReduceOp>(tile: &mut [[f32; TILE_SIZE]; TILE_SIZE]) {
170 let mut stride = TILE_SIZE / 2;
171 while stride > 0 {
172 for ly in 0..stride {
173 tile[ly][0] = Op::combine(tile[ly][0], tile[ly + stride][0]);
174 }
175 stride /= 2;
176 }
177}
178
179fn reduce_tile<Op: ReduceOp>(
181 data: &[f32],
182 width: usize,
183 height: usize,
184 tile_x: usize,
185 tile_y: usize,
186) -> f32 {
187 let mut tile = [[Op::identity(); TILE_SIZE]; TILE_SIZE];
188 load_tile(&mut tile, data, width, height, tile_x * TILE_SIZE, tile_y * TILE_SIZE);
189 reduce_rows::<Op>(&mut tile);
190 reduce_columns::<Op>(&mut tile);
191 tile[0][0]
192}
193
194#[cfg(test)]
195mod tests;