use crate::{
ReduceInstruction, ReducePrecision, VectorizationMode,
components::{
args::NumericVector,
global::{idle_check, reduction_output_base},
instructions::{
Accumulator, ReduceStep, SharedAccumulator, fuse_accumulator_inplace, reduce_inplace,
},
readers::{Reader, cube::CubeReader},
writers::Writer,
},
routines::CubeBlueprint,
};
use cubecl::{prelude::*, std::tensor::r#virtual::VirtualTensor};
#[derive(CubeType)]
pub struct GlobalFullCubeReduce;
#[cube]
impl GlobalFullCubeReduce {
pub fn execute<P: ReducePrecision, Out: NumericVector, I: ReduceInstruction<P>>(
input: &VirtualTensor<P::EI, P::SI>,
output: &mut VirtualTensor<Out::T, Out::N, ReadWrite>,
reduce_axis: usize,
out_vec_axis: usize,
inst: &I,
#[comptime] vectorization_mode: VectorizationMode,
#[comptime] blueprint: CubeBlueprint,
) {
let acc_format = I::accumulator_format(inst);
let write_index = reduction_output_base::<Out::T, Out::N>(
CUBE_POS,
output,
reduce_axis,
comptime!(acc_format.len()),
);
let accumulator_size = blueprint.num_shared_accumulators;
let worker_pos = Self::worker_pos(blueprint);
let mut writer = Writer::<Out>::new::<P>(
input,
output,
reduce_axis,
out_vec_axis,
write_index,
vectorization_mode,
acc_format,
);
let write_count = writer.write_count();
let reduce_index_start = write_index * write_count;
let idle = idle_check::<P, Out>(
input,
output,
reduce_index_start,
vectorization_mode,
blueprint.cube_idle,
);
for b in 0..write_count {
let reduce_index = reduce_index_start + b;
let mut accumulator_shared = Self::reduce_shared::<P, Out, I>(
input,
output,
reduce_axis,
reduce_index,
inst,
idle,
vectorization_mode,
blueprint,
);
let mut accumulator_final = I::null_accumulator(inst);
match blueprint.use_planes {
true => {
if worker_pos == 0 {
reduce_scan::<P, I>(
inst,
&mut accumulator_shared,
&mut accumulator_final,
accumulator_size,
);
writer.write::<P, I>(b, accumulator_final, inst);
}
sync_cube();
}
false => {
reduce_tree::<P, I>(
inst,
&mut accumulator_shared,
&mut accumulator_final,
worker_pos,
accumulator_size,
);
if worker_pos == 0 {
writer.write::<P, I>(b, accumulator_final, inst);
}
}
};
}
let commit_required = writer.commit_required();
#[allow(clippy::collapsible_if)]
if commit_required {
if worker_pos == 0 {
writer.commit();
}
}
}
fn worker_pos(#[comptime] blueprint: CubeBlueprint) -> usize {
match blueprint.use_planes {
true => UNIT_POS_Y as usize,
false => UNIT_POS as usize,
}
}
#[allow(clippy::too_many_arguments)]
fn reduce_shared<P: ReducePrecision, Out: NumericVector, I: ReduceInstruction<P>>(
input: &VirtualTensor<P::EI, P::SI>,
output: &mut VirtualTensor<Out::T, Out::N, ReadWrite>,
reduce_axis: usize,
reduce_index: usize,
inst: &I,
idle: ComptimeOption<bool>,
#[comptime] vectorization_mode: VectorizationMode,
#[comptime] blueprint: CubeBlueprint,
) -> I::SharedAccumulator {
let reader = Reader::<P>::new::<I, Out>(
input,
output,
inst,
reduce_axis,
reduce_index,
idle,
blueprint.bound_checks,
vectorization_mode,
false,
);
let reader = CubeReader::<P>::new(reader);
let mut accumulator = I::null_accumulator(inst);
for i in 0..reader.length() {
let item = reader.read(i);
reduce_inplace::<P, I>(inst, &mut accumulator, item, ReduceStep::Identity);
}
let worker_pos = Self::worker_pos(blueprint);
let accumulator_plane = match blueprint.use_planes {
true => {
I::plane_reduce_inplace(inst, &mut accumulator);
accumulator
}
false => accumulator,
};
let accumulator_size = blueprint.num_shared_accumulators;
let requirements = I::requirements(inst);
let mut accumulator_shared =
I::SharedAccumulator::allocate(accumulator_size, requirements.coordinates, inst);
I::SharedAccumulator::write(&mut accumulator_shared, worker_pos, accumulator_plane);
sync_cube();
accumulator_shared
}
}
#[cube]
fn reduce_scan<P: ReducePrecision, I: ReduceInstruction<P>>(
inst: &I,
shared_accumulator: &mut I::SharedAccumulator,
accumulator: &mut Accumulator<P>,
#[comptime] size: usize,
) {
for i in 0..size {
let acc = I::SharedAccumulator::read(shared_accumulator, i);
I::fuse_accumulators(inst, accumulator, &acc);
}
}
#[cube]
fn reduce_tree<P: ReducePrecision, I: ReduceInstruction<P>>(
inst: &I,
shared_accumulator: &mut I::SharedAccumulator,
accumulator: &mut Accumulator<P>,
worker_index: usize,
#[comptime] size: usize,
) {
if size.is_power_of_two() {
let mut num_active_units = size.runtime();
let mut jump = 1;
while num_active_units > 1 {
num_active_units /= 2;
let destination = jump * 2 * worker_index;
let origin = jump * (2 * worker_index + 1);
if worker_index < num_active_units {
fuse_accumulator_inplace::<P, I>(inst, shared_accumulator, destination, origin);
}
jump *= 2;
sync_cube();
}
} else {
let mut num_remaining_items = size.runtime();
let mut jump = 1;
while num_remaining_items > 1 {
let destination = jump * 2 * worker_index;
let origin = jump * (2 * worker_index + 1);
if worker_index < num_remaining_items / 2 {
fuse_accumulator_inplace::<P, I>(inst, shared_accumulator, destination, origin);
}
num_remaining_items = num_remaining_items.div_ceil(2);
jump *= 2;
sync_cube();
}
}
sync_cube();
let acc = I::SharedAccumulator::read(shared_accumulator, 0);
I::fuse_accumulators(inst, accumulator, &acc);
}