cubek-reduce 0.2.0

CubeK: Reduce Kernels
Documentation
use cubecl::comptime;
use cubecl::cube;
use cubecl::frontend::CubeIndexMutExpand;
use cubecl::prelude::*;

use crate::components::instructions::AccumulatorFormat;
use crate::components::instructions::plane_topk_insert;
use crate::components::instructions::plane_topk_merge;
use crate::components::instructions::{Accumulator, Item, Value};
use crate::{
    ReduceFamily, ReduceInstruction, ReducePrecision,
    components::instructions::{ReduceRequirements, ReduceStep, SharedAccumulator},
};
use cubecl::frontend::Numeric;

#[derive(Debug, CubeType, Clone)]
pub struct ArgTopK {
    #[cube(comptime)]
    pub k: usize,
}

impl ReduceFamily for ArgTopK {
    type Instruction<P: ReducePrecision> = Self;
    type Config = usize;
}

#[derive(CubeType)]
pub struct ArgTopkAccumulator<E: Scalar, S: Size> {
    pub elements: Array<Vector<E, S>>,
    pub coordinates: Array<Vector<u32, S>>,
}

#[derive(CubeType)]
pub struct ArgTopKSharedAccumulator<P: ReducePrecision> {
    elements: Sequence<SharedMemory<Vector<P::EA, P::SI>>>,
    args: Sequence<SharedMemory<Vector<u32, P::SI>>>,
    #[cube(comptime)]
    k: usize,
}

#[cube]
impl<P: ReducePrecision> SharedAccumulator<P, ArgTopK> for ArgTopKSharedAccumulator<P> {
    fn allocate(#[comptime] length: usize, #[comptime] _coordinate: bool, inst: &ArgTopK) -> Self {
        let mut elements = Sequence::new();
        let mut args = Sequence::new();
        for _ in 0..inst.k {
            elements.push(SharedMemory::new(length));
            args.push(SharedMemory::new(length));
        }
        ArgTopKSharedAccumulator::<P> {
            elements,
            args,
            k: inst.k,
        }
    }

    fn read(accumulator: &Self, index: usize) -> Accumulator<P> {
        let mut values = Array::new(accumulator.k);
        let mut args = Array::new(accumulator.k);
        #[unroll]
        for i in 0..accumulator.k {
            values[i] = accumulator.elements[i][index];
            args[i] = accumulator.args[i][index];
        }
        Accumulator::<P> {
            elements: Value::new_Multiple(values),
            args: Value::new_Multiple(args),
        }
    }

    fn write(accumulator: &mut Self, index: usize, item: Accumulator<P>) {
        let values = item.elements.multiple();
        let args = item.args.multiple();
        #[unroll]
        for i in 0..accumulator.k {
            let values_acc = values[i];
            let args_acc = args[i];

            let mut shared_acc = accumulator.elements[i];
            shared_acc[index] = values_acc;

            let mut shared_arg_acc = accumulator.args[i];
            shared_arg_acc[index] = args_acc;
        }
    }
}

#[cube]
impl<P: ReducePrecision> ReduceInstruction<P> for ArgTopK {
    type SharedAccumulator = ArgTopKSharedAccumulator<P>;
    type Config = usize;

    fn requirements(_this: &Self) -> super::ReduceRequirements {
        ReduceRequirements { coordinates: true }
    }

    fn accumulator_format(this: &Self) -> comptime_type!(AccumulatorFormat) {
        comptime!(AccumulatorFormat::Multiple(this.k))
    }

    fn from_config(#[comptime] config: Self::Config) -> Self {
        ArgTopK { k: config }
    }

    fn null_input(_this: &Self) -> Vector<P::EI, P::SI> {
        Vector::empty().fill(P::EI::min_value())
    }

    fn null_accumulator(this: &Self) -> Accumulator<P> {
        let mut elements = Array::new(comptime!(this.k));
        let mut args = Array::new(comptime!(this.k));
        #[unroll]
        for i in 0..this.k {
            elements[i] = Vector::new(P::EA::min_value());
            args[i] = Vector::new(u32::MAX);
        }

        Accumulator::<P> {
            elements: Value::new_Multiple(elements),
            args: Value::new_Multiple(args),
        }
    }

