use cubecl_core as cubecl;
use cubecl_core::prelude::*;
use cubecl_std::tensor::r#virtual::ReadWrite;
use cubecl_std::tensor::r#virtual::VirtualTensor;
use crate::BoundChecksInner;
use crate::LineMode;
use crate::ReduceParams;
use crate::instructions::*;
#[derive(CubeType)]
pub struct ReduceRange {
pub index_start: u32,
pub index_step: u32,
pub coordinate_start: u32,
pub coordinate_end: u32,
pub coordinate_step: u32,
}
#[cube]
impl ReduceRange {
pub(crate) fn new<In: Numeric, Out: Numeric>(
reduce_index: u32,
input: &VirtualTensor<In>,
output: &mut VirtualTensor<Out, ReadWrite>,
axis_reduce: u32,
#[comptime] params: ReduceParams,
) -> ReduceRange {
match comptime!(params.line_mode) {
LineMode::Parallel => {
Self::new_parallel::<In, Out>(reduce_index, input, output, axis_reduce, params)
}
LineMode::Perpendicular => {
Self::new_perpendicular::<In, Out>(reduce_index, input, output, axis_reduce, params)
}
}
}
fn new_parallel<In: Numeric, Out: Numeric>(
reduce_index: u32,
input: &VirtualTensor<In>,
output: &mut VirtualTensor<Out, ReadWrite>,
axis_reduce: u32,
#[comptime] params: ReduceParams,
) -> ReduceRange {
let shape_axis = input.shape(axis_reduce);
let mut index_start = 0;
for axis in 0..input.rank() {
let coordinate = output.coordinate(reduce_index, axis);
index_start += coordinate * input.stride(axis);
}
index_start /= params.line_size_input;
let coordinate_end = shape_axis;
let coordinate_step = if params.shared.is_some() {
CUBE_DIM * params.line_size_input
} else if params.use_planes {
CUBE_DIM_X * params.line_size_input
} else {
params.line_size_input.runtime()
};
ReduceRange {
index_start,
index_step: 1,
coordinate_start: 0,
coordinate_end,
coordinate_step,
}
}
fn new_perpendicular<In: Numeric, Out: Numeric>(
reduce_index: u32,
input: &VirtualTensor<In>,
output: &mut VirtualTensor<Out, ReadWrite>,
axis_reduce: u32,
#[comptime] params: ReduceParams,
) -> ReduceRange {
let shape_axis = input.shape(axis_reduce);
let mut index_start = 0;
for axis in 0..input.rank() {
let coordinate = output.coordinate(reduce_index * params.line_size_input, axis);
index_start += coordinate * input.stride(axis);
}
index_start /= params.line_size_input;
let index_step = input.stride(axis_reduce) / params.line_size_input;
let coordinate_end = shape_axis;
let coordinate_step = if params.shared.is_some() {
CUBE_DIM
} else if params.use_planes {
CUBE_DIM_X
} else {
1_u32.runtime()
};
ReduceRange {
index_start,
index_step,
coordinate_start: 0,
coordinate_step,
coordinate_end,
}
}
}
#[cube]
pub fn reduce_slice<N: Numeric, I: List<Line<N>>, R: ReduceInstruction<N>>(
items: &I,
range: ReduceRange,
inst: &R,
#[comptime] line_size: u32,
#[comptime] line_mode: LineMode,
) -> R::AccumulatorItem {
let mut accumulator = R::null_accumulator(inst, line_size);
let mut index = range.index_start;
for coordinate in range_stepped(
range.coordinate_start,
range.coordinate_end,
range.coordinate_step,
) {
let requirements = R::requirements(inst);
let coordinates = if comptime![requirements.coordinates] {
ReduceCoordinate::new_Required(fill_coordinate_line(coordinate, line_size, line_mode))
} else {
ReduceCoordinate::new_NotRequired()
};
reduce_inplace::<N, R>(
inst,
&mut accumulator,
items.read(index),
coordinates,
false,
);
index += range.index_step;
}
accumulator
}
#[cube]
pub fn reduce_slice_plane<N: Numeric, I: List<Line<N>>, R: ReduceInstruction<N>>(
items: &I,
inst: &R,
range: ReduceRange,
#[comptime] line_size: u32,
#[comptime] line_mode: LineMode,
#[comptime] bound_checks: BoundChecksInner,
) -> R::AccumulatorItem {
let plane_dim = CUBE_DIM_X;
let mut accumulator = R::null_accumulator(inst, line_size);
let mut first_index = range.index_start;
for first_coordinate in range_stepped(
range.coordinate_start,
range.coordinate_end,
range.coordinate_step,
) {
let unit_coordinate_offset = match line_mode {
LineMode::Parallel => UNIT_POS_X * line_size,
LineMode::Perpendicular => UNIT_POS_X,
};
let unit_coordinate = first_coordinate + unit_coordinate_offset;
let requirements = R::requirements(inst);
let coordinates = if comptime![requirements.coordinates] {
ReduceCoordinate::new_Required(fill_coordinate_line(
unit_coordinate,
line_size,
line_mode,
))
} else {
ReduceCoordinate::new_NotRequired()
};
let index = first_index + UNIT_POS_X * range.index_step;
let item = match bound_checks {
BoundChecksInner::None => items.read(index),
BoundChecksInner::Mask => {
let mask = unit_coordinate < range.coordinate_end;
let index = index * u32::cast_from(mask);
select(mask, items.read(index), R::null_input(inst, line_size))
}
BoundChecksInner::Branch => {
if unit_coordinate < range.coordinate_end {
items.read(index)
} else {
R::null_input(inst, line_size)
}
}
};
reduce_inplace::<N, R>(inst, &mut accumulator, item, coordinates, true);
first_index += plane_dim * range.index_step;
}
accumulator
}
#[cube]
pub fn reduce_slice_shared<N: Numeric, I: List<Line<N>>, R: ReduceInstruction<N>>(
items: &I,
inst: &R,
range: ReduceRange,
#[comptime] accumulator_size: u32,
#[comptime] line_size: u32,
#[comptime] line_mode: LineMode,
#[comptime] use_planes: bool,
#[comptime] bound_checks: BoundChecksInner,
) -> R::SharedAccumulator {
let accumulator_index = if use_planes { UNIT_POS_Y } else { UNIT_POS };
let requirements = R::requirements(inst);
let mut accumulator =
R::SharedAccumulator::allocate(accumulator_size, line_size, requirements.coordinates);
R::SharedAccumulator::write(
&mut accumulator,
accumulator_index,
R::null_accumulator(inst, line_size),
);
let mut first_index = range.index_start;
for first_coordinate in range_stepped(
range.coordinate_start,
range.coordinate_end,
range.coordinate_step,
) {
let unit_coordinate_offset = match line_mode {
LineMode::Parallel => UNIT_POS * line_size,
LineMode::Perpendicular => UNIT_POS,
};
let unit_coordinate = first_coordinate + unit_coordinate_offset;
let index = first_index + UNIT_POS * range.index_step;
let item = match bound_checks {
BoundChecksInner::None => items.read(index),
BoundChecksInner::Mask => {
let mask = unit_coordinate < range.coordinate_end;
let index = index * u32::cast_from(mask);
select(mask, items.read(index), R::null_input(inst, line_size))
}
BoundChecksInner::Branch => {
if unit_coordinate < range.coordinate_end {
items.read(index)
} else {
R::null_input(inst, line_size)
}
}
};
let coordinates = if comptime! {requirements.coordinates} {
let coordinate = fill_coordinate_line(unit_coordinate, line_size, line_mode);
let coordinate = select(
unit_coordinate < range.coordinate_end,
coordinate,
Line::empty(line_size).fill(u32::MAX),
);
ReduceCoordinate::new_Required(coordinate)
} else {
ReduceCoordinate::new_NotRequired()
};
reduce_shared_inplace::<N, R>(
inst,
&mut accumulator,
accumulator_index,
item,
coordinates,
use_planes,
);
first_index += range.index_step * CUBE_DIM;
}
accumulator
}
#[cube]
fn fill_coordinate_line(
first: u32,
#[comptime] line_size: u32,
#[comptime] line_mode: LineMode,
) -> Line<u32> {
match comptime!(line_mode) {
LineMode::Parallel => {
let mut coordinates = Line::empty(line_size).fill(first);
if line_size > 1 {
#[unroll]
for j in 0..line_size {
coordinates[j] += j;
}
}
coordinates
}
LineMode::Perpendicular => Line::empty(line_size).fill(first),
}
}
#[cube]
pub fn reduce_tree<In: Numeric, Inst: ReduceInstruction<In>>(
inst: &Inst,
accumulator: &mut Inst::SharedAccumulator,
#[comptime] size: u32,
) -> Inst::AccumulatorItem {
if comptime!(size.is_power_of_two()) {
let mut num_active_units = size.runtime();
let mut jump = 1;
while num_active_units > 1 {
num_active_units /= 2;
let destination = jump * 2 * UNIT_POS;
let origin = jump * (2 * UNIT_POS + 1);
if UNIT_POS < num_active_units {
fuse_accumulator_inplace::<In, Inst>(inst, accumulator, destination, origin);
}
jump *= 2;
sync_units();
}
} else {
let mut num_remaining_items = size.runtime();
let mut jump = 1;
while num_remaining_items > 1 {
let destination = jump * 2 * UNIT_POS;
let origin = jump * (2 * UNIT_POS + 1);
if UNIT_POS < num_remaining_items / 2 {
fuse_accumulator_inplace::<In, Inst>(inst, accumulator, destination, origin);
}
num_remaining_items = div_ceil(num_remaining_items, 2);
jump *= 2;
sync_units();
}
}
sync_units();
Inst::SharedAccumulator::read(accumulator, 0)
}
#[cube]
#[allow(unknown_lints)] #[allow(clippy::manual_div_ceil)]
fn div_ceil(a: u32, b: u32) -> u32 {
(a + b - 1) / b
}