cubek-std 0.2.0-pre.5

CubeK: Standard Library
Documentation
use cubecl::{CubeCount, Runtime, client::ComputeClient};

pub fn cube_count_spread_with_total<R: Runtime>(
    client: &ComputeClient<R>,
    num_cubes: usize,
) -> (CubeCount, usize) {
    let cube_count = cube_count_spread(&client.properties().hardware.max_cube_count, num_cubes);

    (
        CubeCount::Static(
            cube_count[0] as u32,
            cube_count[1] as u32,
            cube_count[2] as u32,
        ),
        cube_count[0] * cube_count[1] * cube_count[2],
    )
}

fn cube_count_spread(max_cube_count: &(u32, u32, u32), num_cubes: usize) -> [usize; 3] {
    let max_cube_count = [max_cube_count.0, max_cube_count.1, max_cube_count.2];
    let mut num_cubes = [num_cubes, 1, 1];
    let base = 2;

    let mut reduce_count = |i: usize| {
        if num_cubes[i] <= max_cube_count[i] as usize {
            return true;
        }

        loop {
            num_cubes[i] = num_cubes[i].div_ceil(base);
            num_cubes[i + 1] *= base;

            if num_cubes[i] <= max_cube_count[i] as usize {
                return false;
            }
        }
    };

    for i in 0..2 {
        if reduce_count(i) {
            break;
        }
    }

    num_cubes
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn safe_num_cubes_even() {
        let max = (32, 32, 32);
        let required = 2048;

        let actual = cube_count_spread(&max, required);
        let expected = [32, 32, 2];
        assert_eq!(actual, expected);
    }

    #[test]
    fn safe_num_cubes_odd() {
        let max = (48, 32, 16);
        let required = 3177;

        let actual = cube_count_spread(&max, required);
        let expected = [25, 32, 4];
        assert_eq!(actual, expected);
    }
}