use crate::components::precision::ReducePrecision;
use cubecl::prelude::*;
pub trait ReduceFamily: Send + Sync + 'static + std::fmt::Debug {
type Instruction<P: ReducePrecision>: ReduceInstruction<P, Config = Self::Config>;
type Config: CubeComptime + Send + Sync;
}
#[derive(CubeType, Clone, Copy)]
pub struct ReduceRequirements {
#[cube(comptime)]
pub coordinates: bool,
}
#[cube]
pub trait ReduceInstruction<P: ReducePrecision>:
Send + Sync + 'static + std::fmt::Debug + CubeType
{
type Config: CubeComptime + Send + Sync;
fn requirements(this: &Self) -> ReduceRequirements;
type AccumulatorItem: CubeType;
type SharedAccumulator: SharedAccumulator<Item = Self::AccumulatorItem>;
fn from_config(#[comptime] config: Self::Config) -> Self;
fn null_input(this: &Self, #[comptime] line_size: LineSize) -> Line<P::EI>;
fn null_accumulator(this: &Self, #[comptime] line_size: LineSize) -> Self::AccumulatorItem;
fn assign_accumulator(
this: &Self,
destination: &mut Self::AccumulatorItem,
source: &Self::AccumulatorItem,
);
fn read_accumulator(
this: &Self,
accumulator: &Self::AccumulatorItem,
) -> (Line<P::EI>, ReduceCoordinate);
fn reduce(
this: &Self,
accumulator: &Self::AccumulatorItem,
item: Line<P::EI>,
coordinate: ReduceCoordinate,
#[comptime] use_planes: bool,
) -> Self::AccumulatorItem;
fn fuse_accumulators(
this: &Self,
lhs: Self::AccumulatorItem,
rhs: Self::AccumulatorItem,
) -> Self::AccumulatorItem;
fn merge_line<Out: Numeric>(
this: &Self,
accumulator: Self::AccumulatorItem,
shape_axis_reduce: usize,
) -> Out;
fn to_output_perpendicular<Out: Numeric>(
this: &Self,
accumulator: Self::AccumulatorItem,
shape_axis_reduce: usize,
) -> Line<Out>;
}
#[derive(CubeType)]
pub enum ReduceCoordinate {
Required(Line<u32>),
NotRequired,
}
#[cube]
pub trait SharedAccumulator: CubeType + Send + Sync + 'static {
type Item: CubeType;
fn allocate(
#[comptime] length: usize,
#[comptime] line_size: LineSize,
#[comptime] _coordinate: bool,
) -> Self;
fn read(accumulator: &Self, index: usize) -> Self::Item;
fn write(accumulator: &mut Self, index: usize, item: Self::Item);
}
#[cube]
impl<In: Numeric> SharedAccumulator for SharedMemory<Line<In>> {
type Item = Line<In>;
fn allocate(
#[comptime] length: usize,
#[comptime] line_size: LineSize,
#[comptime] _coordinate: bool,
) -> Self {
SharedMemory::new_lined(length, line_size)
}
fn read(accumulator: &Self, index: usize) -> Self::Item {
accumulator[index]
}
fn write(accumulator: &mut Self, index: usize, item: Self::Item) {
accumulator[index] = item;
}
}
#[derive(CubeType)]
pub struct ArgAccumulator<N: Numeric> {
pub elements: SharedMemory<Line<N>>,
pub args: SharedMemory<Line<u32>>,
}
#[cube]
impl<In: Numeric> SharedAccumulator for ArgAccumulator<In> {
type Item = (Line<In>, Line<u32>);
fn allocate(
#[comptime] length: usize,
#[comptime] line_size: LineSize,
#[comptime] _coordinate: bool,
) -> Self {
ArgAccumulator::<In> {
elements: SharedMemory::new_lined(length, line_size),
args: SharedMemory::new_lined(length, line_size),
}
}
fn read(accumulator: &Self, index: usize) -> Self::Item {
(accumulator.elements[index], accumulator.args[index])
}
fn write(accumulator: &mut Self, index: usize, item: Self::Item) {
accumulator.elements[index] = item.0;
accumulator.args[index] = item.1;
}
}
#[cube]
pub fn reduce_inplace<P: ReducePrecision, R: ReduceInstruction<P>>(
inst: &R,
accumulator: &mut R::AccumulatorItem,
item: Line<P::EI>,
coordinate: ReduceCoordinate,
#[comptime] use_planes: bool,
) {
let reduction = &R::reduce(inst, accumulator, item, coordinate, use_planes);
R::assign_accumulator(inst, accumulator, reduction);
}
#[cube]
pub fn reduce_shared_inplace<P: ReducePrecision, R: ReduceInstruction<P>>(
inst: &R,
accumulator: &mut R::SharedAccumulator,
index: usize,
item: Line<P::EI>,
coordinate: ReduceCoordinate,
#[comptime] use_planes: bool,
) {
let acc_item = R::SharedAccumulator::read(accumulator, index);
let reduction = R::reduce(inst, &acc_item, item, coordinate, use_planes);
R::SharedAccumulator::write(accumulator, index, reduction);
}
#[cube]
pub fn fuse_accumulator_inplace<P: ReducePrecision, R: ReduceInstruction<P>>(
inst: &R,
accumulator: &mut R::SharedAccumulator,
destination: usize,
origin: usize,
) {
let fused = R::fuse_accumulators(
inst,
R::SharedAccumulator::read(accumulator, destination),
R::SharedAccumulator::read(accumulator, origin),
);
R::SharedAccumulator::write(accumulator, destination, fused);
}