Skip to main content

cubek_convolution/components/
selection.rs

1use cubecl::{
2    Runtime,
3    client::ComputeClient,
4    ir::{StorageType, VectorSize},
5};
6use cubek_matmul::components::stage::PartitionBuffering;
7
8use cubek_matmul::definition::{
9    MatmulAvailabilityError, MatmulElems, MatmulVectorSizes, TilingBlueprint, TilingScheme,
10    adjust_dtypes,
11};
12use cubek_matmul::{
13    components::tile::TileMatmulKind,
14    routines::{NUM_SM_APPROX, NUM_TENSOR_CORES_APPROX, find_instruction_size},
15};
16use cubek_std::SwizzleModes;
17use cubek_std::stage::SwizzleMode;
18
19use crate::components::ConvolutionProblem;
20
21/// A heuristic to find the number of tiles in the stage.
22///
23/// Maximizes tensor core usage unless doing so would significantly impair
24/// parallelization across SMs. It ensures the number of cubes is as close as
25/// possible to the available SMs.
26pub(crate) fn find_stage_size_m_n(
27    m: usize,
28    n: usize,
29    num_sm: usize,
30    max_tensor_cores: usize,
31    instruction_m: usize,
32    instruction_n: usize,
33    stage_size_k: usize,
34) -> (usize, usize) {
35    let max_tiles_elems_m = 256 / instruction_m;
36    let max_tiles_elems_n = 256 / instruction_n;
37    let max_tiles_total_stage = 16 / stage_size_k;
38
39    let mut dim_num_tiles_m = max_tensor_cores
40        .min(max_tiles_elems_m)
41        .min(max_tiles_total_stage);
42
43    let mut dim_num_tiles_n = max_tensor_cores
44        .min(max_tiles_elems_n)
45        .min(max_tiles_total_stage);
46
47    let total_tiles_m = m.div_ceil(instruction_m);
48    let total_tiles_n = n.div_ceil(instruction_n);
49
50    while total_tiles_n < dim_num_tiles_n && dim_num_tiles_n > 1 {
51        dim_num_tiles_n /= 2;
52    }
53
54    let total_tiles = total_tiles_m * total_tiles_n;
55
56    let mut stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
57    let mut num_cubes_expected = total_tiles.div_ceil(stage_num_tiles);
58
59    // We keep track of two configurations to select the closest to `num_sm`, whether it's a bit over or under
60    let mut previous_dim_num_tiles = dim_num_tiles_m;
61    let mut previous_num_cubes = num_cubes_expected;
62
63    // Refine tensor core usage to stay as close as possible to `num_sm`
64    while num_cubes_expected < num_sm && dim_num_tiles_m > 1 {
65        previous_dim_num_tiles = dim_num_tiles_m;
66        previous_num_cubes = num_cubes_expected;
67
68        // Reduce tensor core usage
69        dim_num_tiles_m = dim_num_tiles_m.div_ceil(2);
70        stage_num_tiles = dim_num_tiles_m * dim_num_tiles_n;
71
72        // Number of cubes grows as a consequence of smaller stage
73        num_cubes_expected = total_tiles.div_ceil(stage_num_tiles);
74    }
75
76    // Compare previous and current values to determine the closest to `num_sm`
77    if (previous_num_cubes as isize - num_sm as isize).abs()
78        <= (num_cubes_expected as isize - num_sm as isize).abs()
79    {
80        (previous_dim_num_tiles, dim_num_tiles_n)
81    } else {
82        (dim_num_tiles_n, dim_num_tiles_m)
83    }
84}
85
86pub fn convolution_matmul_selection<R: Runtime>(
87    tile_matmul: TileMatmulKind,
88    client: &ComputeClient<R>,
89    problem: &ConvolutionProblem,
90    plane_dim: u32,
91    swizzle: bool,
92    vector_sizes: &MatmulVectorSizes,
93    dtypes: &mut MatmulElems,
94) -> Result<TilingBlueprint, MatmulAvailabilityError> {
95    adjust_dtypes(client, dtypes, tile_matmul.requires_accelerator());
96
97    // rough heuristic based on previous bench results where 512 channels with a 3x3 kernel seemed
98    // to be the rough cutoff for the k=4 size.
99    let stage_k = if problem.k >= 4096 { 4 } else { 2 };
100
101    let tile_size = find_instruction_size::<R, _, _>(
102        client,
103        (
104            dtypes.lhs_register,
105            dtypes.rhs_register,
106            dtypes.acc_register,
107        ),
108        (problem.m, problem.n, problem.k).into(),
109        (None, None, None),
110        |c, cfg| tile_matmul.is_supported(c, cfg),
111        |c, l, r, a| tile_matmul.supported_sizes(c, l, r, a),
112    )?;
113
114    let hardware = &client.properties().hardware;
115    let num_sm = hardware
116        .num_streaming_multiprocessors
117        .unwrap_or(NUM_TENSOR_CORES_APPROX);
118    let max_tensor_cores = hardware.num_tensor_cores.unwrap_or(NUM_SM_APPROX);
119
120    let (stage_size_m, stage_size_n) = find_stage_size_m_n(
121        problem.m,
122        problem.n,
123        num_sm as usize,
124        max_tensor_cores as usize,
125        tile_size.m() as usize,
126        tile_size.n() as usize,
127        stage_k as usize,
128    );
129
130    let tiling_scheme = TilingScheme::builder()
131        .with_stage_size((stage_size_m as u32, 1, 1).into())
132        .with_tile_size(tile_size)
133        .with_partition_size((1, stage_size_n as u32, stage_k).into())
134        .build()
135        .unwrap();
136
137    let mut builder = TilingBlueprint::builder(
138        tile_matmul,
139        tiling_scheme,
140        plane_dim,
141        &problem.as_matmul_problem(),
142    )
143    .partition_buffering(PartitionBuffering::Single);
144
145    if swizzle {
146        let swizzle_dim = tiling_scheme.elements_per_stage_along_k() as usize;
147
148        let lhs = select_swizzle(swizzle_dim, dtypes.lhs_stage, vector_sizes.lhs);
149        let rhs = select_swizzle(swizzle_dim, dtypes.rhs_stage, vector_sizes.rhs);
150        builder = builder.shared_swizzle(SwizzleModes {
151            lhs,
152            rhs,
153            ..Default::default()
154        });
155    }
156
157    Ok(builder.build())
158}
159
160/// All modes currently use atom size 16
161const SWIZZLE_ATOM: usize = 16;
162
163fn select_swizzle(swizzle_dim: usize, elem: StorageType, vector_size: VectorSize) -> SwizzleMode {
164    // Vector size exceeds swizzle atom
165    if elem.size() * vector_size > SWIZZLE_ATOM {
166        return SwizzleMode::None;
167    }
168    let swizzle_dim_bytes = swizzle_dim * elem.size();
169    if !swizzle_dim_bytes.is_power_of_two() {
170        return SwizzleMode::None;
171    }
172    match swizzle_dim_bytes {
173        32 => SwizzleMode::B32,
174        64 => SwizzleMode::B64,
175        _ => SwizzleMode::B128,
176        //_ => SwizzleMode::None,
177    }
178}