cubek-reduce 0.2.0

CubeK: Reduce Kernels
Documentation
use crate::{
    BoundChecks, ReduceInstruction, ReducePrecision, VectorizationMode,
    components::{
        args::NumericVector,
        instructions::{ReduceRequirements, Value},
        readers::{parallel::ParallelReader, perpendicular::PerpendicularReader},
    },
};
use cubecl::{prelude::*, std::tensor::r#virtual::VirtualTensor};

#[derive(CubeType)]
pub enum Reader<P: ReducePrecision> {
    Parallel(ParallelReader<P>),
    Perpendicular(PerpendicularReader<P>),
}

#[cube]
impl<P: ReducePrecision> Reader<P> {
    #[allow(clippy::too_many_arguments)]
    pub fn new<I: ReduceInstruction<P>, Out: NumericVector>(
        input: &VirtualTensor<P::EI, P::SI>,
        output: &mut VirtualTensor<Out::T, Out::N, ReadWrite>,
        inst: &I,
        reduce_axis: usize,
        reduce_index: usize,
        idle: ComptimeOption<bool>,
        #[comptime] bound_checks: BoundChecks,
        #[comptime] vectorization_mode: VectorizationMode,
        #[comptime] plane_dim_ceil: bool,
    ) -> Reader<P> {
        let effective_plane_dim = if plane_dim_ceil {
            min(CUBE_DIM_X, PLANE_DIM)
        } else {
            CUBE_DIM_X
        };
        match vectorization_mode {
            VectorizationMode::Parallel => {
                Reader::<P>::new_Parallel(ParallelReader::<P>::new::<I, Out>(
                    input,
                    output,
                    inst,
                    reduce_axis,
                    reduce_index,
                    idle,
                    effective_plane_dim,
                    bound_checks,
                ))
            }
            VectorizationMode::Perpendicular => {
                Reader::<P>::new_Perpendicular(PerpendicularReader::<P>::new::<I, Out>(
                    input,
                    output,
                    inst,
                    reduce_axis,
                    reduce_index,
                    idle,
                    effective_plane_dim,
                    bound_checks,
                ))
            }
        }
    }
}

#[cube]
pub fn new_coordinates<N: Size>(
    coordinate: usize,
    requirements: ReduceRequirements,
    #[comptime] vectorization_mode: VectorizationMode,
) -> Value<Vector<u32, N>> {
    if requirements.coordinates.comptime() {
        // TODO: Make this generic to allow 64-bit coordinate output.
        // Can't directly use `usize` for the buffer, since its size isn't defined beyond the
        // kernel boundary.
        Value::new_single(fill_coordinate_vector(
            coordinate as u32,
            vectorization_mode,
        ))
    } else {
        Value::new_None()
    }
}

// If vectorization mode is parallel, fill a vector with `x, x+1, ... x+ vector_size - 1` where `x = first`.
// If vectorization mode is perpendicular, fill a vector with `x, x, ... x` where `x = first`.
#[cube]
pub(crate) fn fill_coordinate_vector<N: Size>(
    first: u32,
    #[comptime] vectorization_mode: VectorizationMode,
) -> Vector<u32, N> {
    match vectorization_mode {
        VectorizationMode::Parallel => {
            let mut coordinates = Vector::empty();
            #[unroll]
            for j in 0..N::value() {
                coordinates[j] = first + j as u32;
            }
            coordinates
        }
        VectorizationMode::Perpendicular => Vector::empty().fill(first),
    }
}