cubek-reduce 0.2.0-pre.2

CubeK: Reduce Kernels
Documentation
use super::{ReduceCoordinate, ReduceFamily, ReduceInstruction};
use crate::{components::instructions::ReduceRequirements, components::precision::ReducePrecision};
use cubecl::prelude::*;

// TODO Add to test framework.
/// Return the item with the maximum absolute value.
#[derive(Debug, CubeType, Clone)]
pub struct MaxAbs;

impl ReduceFamily for MaxAbs {
    type Instruction<P: ReducePrecision> = Self;
    type Config = ();
}

#[cube]
impl<P: ReducePrecision> ReduceInstruction<P> for MaxAbs {
    type AccumulatorItem = Line<P::EA>;
    type SharedAccumulator = SharedMemory<Line<P::EA>>;
    type Config = ();

    fn requirements(_this: &Self) -> ReduceRequirements {
        ReduceRequirements { coordinates: false }
    }

    fn from_config(_config: Self::Config) -> Self {
        MaxAbs {}
    }

    fn null_input(_this: &Self, #[comptime] line_size: LineSize) -> Line<P::EI> {
        Line::empty(line_size).fill(P::EI::from_int(0))
    }

    fn null_accumulator(_this: &Self, #[comptime] line_size: LineSize) -> Self::AccumulatorItem {
        Line::empty(line_size).fill(P::EA::from_int(0))
    }

    fn assign_accumulator(
        _this: &Self,
        destination: &mut Self::AccumulatorItem,
        source: &Self::AccumulatorItem,
    ) {
        *destination = *source;
    }

    fn reduce(
        _this: &Self,
        accumulator: &Self::AccumulatorItem,
        item: Line<P::EI>,
        _coordinate: ReduceCoordinate,
        #[comptime] use_planes: bool,
    ) -> Self::AccumulatorItem {
        if use_planes {
            let candidate_item = Line::cast_from(plane_max(Line::abs(item)));
            select_many(
                accumulator.greater_than(candidate_item),
                *accumulator,
                candidate_item,
            )
        } else {
            let item_abs = Line::cast_from(Line::abs(item));
            select_many(accumulator.greater_than(item_abs), *accumulator, item_abs)
        }
    }

    fn read_accumulator(
        _this: &Self,
        accumulator: &Line<P::EA>,
    ) -> (Line<P::EI>, ReduceCoordinate) {
        (
            Line::cast_from(*accumulator),
            ReduceCoordinate::new_NotRequired(),
        )
    }

    fn fuse_accumulators(
        _this: &Self,
        lhs: Self::AccumulatorItem,
        rhs: Self::AccumulatorItem,
    ) -> Self::AccumulatorItem {
        select_many(lhs.greater_than(rhs), lhs, rhs)
    }

    fn merge_line<Out: Numeric>(
        _this: &Self,
        accumulator: Self::AccumulatorItem,
        _shape_axis_reduce: usize,
    ) -> Out {
        let mut max = P::EA::from_int(0);
        #[unroll]
        for k in 0..accumulator.size() {
            let candidate = accumulator[k];
            max = select(candidate > max, candidate, max);
        }
        Out::cast_from(max)
    }

    fn to_output_perpendicular<Out: Numeric>(
        _this: &Self,
        accumulator: Self::AccumulatorItem,
        _shape_axis_reduce: usize,
    ) -> Line<Out> {
        Line::cast_from(accumulator)
    }
}