cubek_convolution/components/
selection.rs

1use cubecl::{
2    Runtime,
3    client::ComputeClient,
4    ir::{LineSize, StorageType},
5};
6use cubek_matmul::components::stage::{PartitionBuffering, SwizzleMode};
7
8use cubek_matmul::definition::{
9    MatmulAvailabilityError, MatmulElems, MatmulLineSizes, SwizzleModes, TilingBlueprint,
10    TilingScheme, adjust_dtypes,
11};
12use cubek_matmul::{
13    components::tile::TileMatmulFamily,
14    routines::{NUM_SM_APPROX, NUM_TENSOR_CORES_APPROX, find_instruction_size},
15};
16
17use crate::components::ConvolutionProblem;
18
19/// A heuristic to find the number of tiles in the stage.
20///
21/// Maximizes tensor core usage unless doing so would significantly impair
22/// parallelization across SMs. It ensures the number of cubes is as close as
23/// possible to the available SMs.
24pub(crate) fn find_stage_size_m_n(
25    m: usize,
26    n: usize,
27    num_sm: usize,
28    max_tensor_cores: usize,
29    instruction_m: usize,
30    instruction_n: usize,
31    stage_size_k: usize,
32) -> (usize, usize) {
33    let max_tiles_elems_m = 256 / instruction_m;
34    let max_tiles_elems_n = 256 / instruction_n;
35    let max_tiles_total_stage = 16 / stage_size_k;
36
37    let mut dim_num_tiles_m = max_tensor_cores
38        .min(max_tiles_elems_m)
39        .min(max_tiles_total_stage);
40
41    let mut dim_num_tiles_n = max_tensor_cores
42        .min(max_tiles_elems_n)
43        .min(max_tiles_total_stage);
44
45    let total_tiles_m = m.div_ceil(instruction_m);
46    let total_tiles_n = n.div_ceil(instruction_n);
47
48    while total_tiles_n < dim_num_tiles_n && dim_num_tiles_n > 1 {
49        dim_num_tiles_n /= 2;
50    }
51
52    let total_tiles = total_tiles_m * total_tiles_n;
53
54    let mut stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
55    let mut num_cubes_expected = total_tiles.div_ceil(stage_num_tiles);
56
57    // We keep track of two configurations to select the closest to `num_sm`, whether it's a bit over or under
58    let mut previous_dim_num_tiles = dim_num_tiles_m;
59    let mut previous_num_cubes = num_cubes_expected;
60
61    // Refine tensor core usage to stay as close as possible to `num_sm`
62    while num_cubes_expected < num_sm && dim_num_tiles_m > 1 {
63        previous_dim_num_tiles = dim_num_tiles_m;
64        previous_num_cubes = num_cubes_expected;
65
66        // Reduce tensor core usage
67        dim_num_tiles_m = dim_num_tiles_m.div_ceil(2);
68        stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
69
70        // Number of cubes grows as a consequence of smaller stage
71        num_cubes_expected = total_tiles.div_ceil(stage_num_tiles);
72    }
73
74    // Compare previous and current values to determine the closest to `num_sm`
75    if (previous_num_cubes as isize - num_sm as isize).abs()
76        <= (num_cubes_expected as isize - num_sm as isize).abs()
77    {
78        (previous_dim_num_tiles, dim_num_tiles_n)
79    } else {
80        (dim_num_tiles_n, dim_num_tiles_m)
81    }
82}
83
84pub fn convolution_matmul_selection<TMM: TileMatmulFamily, R: Runtime>(
85    client: &ComputeClient<R>,
86    problem: &ConvolutionProblem,
87    plane_dim: u32,
88    swizzle: bool,
89    line_sizes: &MatmulLineSizes,
90    dtypes: &mut MatmulElems,
91) -> Result<TilingBlueprint, MatmulAvailabilityError> {
92    adjust_dtypes(client, dtypes, TMM::requires_accelerator());
93
94    // rough heuristic based on previous bench results where 512 channels with a 3x3 kernel seemed
95    // to be the rough cutoff for the k=4 size.
96    let stage_k = if problem.k >= 4096 { 4 } else { 2 };
97
98    let tile_size = find_instruction_size::<R, TMM>(client, dtypes, problem.m, problem.n)?;
99
100    let hardware = &client.properties().hardware;
101    let num_sm = hardware
102        .num_streaming_multiprocessors
103        .unwrap_or(NUM_TENSOR_CORES_APPROX);
104    let max_tensor_cores = hardware.num_tensor_cores.unwrap_or(NUM_SM_APPROX);
105
106    let (stage_size_m, stage_size_n) = find_stage_size_m_n(
107        problem.m,
108        problem.n,
109        num_sm as usize,
110        max_tensor_cores as usize,
111        tile_size.m() as usize,
112        tile_size.n() as usize,
113        stage_k as usize,
114    );
115
116    let tiling_scheme = TilingScheme::builder()
117        .with_stage_size((stage_size_m as u32, 1, 1).into())
118        .with_tile_size(tile_size)
119        .with_partition_size((1, stage_size_n as u32, stage_k).into())
120        .build()
121        .unwrap();
122
123    let mut builder =
124        TilingBlueprint::builder(tiling_scheme, plane_dim, &problem.as_matmul_problem())
125            .partition_buffering(PartitionBuffering::Single);
126
127    if swizzle {
128        let swizzle_dim = tiling_scheme.elements_per_stage_along_k() as usize;
129
130        let lhs = select_swizzle(swizzle_dim, dtypes.lhs_stage, line_sizes.lhs);
131        let rhs = select_swizzle(swizzle_dim, dtypes.rhs_stage, line_sizes.rhs);
132        builder = builder.shared_swizzle(SwizzleModes {
133            lhs,
134            rhs,
135            ..Default::default()
136        });
137    }
138
139    Ok(builder.build())
140}
141
142/// All modes currently use atom size 16
143const SWIZZLE_ATOM: usize = 16;
144
145fn select_swizzle(swizzle_dim: usize, elem: StorageType, line_size: LineSize) -> SwizzleMode {
146    // Line size exceeds swizzle atom
147    if elem.size() * line_size > SWIZZLE_ATOM {
148        return SwizzleMode::None;
149    }
150    let swizzle_dim_bytes = swizzle_dim * elem.size();
151    if !swizzle_dim_bytes.is_power_of_two() {
152        return SwizzleMode::None;
153    }
154    match swizzle_dim_bytes {
155        32 => SwizzleMode::B32,
156        64 => SwizzleMode::B64,
157        _ => SwizzleMode::B128,
158        //_ => SwizzleMode::None,
159    }
160}