use super::partition_view::PartitionView;
use super::tensor_view::TensorView;
pub const TILE_SIZE: usize = 16;
pub trait ReduceOp {
fn identity() -> f32;
fn combine(a: f32, b: f32) -> f32;
}
pub struct SumOp;
impl ReduceOp for SumOp {
#[inline]
fn identity() -> f32 {
0.0
}
#[inline]
fn combine(a: f32, b: f32) -> f32 {
a + b
}
}
pub struct MaxOp;
impl ReduceOp for MaxOp {
#[inline]
fn identity() -> f32 {
f32::NEG_INFINITY
}
#[inline]
fn combine(a: f32, b: f32) -> f32 {
a.max(b)
}
}
pub struct MinOp;
impl ReduceOp for MinOp {
#[inline]
fn identity() -> f32 {
f32::INFINITY
}
#[inline]
fn combine(a: f32, b: f32) -> f32 {
a.min(b)
}
}
pub fn tiled_reduce_2d<Op: ReduceOp>(data: &[f32], width: usize, height: usize) -> f32 {
let partial = collect_tile_results::<Op>(data, width, height);
partial.iter().copied().fold(Op::identity(), Op::combine)
}
#[inline]
pub fn tiled_sum_2d(data: &[f32], width: usize, height: usize) -> f32 {
tiled_reduce_2d::<SumOp>(data, width, height)
}
#[inline]
pub fn tiled_max_2d(data: &[f32], width: usize, height: usize) -> f32 {
tiled_reduce_2d::<MaxOp>(data, width, height)
}
#[inline]
pub fn tiled_min_2d(data: &[f32], width: usize, height: usize) -> f32 {
tiled_reduce_2d::<MinOp>(data, width, height)
}
pub fn tiled_reduce_partial<Op: ReduceOp>(data: &[f32], width: usize, height: usize) -> Vec<f32> {
collect_tile_results::<Op>(data, width, height)
}
fn collect_tile_results<Op: ReduceOp>(data: &[f32], width: usize, height: usize) -> Vec<f32> {
if data.is_empty() || width == 0 || height == 0 {
return vec![Op::identity()];
}
let view: TensorView<f32> = TensorView::new([height, width, 1, 1]);
let partition: PartitionView<f32> = PartitionView::new(view, [TILE_SIZE, TILE_SIZE, 1, 1]);
let tiles_y = partition.tile_count()[0];
let tiles_x = partition.tile_count()[1];
let mut results = Vec::with_capacity(tiles_y * tiles_x);
for tile_y in 0..tiles_y {
for tile_x in 0..tiles_x {
results.push(reduce_tile::<Op>(data, width, height, tile_x, tile_y));
}
}
results
}
fn load_tile(
tile: &mut [[f32; TILE_SIZE]; TILE_SIZE],
data: &[f32],
width: usize,
height: usize,
start_x: usize,
start_y: usize,
) {
#[allow(clippy::needless_range_loop)]
for ly in 0..TILE_SIZE {
let gy = start_y + ly;
if gy >= height {
break;
}
#[allow(clippy::needless_range_loop)]
for lx in 0..TILE_SIZE {
let gx = start_x + lx;
if gx >= width {
break;
}
tile[ly][lx] = data[gy * width + gx];
}
}
}
fn reduce_rows<Op: ReduceOp>(tile: &mut [[f32; TILE_SIZE]; TILE_SIZE]) {
#[allow(clippy::needless_range_loop)]
for ly in 0..TILE_SIZE {
let mut stride = TILE_SIZE / 2;
while stride > 0 {
for lx in 0..stride {
tile[ly][lx] = Op::combine(tile[ly][lx], tile[ly][lx + stride]);
}
stride /= 2;
}
}
}
fn reduce_columns<Op: ReduceOp>(tile: &mut [[f32; TILE_SIZE]; TILE_SIZE]) {
let mut stride = TILE_SIZE / 2;
while stride > 0 {
for ly in 0..stride {
tile[ly][0] = Op::combine(tile[ly][0], tile[ly + stride][0]);
}
stride /= 2;
}
}
fn reduce_tile<Op: ReduceOp>(
data: &[f32],
width: usize,
height: usize,
tile_x: usize,
tile_y: usize,
) -> f32 {
let mut tile = [[Op::identity(); TILE_SIZE]; TILE_SIZE];
load_tile(&mut tile, data, width, height, tile_x * TILE_SIZE, tile_y * TILE_SIZE);
reduce_rows::<Op>(&mut tile);
reduce_columns::<Op>(&mut tile);
tile[0][0]
}
#[cfg(test)]
mod tests;