cubecl_convolution/components/
selection.rs1use cubecl_core::{Runtime, client::ComputeClient};
2use cubecl_matmul::components::stage::PartitionBuffering;
3
4use cubecl_matmul::components::{MatmulElems, MatmulSelection, TilingScheme};
5use cubecl_matmul::{
6 components::tile::TileMatmulFamily,
7 kernels::layered::{NUM_SM_APPROX, NUM_TENSOR_CORES_APPROX, find_instruction_size},
8};
9
10use crate::components::ConvolutionProblem;
11
12pub(crate) fn find_stage_size_m_n(
18 m: usize,
19 n: usize,
20 num_sm: usize,
21 max_tensor_cores: usize,
22 instruction_m: usize,
23 instruction_n: usize,
24 stage_size_k: usize,
25) -> (usize, usize) {
26 let max_tiles_elems_m = 256 / instruction_m;
27 let max_tiles_elems_n = 256 / instruction_n;
28 let max_tiles_total_stage = 16 / stage_size_k;
29
30 let mut dim_num_tiles_m = max_tensor_cores
31 .min(max_tiles_elems_m)
32 .min(max_tiles_total_stage);
33
34 let mut dim_num_tiles_n = max_tensor_cores
35 .min(max_tiles_elems_n)
36 .min(max_tiles_total_stage);
37
38 let total_tiles_m = m.div_ceil(instruction_m);
39 let total_tiles_n = n.div_ceil(instruction_n);
40
41 while total_tiles_n < dim_num_tiles_n && dim_num_tiles_n > 1 {
42 dim_num_tiles_n /= 2;
43 }
44
45 let total_tiles = total_tiles_m * total_tiles_n;
46
47 let mut stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
48 let mut num_cubes_expected = total_tiles.div_ceil(stage_num_tiles);
49
50 let mut previous_dim_num_tiles = dim_num_tiles_m;
52 let mut previous_num_cubes = num_cubes_expected;
53
54 while num_cubes_expected < num_sm && dim_num_tiles_m > 1 {
56 previous_dim_num_tiles = dim_num_tiles_m;
57 previous_num_cubes = num_cubes_expected;
58
59 dim_num_tiles_m = dim_num_tiles_m.div_ceil(2);
61 stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
62
63 num_cubes_expected = total_tiles.div_ceil(stage_num_tiles);
65 }
66
67 if (previous_num_cubes as isize - num_sm as isize).abs()
69 <= (num_cubes_expected as isize - num_sm as isize).abs()
70 {
71 (previous_dim_num_tiles, dim_num_tiles_n)
72 } else {
73 (dim_num_tiles_n, dim_num_tiles_m)
74 }
75}
76
77pub fn convolution_matmul_selection<TMM: TileMatmulFamily, R: Runtime>(
78 client: &ComputeClient<R::Server>,
79 problem: &ConvolutionProblem,
80 plane_dim: u32,
81 matmul_elems: MatmulElems,
82) -> MatmulSelection {
83 let stage_k = if problem.k >= 4096 { 4 } else { 2 };
86
87 let tile_size = find_instruction_size(
88 if TMM::requires_accelerator() {
89 Some((
90 client.properties(),
91 (
92 matmul_elems.lhs_register,
93 matmul_elems.rhs_register,
94 matmul_elems.acc_register,
95 ),
96 ))
97 } else {
98 None
99 },
100 problem.m,
101 problem.n,
102 );
103
104 let hardware = &client.properties().hardware;
105 let num_sm = hardware
106 .num_streaming_multiprocessors
107 .unwrap_or(NUM_TENSOR_CORES_APPROX);
108 let max_tensor_cores = hardware.num_tensor_cores.unwrap_or(NUM_SM_APPROX);
109
110 let (stage_size_m, stage_size_n) = find_stage_size_m_n(
111 problem.m,
112 problem.n,
113 num_sm as usize,
114 max_tensor_cores as usize,
115 tile_size.m() as usize,
116 tile_size.n() as usize,
117 stage_k as usize,
118 );
119
120 let tiling_scheme = TilingScheme::builder()
121 .with_stage_size((stage_size_m as u32, 1, 1).into())
122 .with_tile_size(tile_size)
123 .with_partition_size((1, stage_size_n as u32, stage_k).into())
124 .build()
125 .unwrap();
126
127 MatmulSelection::builder(tiling_scheme, plane_dim)
128 .partition_buffering(PartitionBuffering::Single)
129 .build()
130}