use cubecl_core as cubecl;
use cubecl_core::prelude::*;
use cubecl_std::tensor::r#virtual::VirtualTensor;
use crate::BoundChecksInner;
use crate::LineMode;
use crate::ReduceParams;
use crate::instructions::*;
use crate::precision::ReducePrecision;
#[derive(CubeType)]
pub struct ReduceRange {
pub index_start: u32,
pub index_step: u32,
pub coordinate_start: u32,
pub coordinate_end: u32,
pub coordinate_step: u32,
}
#[cube]
impl ReduceRange {
pub(crate) fn new<P: ReducePrecision, Out: Numeric>(
reduce_index: u32,
input: &VirtualTensor<P::EI>,
output: &mut VirtualTensor<Out, ReadWrite>,
axis_reduce: u32,
#[comptime] params: ReduceParams,
) -> ReduceRange {
match comptime!(params.line_mode) {
LineMode::Parallel => {
Self::new_parallel::<P, Out>(reduce_index, input, output, axis_reduce, params)
}
LineMode::Perpendicular => {
Self::new_perpendicular::<P, Out>(reduce_index, input, output, axis_reduce, params)
}
}
}
fn new_parallel<P: ReducePrecision, Out: Numeric>(
reduce_index: u32,
input: &VirtualTensor<P::EI>,
output: &mut VirtualTensor<Out, ReadWrite>,
axis_reduce: u32,
#[comptime] params: ReduceParams,
) -> ReduceRange {
let shape_axis = input.shape(axis_reduce);
let mut index_start = 0;
for axis in 0..input.rank() {
let coordinate = output.coordinate(reduce_index, axis);
index_start += coordinate * input.stride(axis);
}
index_start /= params.line_size_input;
let coordinate_end = shape_axis;
let coordinate_step = if params.shared.is_some() {
CUBE_DIM * params.line_size_input
} else if params.use_planes {
CUBE_DIM_X * params.line_size_input
} else {
params.line_size_input.runtime()
};
ReduceRange {
index_start,
index_step: 1,
coordinate_start: 0,
coordinate_end,
coordinate_step,
}
}
fn new_perpendicular<P: ReducePrecision, Out: Numeric>(
reduce_index: u32,
input: &VirtualTensor<P::EI>,
output: &mut VirtualTensor<Out, ReadWrite>,
axis_reduce: u32,
#[comptime] params: ReduceParams,
) -> ReduceRange {
let shape_axis = input.shape(axis_reduce);
let mut index_start = 0;
for axis in 0..input.rank() {
let coordinate = output.coordinate(reduce_index * params.line_size_input, axis);
index_start += coordinate * input.stride(axis);
}
index_start /= params.line_size_input;
let index_step = input.stride(axis_reduce) / params.line_size_input;
let coordinate_end = shape_axis;
let coordinate_step = if params.shared.is_some() {
CUBE_DIM
} else if params.use_planes {
CUBE_DIM_X
} else {
1_u32.runtime()
};
ReduceRange {
index_start,
index_step,
coordinate_start: 0,
coordinate_step,
coordinate_end,
}
}
}
#[cube]
pub fn reduce_slice<P: ReducePrecision, I: List<Line<P::EI>>, R: ReduceInstruction<P>>(
items: &I,
range: ReduceRange,
inst: &R,
#[comptime] line_size: u32,
#[comptime] line_mode: LineMode,
) -> R::AccumulatorItem {
let mut accumulator = R::null_accumulator(inst, line_size);
let mut index = range.index_start;
for coordinate in range_stepped(
range.coordinate_start,
range.coordinate_end,
range.coordinate_step,
) {
let requirements = R::requirements(inst);
let coordinates = if comptime![requirements.coordinates] {
ReduceCoordinate::new_Required(fill_coordinate_line(coordinate, line_size, line_mode))
} else {
ReduceCoordinate::new_NotRequired()
};
reduce_inplace::<P, R>(
inst,
&mut accumulator,
items.read(index),
coordinates,
false,
);
index += range.index_step;
}
accumulator
}
#[cube]
pub fn reduce_slice_plane<P: ReducePrecision, I: List<Line<P::EI>>, R: ReduceInstruction<P>>(
items: &I,
inst: &R,
range: ReduceRange,
#[comptime] line_size: u32,
#[comptime] line_mode: LineMode,
#[comptime] bound_checks: BoundChecksInner,
) -> R::AccumulatorItem {
let plane_dim = CUBE_DIM_X;
let mut accumulator = R::null_accumulator(inst, line_size);
let mut first_index = range.index_start;
for first_coordinate in range_stepped(
range.coordinate_start,
range.coordinate_end,
range.coordinate_step,
) {
let unit_coordinate_offset = match line_mode {
LineMode::Parallel => UNIT_POS_X * line_size,
LineMode::Perpendicular => UNIT_POS_X,
};
let unit_coordinate = first_coordinate + unit_coordinate_offset;
let requirements = R::requirements(inst);
let coordinates = if comptime![requirements.coordinates] {
ReduceCoordinate::new_Required(fill_coordinate_line(
unit_coordinate,
line_size,
line_mode,
))
} else {
ReduceCoordinate::new_NotRequired()
};
let index = first_index + UNIT_POS_X * range.index_step;
let item = match bound_checks {
BoundChecksInner::None => items.read(index),
BoundChecksInner::Mask => {
let mask = unit_coordinate < range.coordinate_end;
let index = index * u32::cast_from(mask);
select(mask, items.read(index), R::null_input(inst, line_size))
}
BoundChecksInner::Branch => {
if unit_coordinate < range.coordinate_end {
items.read(index)
} else {
R::null_input(inst, line_size)
}
}
};
reduce_inplace::<P, R>(inst, &mut accumulator, item, coordinates, true);
first_index += plane_dim * range.index_step;
}
accumulator
}
#[cube]
pub fn reduce_slice_shared<P: ReducePrecision, I: List<Line<P::EI>>, R: ReduceInstruction<P>>(
items: &I,
inst: &R,
range: ReduceRange,
#[comptime] accumulator_size: u32,
#[comptime] line_size: u32,
#[comptime] line_mode: LineMode,
#[comptime] use_planes: bool,
#[comptime] bound_checks: BoundChecksInner,
) -> R::SharedAccumulator {
let accumulator_index = if use_planes { UNIT_POS_Y } else { UNIT_POS };
let requirements = R::requirements(inst);
let mut accumulator =
R::SharedAccumulator::allocate(accumulator_size, line_size, requirements.coordinates);
R::SharedAccumulator::write(
&mut accumulator,
accumulator_index,
R::null_accumulator(inst, line_size),
);
let mut first_index = range.index_start;
for first_coordinate in range_stepped(
range.coordinate_start,
range.coordinate_end,
range.coordinate_step,
) {
let unit_coordinate_offset = match line_mode {
LineMode::Parallel => UNIT_POS * line_size,
LineMode::Perpendicular => UNIT_POS,
};
let unit_coordinate = first_coordinate + unit_coordinate_offset;
let index = first_index + UNIT_POS * range.index_step;
let item = match bound_checks {
BoundChecksInner::None => items.read(index),
BoundChecksInner::Mask => {
let mask = unit_coordinate < range.coordinate_end;
let index = index * u32::cast_from(mask);
select(mask, items.read(index), R::null_input(inst, line_size))
}
BoundChecksInner::Branch => {
if unit_coordinate < range.coordinate_end {
items.read(index)
} else {
R::null_input(inst, line_size)
}
}
};
let coordinates = if comptime! {requirements.coordinates} {
let coordinate = fill_coordinate_line(unit_coordinate, line_size, line_mode);
let coordinate = select(
unit_coordinate < range.coordinate_end,
coordinate,
Line::empty(line_size).fill(u32::MAX),
);
ReduceCoordinate::new_Required(coordinate)
} else {
ReduceCoordinate::new_NotRequired()
};
reduce_shared_inplace::<P, R>(
inst,
&mut accumulator,
accumulator_index,
item,
coordinates,
use_planes,
);
first_index += range.index_step * CUBE_DIM;
}
accumulator
}
#[cube]
fn fill_coordinate_line(
first: u32,
#[comptime] line_size: u32,
#[comptime] line_mode: LineMode,
) -> Line<u32> {
match comptime!(line_mode) {
LineMode::Parallel => {
let mut coordinates = Line::empty(line_size);
#[unroll]
for j in 0..line_size {
coordinates[j] = first + j;
}
coordinates
}
LineMode::Perpendicular => Line::empty(line_size).fill(first),
}
}
#[cube]
pub fn reduce_tree<P: ReducePrecision, Inst: ReduceInstruction<P>>(
inst: &Inst,
accumulator: &mut Inst::SharedAccumulator,
#[comptime] size: u32,
) -> Inst::AccumulatorItem {
if comptime!(size.is_power_of_two()) {
let mut num_active_units = size.runtime();
let mut jump = 1;
while num_active_units > 1 {
num_active_units /= 2;
let destination = jump * 2 * UNIT_POS;
let origin = jump * (2 * UNIT_POS + 1);
if UNIT_POS < num_active_units {
fuse_accumulator_inplace::<P, Inst>(inst, accumulator, destination, origin);
}
jump *= 2;
sync_cube();
}
} else {
let mut num_remaining_items = size.runtime();
let mut jump = 1;
while num_remaining_items > 1 {
let destination = jump * 2 * UNIT_POS;
let origin = jump * (2 * UNIT_POS + 1);
if UNIT_POS < num_remaining_items / 2 {
fuse_accumulator_inplace::<P, Inst>(inst, accumulator, destination, origin);
}
num_remaining_items = num_remaining_items.div_ceil(2);
jump *= 2;
sync_cube();
}
}
sync_cube();
Inst::SharedAccumulator::read(accumulator, 0)
}
#[cube]
pub fn reduce_sum_shuffle<F: Float>(value: F) -> F {
let v1 = value + plane_shuffle_xor(value, 16);
let v2 = v1 + plane_shuffle_xor(v1, 8);
let v3 = v2 + plane_shuffle_xor(v2, 4);
let v4 = v3 + plane_shuffle_xor(v3, 2);
v4 + plane_shuffle_xor(v4, 1)
}
#[cube]
pub fn reduce_max_shuffle<F: Float>(value: F) -> F {
let v1 = F::max(value, plane_shuffle_xor(value, 16));
let v2 = F::max(v1, plane_shuffle_xor(v1, 8));
let v3 = F::max(v2, plane_shuffle_xor(v2, 4));
let v4 = F::max(v3, plane_shuffle_xor(v3, 2));
F::max(v4, plane_shuffle_xor(v4, 1))
}
#[cube]
pub fn reduce_min_shuffle<F: Float>(value: F) -> F {
let v1 = F::min(value, plane_shuffle_xor(value, 16));
let v2 = F::min(v1, plane_shuffle_xor(v1, 8));
let v3 = F::min(v2, plane_shuffle_xor(v2, 4));
let v4 = F::min(v3, plane_shuffle_xor(v3, 2));
F::min(v4, plane_shuffle_xor(v4, 1))
}
#[cube]
pub fn reduce_prod_shuffle<F: Float>(value: F) -> F {
let v1 = value * plane_shuffle_xor(value, 16);
let v2 = v1 * plane_shuffle_xor(v1, 8);
let v3 = v2 * plane_shuffle_xor(v2, 4);
let v4 = v3 * plane_shuffle_xor(v3, 2);
v4 * plane_shuffle_xor(v4, 1)
}