cubek_convolution/components/
selection.rs1use cubecl::{
2 Runtime,
3 client::ComputeClient,
4 ir::{StorageType, VectorSize},
5};
6use cubek_matmul::components::stage::PartitionBuffering;
7
8use cubek_matmul::definition::{
9 MatmulAvailabilityError, MatmulElems, MatmulVectorSizes, SwizzleModes, TilingBlueprint,
10 TilingScheme, adjust_dtypes,
11};
12use cubek_matmul::{
13 components::tile_matmul::{DispatchTileMatmul, TileMatmulFamily as _},
14 routines::{NUM_SM_APPROX, NUM_TENSOR_CORES_APPROX, find_instruction_size},
15};
16use cubek_std::stage::SwizzleMode;
17
18use crate::components::ConvolutionProblem;
19
20pub(crate) fn find_stage_size_m_n(
26 m: usize,
27 n: usize,
28 num_sm: usize,
29 max_tensor_cores: usize,
30 instruction_m: usize,
31 instruction_n: usize,
32 stage_size_k: usize,
33) -> (usize, usize) {
34 let max_tiles_elems_m = 256 / instruction_m;
35 let max_tiles_elems_n = 256 / instruction_n;
36 let max_tiles_total_stage = 16 / stage_size_k;
37
38 let mut dim_num_tiles_m = max_tensor_cores
39 .min(max_tiles_elems_m)
40 .min(max_tiles_total_stage);
41
42 let mut dim_num_tiles_n = max_tensor_cores
43 .min(max_tiles_elems_n)
44 .min(max_tiles_total_stage);
45
46 let total_tiles_m = m.div_ceil(instruction_m);
47 let total_tiles_n = n.div_ceil(instruction_n);
48
49 while total_tiles_n < dim_num_tiles_n && dim_num_tiles_n > 1 {
50 dim_num_tiles_n /= 2;
51 }
52
53 let total_tiles = total_tiles_m * total_tiles_n;
54
55 let mut stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
56 let mut num_cubes_expected = total_tiles.div_ceil(stage_num_tiles);
57
58 let mut previous_dim_num_tiles = dim_num_tiles_m;
60 let mut previous_num_cubes = num_cubes_expected;
61
62 while num_cubes_expected < num_sm && dim_num_tiles_m > 1 {
64 previous_dim_num_tiles = dim_num_tiles_m;
65 previous_num_cubes = num_cubes_expected;
66
67 dim_num_tiles_m = dim_num_tiles_m.div_ceil(2);
69 stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
70
71 num_cubes_expected = total_tiles.div_ceil(stage_num_tiles);
73 }
74
75 if (previous_num_cubes as isize - num_sm as isize).abs()
77 <= (num_cubes_expected as isize - num_sm as isize).abs()
78 {
79 (previous_dim_num_tiles, dim_num_tiles_n)
80 } else {
81 (dim_num_tiles_n, dim_num_tiles_m)
82 }
83}
84
85pub fn convolution_matmul_selection<R: Runtime>(
86 tile_matmul: DispatchTileMatmul,
87 client: &ComputeClient<R>,
88 problem: &ConvolutionProblem,
89 plane_dim: u32,
90 swizzle: bool,
91 vector_sizes: &MatmulVectorSizes,
92 dtypes: &mut MatmulElems,
93) -> Result<TilingBlueprint, MatmulAvailabilityError> {
94 adjust_dtypes(client, dtypes, tile_matmul.requires_accelerator());
95
96 let stage_k = if problem.k >= 4096 { 4 } else { 2 };
99
100 let tile_size = find_instruction_size::<R, _, _>(
101 client,
102 (
103 dtypes.lhs_register,
104 dtypes.rhs_register,
105 dtypes.acc_register,
106 ),
107 (problem.m, problem.n, problem.k).into(),
108 (None, None, None),
109 |c, cfg| tile_matmul.is_supported(c, cfg),
110 |c, l, r, a| tile_matmul.supported_sizes(c, l, r, a),
111 )?;
112
113 let hardware = &client.properties().hardware;
114 let num_sm = hardware
115 .num_streaming_multiprocessors
116 .unwrap_or(NUM_TENSOR_CORES_APPROX);
117 let max_tensor_cores = hardware.num_tensor_cores.unwrap_or(NUM_SM_APPROX);
118
119 let (stage_size_m, stage_size_n) = find_stage_size_m_n(
120 problem.m,
121 problem.n,
122 num_sm as usize,
123 max_tensor_cores as usize,
124 tile_size.m() as usize,
125 tile_size.n() as usize,
126 stage_k as usize,
127 );
128
129 let tiling_scheme = TilingScheme::builder()
130 .with_stage_size((stage_size_m as u32, 1, 1).into())
131 .with_tile_size(tile_size)
132 .with_partition_size((1, stage_size_n as u32, stage_k).into())
133 .build()
134 .unwrap();
135
136 let mut builder = TilingBlueprint::builder(
137 tile_matmul,
138 tiling_scheme,
139 plane_dim,
140 &problem.as_matmul_problem(),
141 )
142 .partition_buffering(PartitionBuffering::Single);
143
144 if swizzle {
145 let swizzle_dim = tiling_scheme.elements_per_stage_along_k() as usize;
146
147 let lhs = select_swizzle(swizzle_dim, dtypes.lhs_stage, vector_sizes.lhs);
148 let rhs = select_swizzle(swizzle_dim, dtypes.rhs_stage, vector_sizes.rhs);
149 builder = builder.shared_swizzle(SwizzleModes {
150 lhs,
151 rhs,
152 ..Default::default()
153 });
154 }
155
156 Ok(builder.build())
157}
158
159const SWIZZLE_ATOM: usize = 16;
161
162fn select_swizzle(swizzle_dim: usize, elem: StorageType, vector_size: VectorSize) -> SwizzleMode {
163 if elem.size() * vector_size > SWIZZLE_ATOM {
165 return SwizzleMode::None;
166 }
167 let swizzle_dim_bytes = swizzle_dim * elem.size();
168 if !swizzle_dim_bytes.is_power_of_two() {
169 return SwizzleMode::None;
170 }
171 match swizzle_dim_bytes {
172 32 => SwizzleMode::B32,
173 64 => SwizzleMode::B64,
174 _ => SwizzleMode::B128,
175 }
177}