use cubecl::{
Runtime,
client::ComputeClient,
ir::{StorageType, VectorSize},
};
use cubek_matmul::components::stage::PartitionBuffering;
use cubek_matmul::definition::{
MatmulAvailabilityError, MatmulElems, MatmulVectorSizes, SwizzleModes, TilingBlueprint,
TilingScheme, adjust_dtypes,
};
use cubek_matmul::{
components::tile::TileMatmulFamily,
routines::{NUM_SM_APPROX, NUM_TENSOR_CORES_APPROX, find_instruction_size},
};
use cubek_std::stage::SwizzleMode;
use crate::components::ConvolutionProblem;
pub(crate) fn find_stage_size_m_n(
m: usize,
n: usize,
num_sm: usize,
max_tensor_cores: usize,
instruction_m: usize,
instruction_n: usize,
stage_size_k: usize,
) -> (usize, usize) {
let max_tiles_elems_m = 256 / instruction_m;
let max_tiles_elems_n = 256 / instruction_n;
let max_tiles_total_stage = 16 / stage_size_k;
let mut dim_num_tiles_m = max_tensor_cores
.min(max_tiles_elems_m)
.min(max_tiles_total_stage);
let mut dim_num_tiles_n = max_tensor_cores
.min(max_tiles_elems_n)
.min(max_tiles_total_stage);
let total_tiles_m = m.div_ceil(instruction_m);
let total_tiles_n = n.div_ceil(instruction_n);
while total_tiles_n < dim_num_tiles_n && dim_num_tiles_n > 1 {
dim_num_tiles_n /= 2;
}
let total_tiles = total_tiles_m * total_tiles_n;
let mut stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
let mut num_cubes_expected = total_tiles.div_ceil(stage_num_tiles);
let mut previous_dim_num_tiles = dim_num_tiles_m;
let mut previous_num_cubes = num_cubes_expected;
while num_cubes_expected < num_sm && dim_num_tiles_m > 1 {
previous_dim_num_tiles = dim_num_tiles_m;
previous_num_cubes = num_cubes_expected;
dim_num_tiles_m = dim_num_tiles_m.div_ceil(2);
stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
num_cubes_expected = total_tiles.div_ceil(stage_num_tiles);
}
if (previous_num_cubes as isize - num_sm as isize).abs()
<= (num_cubes_expected as isize - num_sm as isize).abs()
{
(previous_dim_num_tiles, dim_num_tiles_n)
} else {
(dim_num_tiles_n, dim_num_tiles_m)
}
}
pub fn convolution_matmul_selection<TMM: TileMatmulFamily, R: Runtime>(
client: &ComputeClient<R>,
problem: &ConvolutionProblem,
plane_dim: u32,
swizzle: bool,
vector_sizes: &MatmulVectorSizes,
dtypes: &mut MatmulElems,
) -> Result<TilingBlueprint, MatmulAvailabilityError> {
adjust_dtypes(client, dtypes, TMM::requires_accelerator());
let stage_k = if problem.k >= 4096 { 4 } else { 2 };
let tile_size = find_instruction_size::<R, _, _>(
client,
(
dtypes.lhs_register,
dtypes.rhs_register,
dtypes.acc_register,
),
(problem.m, problem.n, problem.k).into(),
(None, None, None),
TMM::is_supported,
TMM::supported_sizes,
)?;
let hardware = &client.properties().hardware;
let num_sm = hardware
.num_streaming_multiprocessors
.unwrap_or(NUM_TENSOR_CORES_APPROX);
let max_tensor_cores = hardware.num_tensor_cores.unwrap_or(NUM_SM_APPROX);
let (stage_size_m, stage_size_n) = find_stage_size_m_n(
problem.m,
problem.n,
num_sm as usize,
max_tensor_cores as usize,
tile_size.m() as usize,
tile_size.n() as usize,
stage_k as usize,
);
let tiling_scheme = TilingScheme::builder()
.with_stage_size((stage_size_m as u32, 1, 1).into())
.with_tile_size(tile_size)
.with_partition_size((1, stage_size_n as u32, stage_k).into())
.build()
.unwrap();
let mut builder =
TilingBlueprint::builder(tiling_scheme, plane_dim, &problem.as_matmul_problem())
.partition_buffering(PartitionBuffering::Single);
if swizzle {
let swizzle_dim = tiling_scheme.elements_per_stage_along_k() as usize;
let lhs = select_swizzle(swizzle_dim, dtypes.lhs_stage, vector_sizes.lhs);
let rhs = select_swizzle(swizzle_dim, dtypes.rhs_stage, vector_sizes.rhs);
builder = builder.shared_swizzle(SwizzleModes {
lhs,
rhs,
..Default::default()
});
}
Ok(builder.build())
}
const SWIZZLE_ATOM: usize = 16;
fn select_swizzle(swizzle_dim: usize, elem: StorageType, vector_size: VectorSize) -> SwizzleMode {
if elem.size() * vector_size > SWIZZLE_ATOM {
return SwizzleMode::None;
}
let swizzle_dim_bytes = swizzle_dim * elem.size();
if !swizzle_dim_bytes.is_power_of_two() {
return SwizzleMode::None;
}
match swizzle_dim_bytes {
32 => SwizzleMode::B32,
64 => SwizzleMode::B64,
_ => SwizzleMode::B128,
}
}