cubek-reduce 0.2.0-pre.2

CubeK: Reduce Kernels
Documentation
use crate::{IdleMode, LineMode, ReducePrecision};
use cubecl::{prelude::*, std::tensor::r#virtual::VirtualTensor};

#[cube]
pub(crate) fn reduce_count(
    output_size: usize,
    #[comptime] line_mode: LineMode,
    #[comptime] input_line_size: LineSize,
) -> usize {
    match line_mode {
        LineMode::Parallel => output_size,
        LineMode::Perpendicular => output_size / input_line_size,
    }
}

#[cube]
pub fn idle_check<P: ReducePrecision, Out: Numeric>(
    input: &VirtualTensor<P::EI>,
    output: &mut VirtualTensor<Out, ReadWrite>,
    reduce_index_start: usize,
    #[comptime] line_mode: LineMode,
    #[comptime] idle_mode: IdleMode,
) -> Option<bool> {
    if idle_mode.is_enabled() {
        let reduce_count = reduce_count(
            output.len() * output.line_size(),
            line_mode,
            input.line_size(),
        );

        match idle_mode {
            IdleMode::None => Option::new_None(),
            IdleMode::Mask => Option::new_Some(reduce_index_start >= reduce_count),
            IdleMode::Terminate => {
                if reduce_index_start >= reduce_count {
                    terminate!();
                }
                Option::new_None()
            }
        }
    } else {
        Option::new_None()
    }
}