cubecl_linalg/convolution/
selection.rs1use cubecl_core::{Runtime, client::ComputeClient, prelude::*};
2
3use super::{
4 algorithm::{Algorithm, StageInput},
5 base::ConvolutionProblem,
6};
7use crate::matmul::{
8 components::{
9 CompleteStageTiling, MatmulPrecision, MatmulProblem, MatmulSelection, MatmulSize,
10 stage::{STAGE_BUFFERING, StageVectorization},
11 tile::TileMatmulFamily,
12 },
13 kernels::matmul::{NUM_SM_APPROX, NUM_TENSOR_CORES_APPROX, find_instruction_shape},
14};
15
16pub fn select_matmul<A: Algorithm, R: Runtime, MP: MatmulPrecision>(
17 client: &ComputeClient<R::Server, R::Channel>,
18 problem: &ConvolutionProblem,
19 plane_dim: u32,
20) -> (MatmulSelection, StageInput) {
21 let mm_problem = problem.as_matmul_problem();
22 let selection = matmul_selection::<A::TileMatmul, MP, R>(client, &mm_problem, plane_dim);
23 let config_input = CompleteStageTiling {
24 tile_shape: selection.tile_shape,
25 tile_count: selection.tile_count,
26 };
27 let vectorization = StageVectorization {
28 stage_line_size: 0,
29 stage_elem_padding: 0,
30 };
31 (selection, (config_input, STAGE_BUFFERING, vectorization))
33}
34
35pub(crate) fn find_stage_size_m_n(
41 m: usize,
42 n: usize,
43 num_sm: usize,
44 max_tensor_cores: usize,
45 instruction_m: usize,
46 instruction_n: usize,
47 stage_size_k: usize,
48) -> (usize, usize) {
49 let max_tiles_elems_m = 256 / instruction_m;
50 let max_tiles_elems_n = 256 / instruction_n;
51 let max_tiles_total_stage = 16 / stage_size_k;
52
53 let mut dim_num_tiles_m = max_tensor_cores
54 .min(max_tiles_elems_m)
55 .min(max_tiles_total_stage);
56
57 let mut dim_num_tiles_n = max_tensor_cores
58 .min(max_tiles_elems_n)
59 .min(max_tiles_total_stage);
60
61 let total_tiles_m = m.div_ceil(instruction_m);
62 let total_tiles_n = n.div_ceil(instruction_n);
63
64 while total_tiles_n < dim_num_tiles_n && dim_num_tiles_n > 1 {
65 dim_num_tiles_n /= 2;
66 }
67
68 let total_tiles = total_tiles_m * total_tiles_n;
69
70 let mut stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
71 let mut num_cubes_expected = (total_tiles + stage_num_tiles - 1) / stage_num_tiles;
72
73 let mut previous_dim_num_tiles = dim_num_tiles_m;
75 let mut previous_num_cubes = num_cubes_expected;
76
77 while num_cubes_expected < num_sm && dim_num_tiles_m > 1 {
79 previous_dim_num_tiles = dim_num_tiles_m;
80 previous_num_cubes = num_cubes_expected;
81
82 dim_num_tiles_m = (dim_num_tiles_m + 1) / 2;
84 stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
85
86 num_cubes_expected = (total_tiles + stage_num_tiles - 1) / stage_num_tiles;
88 }
89
90 if (previous_num_cubes as isize - num_sm as isize).abs()
92 <= (num_cubes_expected as isize - num_sm as isize).abs()
93 {
94 (previous_dim_num_tiles, dim_num_tiles_n)
95 } else {
96 (dim_num_tiles_n, dim_num_tiles_m)
97 }
98}
99
100pub fn matmul_selection<TMM: TileMatmulFamily, MP: MatmulPrecision, R: Runtime>(
101 client: &ComputeClient<R::Server, R::Channel>,
102 problem: &MatmulProblem,
103 plane_dim: u32,
104) -> MatmulSelection {
105 let stage_size_k = if problem.k >= 4096 { 4 } else { 2 };
108
109 let (instruction_m, instruction_n, instruction_k) = find_instruction_shape(
110 if TMM::requires_tensor_cores() {
111 Some((
112 client.properties(),
113 (
114 MP::ES::as_elem_native_unchecked(),
115 MP::ES::as_elem_native_unchecked(),
116 MP::EA::as_elem_native_unchecked(),
117 ),
118 ))
119 } else {
120 None
121 },
122 problem.m,
123 problem.n,
124 );
125
126 let hardware = client.properties().hardware_properties();
127 let num_sm = hardware
128 .num_streaming_multiprocessors
129 .unwrap_or(NUM_TENSOR_CORES_APPROX);
130 let max_tensor_cores = hardware.num_tensor_cores.unwrap_or(NUM_SM_APPROX);
131
132 let (stage_size_m, stage_size_n) = find_stage_size_m_n(
133 problem.m,
134 problem.n,
135 num_sm as usize,
136 max_tensor_cores as usize,
137 instruction_m,
138 instruction_n,
139 stage_size_k,
140 );
141
142 MatmulSelection {
143 tile_shape: MatmulSize {
144 m: instruction_m as u32,
145 n: instruction_n as u32,
146 k: instruction_k as u32,
147 },
148 tile_count: MatmulSize {
149 m: stage_size_m as u32,
150 n: stage_size_n as u32,
151 k: stage_size_k as u32,
152 },
153 plane_dim,
154 rows_per_plane: 1,
155 }
156}