cubecl_convolution/components/
selection.rs

1use cubecl_core::{Runtime, client::ComputeClient};
2use cubecl_matmul::components::stage::PartitionBuffering;
3
4use cubecl_matmul::components::{
5    MatmulAvailabilityError, MatmulElems, MatmulSelection, TilingScheme, adjust_dtypes,
6};
7use cubecl_matmul::{
8    components::tile::TileMatmulFamily,
9    kernels::layered::{NUM_SM_APPROX, NUM_TENSOR_CORES_APPROX, find_instruction_size},
10};
11
12use crate::components::ConvolutionProblem;
13
14/// A heuristic to find the number of tiles in the stage.
15///
16/// Maximizes tensor core usage unless doing so would significantly impair
17/// parallelization across SMs. It ensures the number of cubes is as close as
18/// possible to the available SMs.
19pub(crate) fn find_stage_size_m_n(
20    m: usize,
21    n: usize,
22    num_sm: usize,
23    max_tensor_cores: usize,
24    instruction_m: usize,
25    instruction_n: usize,
26    stage_size_k: usize,
27) -> (usize, usize) {
28    let max_tiles_elems_m = 256 / instruction_m;
29    let max_tiles_elems_n = 256 / instruction_n;
30    let max_tiles_total_stage = 16 / stage_size_k;
31
32    let mut dim_num_tiles_m = max_tensor_cores
33        .min(max_tiles_elems_m)
34        .min(max_tiles_total_stage);
35
36    let mut dim_num_tiles_n = max_tensor_cores
37        .min(max_tiles_elems_n)
38        .min(max_tiles_total_stage);
39
40    let total_tiles_m = m.div_ceil(instruction_m);
41    let total_tiles_n = n.div_ceil(instruction_n);
42
43    while total_tiles_n < dim_num_tiles_n && dim_num_tiles_n > 1 {
44        dim_num_tiles_n /= 2;
45    }
46
47    let total_tiles = total_tiles_m * total_tiles_n;
48
49    let mut stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
50    let mut num_cubes_expected = total_tiles.div_ceil(stage_num_tiles);
51
52    // We keep track of two configurations to select the closest to `num_sm`, whether it's a bit over or under
53    let mut previous_dim_num_tiles = dim_num_tiles_m;
54    let mut previous_num_cubes = num_cubes_expected;
55
56    // Refine tensor core usage to stay as close as possible to `num_sm`
57    while num_cubes_expected < num_sm && dim_num_tiles_m > 1 {
58        previous_dim_num_tiles = dim_num_tiles_m;
59        previous_num_cubes = num_cubes_expected;
60
61        // Reduce tensor core usage
62        dim_num_tiles_m = dim_num_tiles_m.div_ceil(2);
63        stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
64
65        // Number of cubes grows as a consequence of smaller stage
66        num_cubes_expected = total_tiles.div_ceil(stage_num_tiles);
67    }
68
69    // Compare previous and current values to determine the closest to `num_sm`
70    if (previous_num_cubes as isize - num_sm as isize).abs()
71        <= (num_cubes_expected as isize - num_sm as isize).abs()
72    {
73        (previous_dim_num_tiles, dim_num_tiles_n)
74    } else {
75        (dim_num_tiles_n, dim_num_tiles_m)
76    }
77}
78
79pub fn convolution_matmul_selection<TMM: TileMatmulFamily, R: Runtime>(
80    client: &ComputeClient<R>,
81    problem: &ConvolutionProblem,
82    plane_dim: u32,
83    dtypes: &mut MatmulElems,
84) -> Result<MatmulSelection, MatmulAvailabilityError> {
85    adjust_dtypes(client, dtypes, TMM::requires_accelerator());
86
87    // rough heuristic based on previous bench results where 512 channels with a 3x3 kernel seemed
88    // to be the rough cutoff for the k=4 size.
89    let stage_k = if problem.k >= 4096 { 4 } else { 2 };
90
91    let tile_size = find_instruction_size::<R, TMM>(client, dtypes, problem.m, problem.n)?;
92
93    let hardware = &client.properties().hardware;
94    let num_sm = hardware
95        .num_streaming_multiprocessors
96        .unwrap_or(NUM_TENSOR_CORES_APPROX);
97    let max_tensor_cores = hardware.num_tensor_cores.unwrap_or(NUM_SM_APPROX);
98
99    let (stage_size_m, stage_size_n) = find_stage_size_m_n(
100        problem.m,
101        problem.n,
102        num_sm as usize,
103        max_tensor_cores as usize,
104        tile_size.m() as usize,
105        tile_size.n() as usize,
106        stage_k as usize,
107    );
108
109    let tiling_scheme = TilingScheme::builder()
110        .with_stage_size((stage_size_m as u32, 1, 1).into())
111        .with_tile_size(tile_size)
112        .with_partition_size((1, stage_size_n as u32, stage_k).into())
113        .build()
114        .unwrap();
115
116    Ok(MatmulSelection::builder(tiling_scheme, plane_dim)
117        .partition_buffering(PartitionBuffering::Single)
118        .build())
119}