cubecl_convolution/components/
selection.rs

1use cubecl_core::{Runtime, client::ComputeClient, ir::StorageType};
2use cubecl_matmul::components::{
3    MatmulLineSizes, SwizzleConfig,
4    stage::{PartitionBuffering, SwizzleMode},
5};
6
7use cubecl_matmul::components::{
8    MatmulAvailabilityError, MatmulElems, MatmulSelection, TilingScheme, adjust_dtypes,
9};
10use cubecl_matmul::{
11    components::tile::TileMatmulFamily,
12    kernels::layered::{NUM_SM_APPROX, NUM_TENSOR_CORES_APPROX, find_instruction_size},
13};
14
15use crate::components::ConvolutionProblem;
16
17/// A heuristic to find the number of tiles in the stage.
18///
19/// Maximizes tensor core usage unless doing so would significantly impair
20/// parallelization across SMs. It ensures the number of cubes is as close as
21/// possible to the available SMs.
22pub(crate) fn find_stage_size_m_n(
23    m: usize,
24    n: usize,
25    num_sm: usize,
26    max_tensor_cores: usize,
27    instruction_m: usize,
28    instruction_n: usize,
29    stage_size_k: usize,
30) -> (usize, usize) {
31    let max_tiles_elems_m = 256 / instruction_m;
32    let max_tiles_elems_n = 256 / instruction_n;
33    let max_tiles_total_stage = 16 / stage_size_k;
34
35    let mut dim_num_tiles_m = max_tensor_cores
36        .min(max_tiles_elems_m)
37        .min(max_tiles_total_stage);
38
39    let mut dim_num_tiles_n = max_tensor_cores
40        .min(max_tiles_elems_n)
41        .min(max_tiles_total_stage);
42
43    let total_tiles_m = m.div_ceil(instruction_m);
44    let total_tiles_n = n.div_ceil(instruction_n);
45
46    while total_tiles_n < dim_num_tiles_n && dim_num_tiles_n > 1 {
47        dim_num_tiles_n /= 2;
48    }
49
50    let total_tiles = total_tiles_m * total_tiles_n;
51
52    let mut stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
53    let mut num_cubes_expected = total_tiles.div_ceil(stage_num_tiles);
54
55    // We keep track of two configurations to select the closest to `num_sm`, whether it's a bit over or under
56    let mut previous_dim_num_tiles = dim_num_tiles_m;
57    let mut previous_num_cubes = num_cubes_expected;
58
59    // Refine tensor core usage to stay as close as possible to `num_sm`
60    while num_cubes_expected < num_sm && dim_num_tiles_m > 1 {
61        previous_dim_num_tiles = dim_num_tiles_m;
62        previous_num_cubes = num_cubes_expected;
63
64        // Reduce tensor core usage
65        dim_num_tiles_m = dim_num_tiles_m.div_ceil(2);
66        stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
67
68        // Number of cubes grows as a consequence of smaller stage
69        num_cubes_expected = total_tiles.div_ceil(stage_num_tiles);
70    }
71
72    // Compare previous and current values to determine the closest to `num_sm`
73    if (previous_num_cubes as isize - num_sm as isize).abs()
74        <= (num_cubes_expected as isize - num_sm as isize).abs()
75    {
76        (previous_dim_num_tiles, dim_num_tiles_n)
77    } else {
78        (dim_num_tiles_n, dim_num_tiles_m)
79    }
80}
81
82pub fn convolution_matmul_selection<TMM: TileMatmulFamily, R: Runtime>(
83    client: &ComputeClient<R>,
84    problem: &ConvolutionProblem,
85    plane_dim: u32,
86    swizzle: bool,
87    line_sizes: &MatmulLineSizes,
88    dtypes: &mut MatmulElems,
89) -> Result<MatmulSelection, MatmulAvailabilityError> {
90    adjust_dtypes(client, dtypes, TMM::requires_accelerator());
91
92    // rough heuristic based on previous bench results where 512 channels with a 3x3 kernel seemed
93    // to be the rough cutoff for the k=4 size.
94    let stage_k = if problem.k >= 4096 { 4 } else { 2 };
95
96    let tile_size = find_instruction_size::<R, TMM>(client, dtypes, problem.m, problem.n)?;
97
98    let hardware = &client.properties().hardware;
99    let num_sm = hardware
100        .num_streaming_multiprocessors
101        .unwrap_or(NUM_TENSOR_CORES_APPROX);
102    let max_tensor_cores = hardware.num_tensor_cores.unwrap_or(NUM_SM_APPROX);
103
104    let (stage_size_m, stage_size_n) = find_stage_size_m_n(
105        problem.m,
106        problem.n,
107        num_sm as usize,
108        max_tensor_cores as usize,
109        tile_size.m() as usize,
110        tile_size.n() as usize,
111        stage_k as usize,
112    );
113
114    let tiling_scheme = TilingScheme::builder()
115        .with_stage_size((stage_size_m as u32, 1, 1).into())
116        .with_tile_size(tile_size)
117        .with_partition_size((1, stage_size_n as u32, stage_k).into())
118        .build()
119        .unwrap();
120
121    let mut builder = MatmulSelection::builder(tiling_scheme, plane_dim)
122        .partition_buffering(PartitionBuffering::Single);
123
124    if swizzle {
125        let swizzle_dim = tiling_scheme.elements_per_stage_along_k();
126
127        let lhs = select_swizzle(swizzle_dim, *dtypes.lhs_stage, line_sizes.lhs);
128        let rhs = select_swizzle(swizzle_dim, *dtypes.rhs_stage, line_sizes.rhs);
129        builder = builder.shared_swizzle(SwizzleConfig {
130            lhs,
131            rhs,
132            ..Default::default()
133        });
134    }
135
136    Ok(builder.build())
137}
138
139/// All modes currently use atom size 16
140const SWIZZLE_ATOM: usize = 16;
141
142fn select_swizzle(swizzle_dim: u32, elem: StorageType, line_size: u8) -> SwizzleMode {
143    // Line size exceeds swizzle atom
144    if elem.size() * line_size as usize > SWIZZLE_ATOM {
145        return SwizzleMode::None;
146    }
147    let swizzle_dim_bytes = swizzle_dim as usize * elem.size();
148    if !swizzle_dim_bytes.is_power_of_two() {
149        return SwizzleMode::None;
150    }
151    match swizzle_dim_bytes {
152        32 => SwizzleMode::B32,
153        64 => SwizzleMode::B64,
154        _ => SwizzleMode::B128,
155        //_ => SwizzleMode::None,
156    }
157}