cubek-reduce 0.2.0-pre.2

CubeK: Reduce Kernels
Documentation
use crate::{LineMode, ReduceDtypes, ReduceError, routines::ReduceBlueprint};
use cubecl::prelude::*;

#[derive(Debug)]
pub struct ReduceLineSettings {
    pub line_mode: LineMode,
    pub line_size_input: LineSize,
    pub line_size_output: LineSize,
}

#[derive(Debug)]
pub struct ReduceLaunchSettings {
    pub cube_dim: CubeDim,
    pub cube_count: CubeCount,
    pub address_type: AddressType,
    pub line: ReduceLineSettings,
}

#[derive(Debug)]
pub struct ReduceProblem {
    pub vector_size: usize,
    pub vector_count: usize,
    pub axis: usize,
    pub dtypes: ReduceDtypes,
    /// The address type, defined by the max of each handle's `required_address_type`
    pub address_type: AddressType,
}

#[derive(Debug, Clone)]
pub enum BlueprintStrategy<R: Routine> {
    Forced(R::Blueprint, CubeDim),
    Inferred(R::Strategy),
}

pub trait Routine: core::fmt::Debug + Clone + Sized {
    type Strategy: core::fmt::Debug + Clone + Send + 'static;
    type Blueprint: core::fmt::Debug + Clone + Send + 'static;

    fn prepare<R: Runtime>(
        &self,
        client: &ComputeClient<R>,
        problem: ReduceProblem,
        settings: ReduceLineSettings,
        strategy: BlueprintStrategy<Self>,
    ) -> Result<(ReduceBlueprint, ReduceLaunchSettings), ReduceError>;
}

pub fn cube_count_safe<R: Runtime>(
    client: &ComputeClient<R>,
    num_cubes: usize,
) -> (CubeCount, usize) {
    let cube_count = cube_count_spread(&client.properties().hardware.max_cube_count, num_cubes);

    (
        CubeCount::Static(
            cube_count[0] as u32,
            cube_count[1] as u32,
            cube_count[2] as u32,
        ),
        cube_count[0] * cube_count[1] * cube_count[2],
    )
}

fn cube_count_spread(max_cube_count: &(u32, u32, u32), num_cubes: usize) -> [usize; 3] {
    let max_cube_count = [max_cube_count.0, max_cube_count.1, max_cube_count.2];
    let mut num_cubes = [num_cubes, 1, 1];
    let base = 2;

    let mut reduce_count = |i: usize| {
        if num_cubes[i] <= max_cube_count[i] as usize {
            return true;
        }

        loop {
            num_cubes[i] = num_cubes[i].div_ceil(base);
            num_cubes[i + 1] *= base;

            if num_cubes[i] <= max_cube_count[i] as usize {
                return false;
            }
        }
    };

    for i in 0..2 {
        if reduce_count(i) {
            break;
        }
    }

    num_cubes
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn safe_num_cubes_even() {
        let max = (32, 32, 32);
        let required = 2048;

        let actual = cube_count_spread(&max, required);
        let expected = [32, 32, 2];
        assert_eq!(actual, expected);
    }

    #[test]
    fn safe_num_cubes_odd() {
        let max = (48, 32, 16);
        let required = 3177;

        let actual = cube_count_spread(&max, required);
        let expected = [25, 32, 4];
        assert_eq!(actual, expected);
    }
}