cubek_convolution/components/
selection.rs1use cubecl::{
2 Runtime,
3 client::ComputeClient,
4 ir::{LineSize, StorageType},
5};
6use cubek_matmul::components::stage::{PartitionBuffering, SwizzleMode};
7
8use cubek_matmul::definition::{
9 MatmulAvailabilityError, MatmulElems, MatmulLineSizes, SwizzleModes, TilingBlueprint,
10 TilingScheme, adjust_dtypes,
11};
12use cubek_matmul::{
13 components::tile::TileMatmulFamily,
14 routines::{NUM_SM_APPROX, NUM_TENSOR_CORES_APPROX, find_instruction_size},
15};
16
17use crate::components::ConvolutionProblem;
18
19pub(crate) fn find_stage_size_m_n(
25 m: usize,
26 n: usize,
27 num_sm: usize,
28 max_tensor_cores: usize,
29 instruction_m: usize,
30 instruction_n: usize,
31 stage_size_k: usize,
32) -> (usize, usize) {
33 let max_tiles_elems_m = 256 / instruction_m;
34 let max_tiles_elems_n = 256 / instruction_n;
35 let max_tiles_total_stage = 16 / stage_size_k;
36
37 let mut dim_num_tiles_m = max_tensor_cores
38 .min(max_tiles_elems_m)
39 .min(max_tiles_total_stage);
40
41 let mut dim_num_tiles_n = max_tensor_cores
42 .min(max_tiles_elems_n)
43 .min(max_tiles_total_stage);
44
45 let total_tiles_m = m.div_ceil(instruction_m);
46 let total_tiles_n = n.div_ceil(instruction_n);
47
48 while total_tiles_n < dim_num_tiles_n && dim_num_tiles_n > 1 {
49 dim_num_tiles_n /= 2;
50 }
51
52 let total_tiles = total_tiles_m * total_tiles_n;
53
54 let mut stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
55 let mut num_cubes_expected = total_tiles.div_ceil(stage_num_tiles);
56
57 let mut previous_dim_num_tiles = dim_num_tiles_m;
59 let mut previous_num_cubes = num_cubes_expected;
60
61 while num_cubes_expected < num_sm && dim_num_tiles_m > 1 {
63 previous_dim_num_tiles = dim_num_tiles_m;
64 previous_num_cubes = num_cubes_expected;
65
66 dim_num_tiles_m = dim_num_tiles_m.div_ceil(2);
68 stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
69
70 num_cubes_expected = total_tiles.div_ceil(stage_num_tiles);
72 }
73
74 if (previous_num_cubes as isize - num_sm as isize).abs()
76 <= (num_cubes_expected as isize - num_sm as isize).abs()
77 {
78 (previous_dim_num_tiles, dim_num_tiles_n)
79 } else {
80 (dim_num_tiles_n, dim_num_tiles_m)
81 }
82}
83
84pub fn convolution_matmul_selection<TMM: TileMatmulFamily, R: Runtime>(
85 client: &ComputeClient<R>,
86 problem: &ConvolutionProblem,
87 plane_dim: u32,
88 swizzle: bool,
89 line_sizes: &MatmulLineSizes,
90 dtypes: &mut MatmulElems,
91) -> Result<TilingBlueprint, MatmulAvailabilityError> {
92 adjust_dtypes(client, dtypes, TMM::requires_accelerator());
93
94 let stage_k = if problem.k >= 4096 { 4 } else { 2 };
97
98 let tile_size = find_instruction_size::<R, TMM>(client, dtypes, problem.m, problem.n)?;
99
100 let hardware = &client.properties().hardware;
101 let num_sm = hardware
102 .num_streaming_multiprocessors
103 .unwrap_or(NUM_TENSOR_CORES_APPROX);
104 let max_tensor_cores = hardware.num_tensor_cores.unwrap_or(NUM_SM_APPROX);
105
106 let (stage_size_m, stage_size_n) = find_stage_size_m_n(
107 problem.m,
108 problem.n,
109 num_sm as usize,
110 max_tensor_cores as usize,
111 tile_size.m() as usize,
112 tile_size.n() as usize,
113 stage_k as usize,
114 );
115
116 let tiling_scheme = TilingScheme::builder()
117 .with_stage_size((stage_size_m as u32, 1, 1).into())
118 .with_tile_size(tile_size)
119 .with_partition_size((1, stage_size_n as u32, stage_k).into())
120 .build()
121 .unwrap();
122
123 let mut builder =
124 TilingBlueprint::builder(tiling_scheme, plane_dim, &problem.as_matmul_problem())
125 .partition_buffering(PartitionBuffering::Single);
126
127 if swizzle {
128 let swizzle_dim = tiling_scheme.elements_per_stage_along_k() as usize;
129
130 let lhs = select_swizzle(swizzle_dim, dtypes.lhs_stage, line_sizes.lhs);
131 let rhs = select_swizzle(swizzle_dim, dtypes.rhs_stage, line_sizes.rhs);
132 builder = builder.shared_swizzle(SwizzleModes {
133 lhs,
134 rhs,
135 ..Default::default()
136 });
137 }
138
139 Ok(builder.build())
140}
141
142const SWIZZLE_ATOM: usize = 16;
144
145fn select_swizzle(swizzle_dim: usize, elem: StorageType, line_size: LineSize) -> SwizzleMode {
146 if elem.size() * line_size > SWIZZLE_ATOM {
148 return SwizzleMode::None;
149 }
150 let swizzle_dim_bytes = swizzle_dim * elem.size();
151 if !swizzle_dim_bytes.is_power_of_two() {
152 return SwizzleMode::None;
153 }
154 match swizzle_dim_bytes {
155 32 => SwizzleMode::B32,
156 64 => SwizzleMode::B64,
157 _ => SwizzleMode::B128,
158 }
160}