    fn reduce(
        this: &Self,
        accumulator: &mut Accumulator<P>,
        item: Item<P>,
        #[comptime] reduce_step: ReduceStep,
    ) {
        let elements = accumulator.elements.multiple_mut();

        match reduce_step {
            ReduceStep::Plane => {
                plane_topk_insert::<P::EA, P::SI>(
                    elements,
                    &mut accumulator.args,
                    Vector::cast_from(item.elements),
                    &item.args,
                    this.k,
                    true,
                );
            }
            ReduceStep::Identity => {
                let coordinates = accumulator.args.multiple_mut();
                let mut insert_val = Vector::cast_from(item.elements);
                let mut insert_coord = item.args.item();

                for j in 0..this.k {
                    let to_keep = select_many(
                        elements[j].equal(insert_val),
                        coordinates[j].less_than(insert_coord),
                        elements[j].greater_than(insert_val),
                    );
                    let best_value = select_many(to_keep, elements[j], insert_val);
                    let loser_value = select_many(to_keep, insert_val, elements[j]);
                    let best_coordinate = select_many(to_keep, coordinates[j], insert_coord);
                    let loser_coordinate = select_many(to_keep, insert_coord, coordinates[j]);

                    elements[j] = best_value;
                    coordinates[j] = best_coordinate;
                    insert_val = loser_value;
                    insert_coord = loser_coordinate;
                }
            }
        };
    }

    fn plane_reduce_inplace(this: &Self, accumulator: &mut Accumulator<P>) {
        plane_topk_merge::<P::EA, P::SI>(
            accumulator.elements.multiple_mut(),
            &mut accumulator.args,
            this.k,
            true,
        );
    }

    fn fuse_accumulators(this: &Self, accumulator: &mut Accumulator<P>, other: &Accumulator<P>) {
        let elements = accumulator.elements.multiple_mut();
        let coordinates = accumulator.args.multiple_mut();
        let other_elements = other.elements.multiple();
        let other_coords = other.args.multiple();

        for i in 0..this.k {
            let mut insert_val = other_elements[i];
            let mut insert_coord = other_coords[i];
            for j in 0..this.k {
                let to_keep = select_many(
                    elements[j].equal(insert_val),
                    coordinates[j].less_than(insert_coord),
                    elements[j].greater_than(insert_val),
                );
                let best_value = select_many(to_keep, elements[j], insert_val);
                let best_coordinate = select_many(to_keep, coordinates[j], insert_coord);
                let loser_value = select_many(to_keep, insert_val, elements[j]);
                let loser_coordinate = select_many(to_keep, insert_coord, coordinates[j]);

                elements[j] = best_value;
                coordinates[j] = best_coordinate;
                insert_val = loser_value;
                insert_coord = loser_coordinate;
            }
        }
    }

    fn to_output_parallel<Out: Numeric>(
        this: &Self,
        accumulator: Accumulator<P>,
        _shape_axis_reduce: usize,
    ) -> Value<Out> {
        let coords = accumulator.args.multiple();
        let vals = accumulator.elements.multiple();
        let vector_size = coords[0].size().comptime();

        let mut topk_vals = Array::new(this.k);
        let mut topk_coords = Array::new(this.k);

        #[unroll]
        for slot in 0..this.k {
            topk_vals[slot] = P::EA::min_value();
            topk_coords[slot] = u32::MAX;
        }

        #[unroll]
        for i in 0..this.k {
            #[unroll]
            for j in 0..vector_size {
                let mut value = vals[i][j];
                let mut coordinate = coords[i][j];

                #[unroll]
                for slot in 0..this.k {
                    let current_value = topk_vals[slot];
                    let current_coordinate = topk_coords[slot];

                    let to_keep = select(
                        current_value == value,
                        current_coordinate < coordinate,
                        current_value > value,
                    );

                    topk_vals[slot] = select(to_keep, current_value, value);
                    topk_coords[slot] = select(to_keep, current_coordinate, coordinate);

                    value = select(to_keep, value, current_value);
                    coordinate = select(to_keep, coordinate, current_coordinate);
                }
            }
        }

        let mut out = Array::new(this.k);
        #[unroll]
        for i in 0..this.k {
            out[i] = Out::cast_from(topk_coords[i]);
        }
        Value::new_Multiple(out)
    }

    fn to_output_perpendicular<Out: Numeric>(
        this: &Self,
        accumulator: Accumulator<P>,
        _shape_axis_reduce: usize,
    ) -> Value<Vector<Out, P::SI>> {
        let acc_args = accumulator.args.multiple();
        let mut output = Array::new(this.k);

        #[unroll]
        for i in 0..this.k {
            output[i] = Vector::cast_from(acc_args[i]);
        }

        Value::new_Multiple(output)
    }
}