cubecl_linalg/convolution/
selection.rs

1use cubecl_core::{Runtime, client::ComputeClient, prelude::*};
2
3use super::{
4    algorithm::{Algorithm, StageInput},
5    base::ConvolutionProblem,
6};
7use crate::matmul::{
8    components::{
9        CompleteStageTiling, MatmulPrecision, MatmulProblem, MatmulSelection, MatmulSize,
10        stage::{STAGE_BUFFERING, StageVectorization},
11        tile::TileMatmulFamily,
12    },
13    kernels::matmul::{NUM_SM_APPROX, NUM_TENSOR_CORES_APPROX, find_instruction_shape},
14};
15
16pub fn select_matmul<A: Algorithm, R: Runtime, MP: MatmulPrecision>(
17    client: &ComputeClient<R::Server, R::Channel>,
18    problem: &ConvolutionProblem,
19    plane_dim: u32,
20) -> (MatmulSelection, StageInput) {
21    let mm_problem = problem.as_matmul_problem();
22    let selection = matmul_selection::<A::TileMatmul, MP, R>(client, &mm_problem, plane_dim);
23    let config_input = CompleteStageTiling {
24        tile_shape: selection.tile_shape,
25        tile_count: selection.tile_count,
26    };
27    let vectorization = StageVectorization {
28        stage_line_size: 0,
29        stage_elem_padding: 0,
30    };
31    // TODO Allows to select double buffering
32    (selection, (config_input, STAGE_BUFFERING, vectorization))
33}
34
35/// A heuristic to find the number of tiles in the stage.
36///
37/// Maximizes tensor core usage unless doing so would significantly impair
38/// parallelization across SMs. It ensures the number of cubes is as close as
39/// possible to the available SMs.
40pub(crate) fn find_stage_size_m_n(
41    m: usize,
42    n: usize,
43    num_sm: usize,
44    max_tensor_cores: usize,
45    instruction_m: usize,
46    instruction_n: usize,
47    stage_size_k: usize,
48) -> (usize, usize) {
49    let max_tiles_elems_m = 256 / instruction_m;
50    let max_tiles_elems_n = 256 / instruction_n;
51    let max_tiles_total_stage = 16 / stage_size_k;
52
53    let mut dim_num_tiles_m = max_tensor_cores
54        .min(max_tiles_elems_m)
55        .min(max_tiles_total_stage);
56
57    let mut dim_num_tiles_n = max_tensor_cores
58        .min(max_tiles_elems_n)
59        .min(max_tiles_total_stage);
60
61    let total_tiles_m = m.div_ceil(instruction_m);
62    let total_tiles_n = n.div_ceil(instruction_n);
63
64    while total_tiles_n < dim_num_tiles_n && dim_num_tiles_n > 1 {
65        dim_num_tiles_n /= 2;
66    }
67
68    let total_tiles = total_tiles_m * total_tiles_n;
69
70    let mut stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
71    let mut num_cubes_expected = (total_tiles + stage_num_tiles - 1) / stage_num_tiles;
72
73    // We keep track of two configurations to select the closest to `num_sm`, whether it's a bit over or under
74    let mut previous_dim_num_tiles = dim_num_tiles_m;
75    let mut previous_num_cubes = num_cubes_expected;
76
77    // Refine tensor core usage to stay as close as possible to `num_sm`
78    while num_cubes_expected < num_sm && dim_num_tiles_m > 1 {
79        previous_dim_num_tiles = dim_num_tiles_m;
80        previous_num_cubes = num_cubes_expected;
81
82        // Reduce tensor core usage
83        dim_num_tiles_m = (dim_num_tiles_m + 1) / 2;
84        stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
85
86        // Number of cubes grows as a consequence of smaller stage
87        num_cubes_expected = (total_tiles + stage_num_tiles - 1) / stage_num_tiles;
88    }
89
90    // Compare previous and current values to determine the closest to `num_sm`
91    if (previous_num_cubes as isize - num_sm as isize).abs()
92        <= (num_cubes_expected as isize - num_sm as isize).abs()
93    {
94        (previous_dim_num_tiles, dim_num_tiles_n)
95    } else {
96        (dim_num_tiles_n, dim_num_tiles_m)
97    }
98}
99
100pub fn matmul_selection<TMM: TileMatmulFamily, MP: MatmulPrecision, R: Runtime>(
101    client: &ComputeClient<R::Server, R::Channel>,
102    problem: &MatmulProblem,
103    plane_dim: u32,
104) -> MatmulSelection {
105    // rough heuristic based on previous bench results where 512 channels with a 3x3 kernel seemed
106    // to be the rough cutoff for the k=4 size.
107    let stage_size_k = if problem.k >= 4096 { 4 } else { 2 };
108
109    let (instruction_m, instruction_n, instruction_k) = find_instruction_shape(
110        if TMM::requires_tensor_cores() {
111            Some((
112                client.properties(),
113                (
114                    MP::ES::as_elem_native_unchecked(),
115                    MP::ES::as_elem_native_unchecked(),
116                    MP::EA::as_elem_native_unchecked(),
117                ),
118            ))
119        } else {
120            None
121        },
122        problem.m,
123        problem.n,
124    );
125
126    let hardware = client.properties().hardware_properties();
127    let num_sm = hardware
128        .num_streaming_multiprocessors
129        .unwrap_or(NUM_TENSOR_CORES_APPROX);
130    let max_tensor_cores = hardware.num_tensor_cores.unwrap_or(NUM_SM_APPROX);
131
132    let (stage_size_m, stage_size_n) = find_stage_size_m_n(
133        problem.m,
134        problem.n,
135        num_sm as usize,
136        max_tensor_cores as usize,
137        instruction_m,
138        instruction_n,
139        stage_size_k,
140    );
141
142    MatmulSelection {
143        tile_shape: MatmulSize {
144            m: instruction_m as u32,
145            n: instruction_n as u32,
146            k: instruction_k as u32,
147        },
148        tile_count: MatmulSize {
149            m: stage_size_m as u32,
150            n: stage_size_n as u32,
151            k: stage_size_k as u32,
152        },
153        plane_dim,
154        rows_per_plane: 1,
155    }
156}