cubecl_convolution/
selection.rs1use cubecl_core::{Runtime, client::ComputeClient, ir::Elem};
2use cubecl_matmul::components::stage::PartitionBuffering;
3
4use super::base::ConvolutionProblem;
5use cubecl_matmul::components::{MatmulSelection, TilingScheme};
6use cubecl_matmul::{
7 components::tile::TileMatmulFamily,
8 kernels::layered::{NUM_SM_APPROX, NUM_TENSOR_CORES_APPROX, find_instruction_size},
9};
10
11pub(crate) fn find_stage_size_m_n(
17 m: usize,
18 n: usize,
19 num_sm: usize,
20 max_tensor_cores: usize,
21 instruction_m: usize,
22 instruction_n: usize,
23 stage_size_k: usize,
24) -> (usize, usize) {
25 let max_tiles_elems_m = 256 / instruction_m;
26 let max_tiles_elems_n = 256 / instruction_n;
27 let max_tiles_total_stage = 16 / stage_size_k;
28
29 let mut dim_num_tiles_m = max_tensor_cores
30 .min(max_tiles_elems_m)
31 .min(max_tiles_total_stage);
32
33 let mut dim_num_tiles_n = max_tensor_cores
34 .min(max_tiles_elems_n)
35 .min(max_tiles_total_stage);
36
37 let total_tiles_m = m.div_ceil(instruction_m);
38 let total_tiles_n = n.div_ceil(instruction_n);
39
40 while total_tiles_n < dim_num_tiles_n && dim_num_tiles_n > 1 {
41 dim_num_tiles_n /= 2;
42 }
43
44 let total_tiles = total_tiles_m * total_tiles_n;
45
46 let mut stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
47 let mut num_cubes_expected = (total_tiles + stage_num_tiles - 1) / stage_num_tiles;
48
49 let mut previous_dim_num_tiles = dim_num_tiles_m;
51 let mut previous_num_cubes = num_cubes_expected;
52
53 while num_cubes_expected < num_sm && dim_num_tiles_m > 1 {
55 previous_dim_num_tiles = dim_num_tiles_m;
56 previous_num_cubes = num_cubes_expected;
57
58 dim_num_tiles_m = (dim_num_tiles_m + 1) / 2;
60 stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
61
62 num_cubes_expected = (total_tiles + stage_num_tiles - 1) / stage_num_tiles;
64 }
65
66 if (previous_num_cubes as isize - num_sm as isize).abs()
68 <= (num_cubes_expected as isize - num_sm as isize).abs()
69 {
70 (previous_dim_num_tiles, dim_num_tiles_n)
71 } else {
72 (dim_num_tiles_n, dim_num_tiles_m)
73 }
74}
75
76pub fn convolution_matmul_selection<TMM: TileMatmulFamily, R: Runtime>(
77 client: &ComputeClient<R::Server, R::Channel>,
78 problem: &ConvolutionProblem,
79 plane_dim: u32,
80 elem_stage: Elem,
81 elem_acc: Elem,
82) -> MatmulSelection {
83 let stage_k = if problem.k >= 4096 { 4 } else { 2 };
86
87 let tile_size = find_instruction_size(
88 if TMM::requires_accelerator() {
89 Some((client.properties(), (elem_stage, elem_stage, elem_acc)))
90 } else {
91 None
92 },
93 problem.m,
94 problem.n,
95 );
96
97 let hardware = &client.properties().hardware;
98 let num_sm = hardware
99 .num_streaming_multiprocessors
100 .unwrap_or(NUM_TENSOR_CORES_APPROX);
101 let max_tensor_cores = hardware.num_tensor_cores.unwrap_or(NUM_SM_APPROX);
102
103 let (stage_size_m, stage_size_n) = find_stage_size_m_n(
104 problem.m,
105 problem.n,
106 num_sm as usize,
107 max_tensor_cores as usize,
108 tile_size.m() as usize,
109 tile_size.n() as usize,
110 stage_k as usize,
111 );
112
113 let tiling_scheme = TilingScheme::builder()
114 .with_stage_size((stage_size_m as u32, 1, 1).into())
115 .with_tile_size(tile_size)
116 .with_partition_size((1, stage_size_n as u32, stage_k).into())
117 .build()
118 .unwrap();
119
120 MatmulSelection::builder(tiling_scheme, plane_dim)
121 .partition_buffering(PartitionBuffering::Single)
122 .build()
123}