use crate::{LineMode, ReduceDtypes, ReduceError, routines::ReduceBlueprint};
use cubecl::prelude::*;
#[derive(Debug)]
pub struct ReduceLineSettings {
pub line_mode: LineMode,
pub line_size_input: LineSize,
pub line_size_output: LineSize,
}
#[derive(Debug)]
pub struct ReduceLaunchSettings {
pub cube_dim: CubeDim,
pub cube_count: CubeCount,
pub address_type: AddressType,
pub line: ReduceLineSettings,
}
#[derive(Debug)]
pub struct ReduceProblem {
pub vector_size: usize,
pub vector_count: usize,
pub axis: usize,
pub dtypes: ReduceDtypes,
pub address_type: AddressType,
}
#[derive(Debug, Clone)]
pub enum BlueprintStrategy<R: Routine> {
Forced(R::Blueprint, CubeDim),
Inferred(R::Strategy),
}
pub trait Routine: core::fmt::Debug + Clone + Sized {
type Strategy: core::fmt::Debug + Clone + Send + 'static;
type Blueprint: core::fmt::Debug + Clone + Send + 'static;
fn prepare<R: Runtime>(
&self,
client: &ComputeClient<R>,
problem: ReduceProblem,
settings: ReduceLineSettings,
strategy: BlueprintStrategy<Self>,
) -> Result<(ReduceBlueprint, ReduceLaunchSettings), ReduceError>;
}
pub fn cube_count_safe<R: Runtime>(
client: &ComputeClient<R>,
num_cubes: usize,
) -> (CubeCount, usize) {
let cube_count = cube_count_spread(&client.properties().hardware.max_cube_count, num_cubes);
(
CubeCount::Static(
cube_count[0] as u32,
cube_count[1] as u32,
cube_count[2] as u32,
),
cube_count[0] * cube_count[1] * cube_count[2],
)
}
fn cube_count_spread(max_cube_count: &(u32, u32, u32), num_cubes: usize) -> [usize; 3] {
let max_cube_count = [max_cube_count.0, max_cube_count.1, max_cube_count.2];
let mut num_cubes = [num_cubes, 1, 1];
let base = 2;
let mut reduce_count = |i: usize| {
if num_cubes[i] <= max_cube_count[i] as usize {
return true;
}
loop {
num_cubes[i] = num_cubes[i].div_ceil(base);
num_cubes[i + 1] *= base;
if num_cubes[i] <= max_cube_count[i] as usize {
return false;
}
}
};
for i in 0..2 {
if reduce_count(i) {
break;
}
}
num_cubes
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn safe_num_cubes_even() {
let max = (32, 32, 32);
let required = 2048;
let actual = cube_count_spread(&max, required);
let expected = [32, 32, 2];
assert_eq!(actual, expected);
}
#[test]
fn safe_num_cubes_odd() {
let max = (48, 32, 16);
let required = 3177;
let actual = cube_count_spread(&max, required);
let expected = [25, 32, 4];
assert_eq!(actual, expected);
}
}