burn_jit/kernel/conv/conv2d/
implicit_gemm.rs

1use burn_tensor::{
2    ops::{conv::calculate_conv_output_size, ConvOptions},
3    Shape,
4};
5use cmma::{Matrix, MatrixIdent, MatrixLayout};
6use cubecl::{
7    cube,
8    ir::{Elem, FloatKind},
9    prelude::*,
10    Compiler, CubeCount, CubeDim, Feature,
11};
12use half::f16;
13
14use crate::{
15    kernel::{conv::ConvLaunchError, into_contiguous, slice, slice_assign},
16    ops::{
17        numeric::{empty_device, zeros_device},
18        permute,
19    },
20    tensor::JitTensor,
21    FloatElement, JitRuntime,
22};
23
24use super::nchw_to_nhwc;
25
26/// Perform a 2D convolution using the implicit GEMM algorithm. Requires `cmma` to be available.
27///
28/// * `input` - The input feature map
29/// * `weight` - The weights (filter) applied to each kernel
30/// * `bias` - The bias added to each channel
31/// * `options` - The options to use for the convolution
32///
33pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement>(
34    input: JitTensor<R>,
35    weight: JitTensor<R>,
36    bias: Option<JitTensor<R>>,
37    options: ConvOptions<2>,
38) -> Result<JitTensor<R>, ConvLaunchError> {
39    let is_tf32 = F::as_elem_native_unchecked() == Elem::Float(FloatKind::F32)
40        && input
41            .client
42            .properties()
43            .feature_enabled(Feature::Type(Elem::Float(FloatKind::TF32)));
44
45    let k_target = if is_tf32 { 8 } else { 16 };
46
47    let [batch_size, in_channels, height, width] = input.shape.dims();
48    let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims();
49    let (pad_in_channels, pad_kh, pad_kw) = padded_k(in_channels, kernel_h, kernel_w, k_target);
50    let padded_out_channels = out_channels.div_ceil(16) * 16;
51
52    let out_h = calculate_conv_output_size(
53        kernel_h,
54        options.stride[0],
55        options.padding[0],
56        options.dilation[0],
57        height,
58    );
59    let out_w = calculate_conv_output_size(
60        kernel_w,
61        options.stride[1],
62        options.padding[1],
63        options.dilation[1],
64        width,
65    );
66
67    let padded_batch_size = padded_batch_size(batch_size, out_h, out_w);
68
69    if !can_do_implicit_gemm::<R, F>(
70        batch_size,
71        in_channels,
72        out_channels,
73        [kernel_h, kernel_w],
74        options.groups,
75        out_h,
76        out_w,
77        &input.client,
78    ) {
79        panic!(
80            "Requirements for implicit GEMM not met:
81- CMMA must be available
82- `groups` must be 1
83- subcube size must be non-variable (might not hold on Intel)
84        "
85        );
86    }
87
88    // If input is contiguous NCHW, use custom transpose kernel
89    let input = match input.is_contiguous() {
90        true => nchw_to_nhwc::<R, F>(input),
91        false => into_contiguous(permute(input, &[0, 2, 3, 1])),
92    };
93    let weight = into_contiguous(permute(weight, &[2, 3, 1, 0]));
94
95    let out_shape = Shape::new([padded_batch_size, out_h, out_w, padded_out_channels]);
96    let out = empty_device::<R, F>(input.client.clone(), input.device.clone(), out_shape);
97
98    // Implicit GEMM matrix size
99    let gemm_m = (padded_batch_size * out_h * out_w) as u32;
100    let gemm_n = padded_out_channels as u32;
101    let gemm_k = (pad_in_channels * pad_kh * pad_kw) as u32;
102
103    let (cmma_m, cmma_n, cmma_k) =
104        find_cmma_size::<R, F>(&input.client, gemm_m, gemm_k, gemm_n).unwrap();
105
106    let slice_size = pad_kh * pad_kw * pad_in_channels;
107
108    let cube_dim_x = 128;
109    let cube_dim_y = Ord::min(gemm_n.div_ceil(16), 2);
110
111    let input_tile_size = cmma_m * cmma_k;
112    let weight_tile_size = cmma_k * cmma_n;
113
114    let topology = input.client.properties().hardware_properties();
115    let warp_size = topology.plane_size_min;
116    let warps_per_cube = (cube_dim_y * cube_dim_x) / warp_size;
117
118    let supported_vecs = R::supported_line_sizes();
119
120    let input_elems_per_thread = input_tile_size / warp_size;
121    let input_vectorization = find_common_vec(in_channels, input_elems_per_thread, supported_vecs);
122
123    let weight_elems_per_thread = weight_tile_size / warp_size;
124    let weight_vectorization =
125        find_common_vec(out_channels, weight_elems_per_thread, supported_vecs);
126
127    let has_bias = bias.is_some();
128    let bias = match bias {
129        Some(bias) if out_channels == padded_out_channels => bias,
130        Some(bias) => {
131            let shape = Shape::new([padded_out_channels]);
132            let padded_bias = zeros_device::<R, F>(bias.client.clone(), bias.device.clone(), shape);
133            #[allow(clippy::single_range_in_vec_init)]
134            slice_assign::<R, F>(padded_bias, &[0..out_channels], bias)
135        }
136        None => empty_device::<R, F>(input.client.clone(), input.device.clone(), Shape::new([1])),
137    };
138
139    let settings = GemmSettings {
140        cmma_m,
141        cmma_n,
142        cmma_k,
143        check_m: batch_size != padded_batch_size,
144        check_n: out_channels != padded_out_channels,
145        check_k: (kernel_h * kernel_w * in_channels) as u32 != gemm_k,
146        warp_size,
147        warps_per_cube,
148        cube_dim_x,
149    };
150
151    // `CUBE_DIM_X` must be a multiple of `WARP_SIZE`
152    // 128x2 means we have 8 warps and a cube computes a 32x64 output tile
153    let cube_dim = CubeDim {
154        x: cube_dim_x,
155        y: cube_dim_y,
156        z: 1,
157    };
158
159    let cube_count_x = gemm_m.div_ceil(cmma_m * cube_dim_x / warp_size);
160    let cube_count_y = gemm_n.div_ceil(cmma_n * cube_dim_y);
161
162    // If div floor == div ceil then the cubes are aligned with the input dimensions
163    let aligned = gemm_m / (cmma_m * cube_dim_x / warp_size) == cube_count_x
164        && gemm_n / (cmma_n * cube_dim_y) == cube_count_y;
165
166    let cube_count = CubeCount::Static(cube_count_x, cube_count_y, 1);
167
168    let launch = match is_tf32 {
169        false => implicit_gemm_kernel::launch::<F, f16, R>,
170        true => implicit_gemm_kernel::launch::<F, tf32, R>,
171    };
172
173    launch(
174        &input.client,
175        cube_count,
176        cube_dim,
177        input.as_tensor_arg::<F>(input_vectorization),
178        weight.as_tensor_arg::<F>(weight_vectorization),
179        bias.as_tensor_arg::<F>(1),
180        out.as_tensor_arg::<F>(1),
181        DimensionsLaunch::new(
182            ScalarArg::new(gemm_m),
183            ScalarArg::new(gemm_n),
184            ScalarArg::new(gemm_k),
185            ScalarArg::new(slice_size as u32),
186            ScalarArg::new(pad_kw as u32),
187            ScalarArg::new(pad_in_channels as u32),
188            ScalarArg::new(out_h as u32),
189            ScalarArg::new(out_w as u32),
190        ),
191        ConvArgsLaunch::new(
192            ScalarArg::new(options.stride[0] as u32),
193            ScalarArg::new(options.stride[1] as u32),
194            ScalarArg::new(options.padding[0] as i32),
195            ScalarArg::new(options.padding[1] as i32),
196            ScalarArg::new(options.dilation[0] as u32),
197            ScalarArg::new(options.dilation[1] as u32),
198        ),
199        settings,
200        ConvSettings {
201            kernel_h: kernel_h as u32,
202            kernel_w: kernel_w as u32,
203            padding_h: options.padding[0] as i32,
204            padding_w: options.padding[1] as i32,
205            aligned,
206            has_bias,
207        },
208    );
209
210    let out = slice::<R, F>(out, &[0..batch_size, 0..out_h, 0..out_w, 0..out_channels]);
211
212    // Reset to NCHW
213    Ok(permute(out, &[0, 3, 1, 2]))
214}
215
216fn find_common_vec(channels: usize, elems_per_thread: u32, supported_vecs: &[u8]) -> u8 {
217    let channels = channels as u8;
218    let elems_per_thread = elems_per_thread as u8;
219    let smaller = Ord::min(channels, elems_per_thread);
220    (1..=smaller)
221        .rev()
222        .filter(|it| supported_vecs.contains(it))
223        .find(|vec| channels % *vec == 0 && elems_per_thread % *vec == 0)
224        .unwrap_or(1)
225}
226
227#[derive(CubeLaunch)]
228struct ConvArgs {
229    stride_h: u32,
230    stride_w: u32,
231    pad_h: i32,
232    pad_w: i32,
233    dilation_h: u32,
234    dilation_w: u32,
235}
236
237#[derive(CubeLaunch)]
238struct Dimensions {
239    gemm_m: u32,
240    gemm_n: u32,
241    gemm_k: u32,
242    slice_size: u32,
243
244    pad_kw: u32,
245    pad_channels: u32,
246
247    out_h: u32,
248    out_w: u32,
249}
250
251#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
252struct GemmSettings {
253    cmma_m: u32,
254    cmma_n: u32,
255    cmma_k: u32,
256
257    check_m: bool,
258    check_n: bool,
259    check_k: bool,
260
261    warp_size: u32,
262    warps_per_cube: u32,
263
264    cube_dim_x: u32,
265}
266
267#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
268struct ConvSettings {
269    kernel_h: u32,
270    kernel_w: u32,
271    padding_h: i32,
272    padding_w: i32,
273    aligned: bool,
274    has_bias: bool,
275}
276
277#[derive(Clone, Copy, CubeType)]
278struct Positions {
279    global_m: u32,
280    global_n: u32,
281
282    intra_warp_unit_idx: u32,
283    cube_linear_warp_idx: u32,
284}
285
286#[derive(CubeType)]
287struct Matrices<F: Float, FAcc: Float> {
288    a: Matrix<F>,
289    b: Matrix<F>,
290    acc: Matrix<FAcc>,
291}
292
293#[allow(clippy::collapsible_else_if)]
294#[cube(launch)]
295fn implicit_gemm_kernel<F: Float, FMat: Float>(
296    input: &Tensor<Line<F>>,
297    weight: &Tensor<Line<F>>,
298    bias: &Tensor<F>,
299    out: &mut Tensor<F>,
300    dims: &Dimensions,
301    args: &ConvArgs,
302    #[comptime] gemm_settings: GemmSettings,
303    #[comptime] conv_settings: ConvSettings,
304) {
305    let _ = bias[0];
306
307    let GemmSettings {
308        cmma_m,
309        cmma_n,
310        cmma_k,
311        warps_per_cube,
312        ..
313    } = gemm_settings;
314
315    let cmma_out_tile_size = cmma_m * cmma_n;
316    let cmma_input_tile_size = cmma_m * cmma_k;
317    let cmma_filter_tile_size = cmma_k * cmma_n;
318
319    let pos = calculate_positions(gemm_settings);
320
321    let in_vec = input.line_size();
322    let weight_vec = weight.line_size();
323
324    // Shared memory tiles, currently only holds enough data for
325    // each warp to have its own tile for a single MMA op (8 * 16 * 16 elements)
326    // conceptually a WARPS_PER_CUBE x (CMMA_M * CMMA_K) matrix
327    let mut smem_input_tile = SharedMemory::<FMat>::new_lined(
328        comptime!(cmma_input_tile_size * warps_per_cube / in_vec),
329        in_vec,
330    );
331    let mut smem_weight_tile = SharedMemory::<FMat>::new_lined(
332        comptime!(cmma_filter_tile_size * warps_per_cube / weight_vec),
333        weight_vec,
334    );
335
336    let input_tile_start = pos.cube_linear_warp_idx * (cmma_input_tile_size / in_vec);
337    let weight_tile_start = pos.cube_linear_warp_idx * (cmma_filter_tile_size / weight_vec);
338    let mut input_tile =
339        smem_input_tile.slice_mut(input_tile_start, input_tile_start + cmma_input_tile_size);
340    let mut weight_tile =
341        smem_weight_tile.slice_mut(weight_tile_start, weight_tile_start + cmma_filter_tile_size);
342
343    let out_pos = pos.global_n + pos.global_m * dims.gemm_n;
344    let mut out = out.slice_mut(out_pos, out_pos + cmma_out_tile_size);
345
346    if conv_settings.aligned || pos.global_m < dims.gemm_m && pos.global_n < dims.gemm_n {
347        execute_gemm::<F, FMat>(
348            input,
349            weight,
350            bias,
351            &mut out,
352            &mut input_tile,
353            &mut weight_tile,
354            dims,
355            &pos,
356            args,
357            gemm_settings,
358            conv_settings,
359        );
360    }
361}
362
363#[cube]
364fn calculate_positions(#[comptime] gemm_settings: GemmSettings) -> Positions {
365    let GemmSettings {
366        cmma_m,
367        cmma_n,
368        warp_size,
369        cube_dim_x,
370        ..
371    } = gemm_settings;
372
373    // Tile using a 2D grid (over the output), each threadblock
374    // is (128, 2) -> (4,2) = 8 warps -> 32x64 output
375    let global_warp_m = ABSOLUTE_POS_X / warp_size;
376    let global_warp_n = ABSOLUTE_POS_Y;
377    let cube_warp_m = UNIT_POS_X / warp_size;
378    let cube_warp_n = UNIT_POS_Y;
379    let num_warps_m = cube_dim_x / warp_size;
380    let intra_warp_unit_idx = UNIT_POS_X % warp_size; // Thread idx within warp (0 to 31)
381    let cube_linear_warp_idx = (cube_warp_n * num_warps_m) + cube_warp_m; // Warp idx within a block (0 to WARPS_PER_BLOCK - 1)
382
383    Positions {
384        global_m: global_warp_m * cmma_m,
385        global_n: global_warp_n * cmma_n,
386        intra_warp_unit_idx,
387        cube_linear_warp_idx,
388    }
389}
390
391#[cube]
392fn make_matrices<F: Float, FAcc: Float>(
393    #[comptime] gemm_settings: GemmSettings,
394    #[comptime] has_bias: bool,
395) -> Matrices<F, FAcc> {
396    let GemmSettings {
397        cmma_m,
398        cmma_n,
399        cmma_k,
400        ..
401    } = gemm_settings;
402
403    let acc = if has_bias {
404        unsafe {
405            Matrix::<FAcc>::uninitialized(
406                MatrixIdent::Accumulator,
407                cmma_m,
408                cmma_n,
409                cmma_k,
410                MatrixLayout::Undefined,
411            )
412        }
413    } else {
414        Matrix::<FAcc>::from_value(
415            MatrixIdent::Accumulator,
416            cmma_m,
417            cmma_n,
418            cmma_k,
419            MatrixLayout::Undefined,
420            FAcc::new(0.0),
421        )
422    };
423
424    Matrices::<F, FAcc> {
425        a: unsafe {
426            Matrix::<F>::uninitialized(
427                MatrixIdent::A,
428                cmma_m,
429                cmma_n,
430                cmma_k,
431                MatrixLayout::RowMajor,
432            )
433        },
434        b: unsafe {
435            Matrix::<F>::uninitialized(
436                MatrixIdent::B,
437                cmma_m,
438                cmma_n,
439                cmma_k,
440                MatrixLayout::RowMajor,
441            )
442        },
443        acc,
444    }
445}
446
447#[cube]
448fn execute_gemm<F: Float, FMat: Float>(
449    input: &Tensor<Line<F>>,
450    weight: &Tensor<Line<F>>,
451    bias: &Tensor<F>,
452    out: &mut SliceMut<F>,
453    input_tile: &mut SliceMut<Line<FMat>>,
454    weight_tile: &mut SliceMut<Line<FMat>>,
455    dims: &Dimensions,
456    pos: &Positions,
457    args: &ConvArgs,
458    #[comptime] g_settings: GemmSettings,
459    #[comptime] k_settings: ConvSettings,
460) {
461    let GemmSettings { cmma_n, cmma_k, .. } = g_settings;
462    let has_bias = k_settings.has_bias;
463
464    let matrices = make_matrices::<FMat, F>(g_settings, has_bias);
465    if has_bias {
466        let bias_tile = bias.slice(pos.global_n, pos.global_n + cmma_n);
467        cmma::load_with_layout(&matrices.acc, &bias_tile, 0, MatrixLayout::RowMajor);
468    }
469
470    // Loop over the K-dimension
471    for k in range_stepped(0, dims.gemm_k, cmma_k) {
472        // Load into smem...
473        // Each warp should load the 16x16 tile it's responsible for
474        // i.e. each thread needs to load 8 elements of input and 8 elements of weight
475
476        load_input_tile(
477            input, args, input_tile, dims, pos, k, g_settings, k_settings,
478        );
479
480        load_weight_tile(weight, weight_tile, dims, pos, k, g_settings, k_settings);
481
482        // Run CMMA
483        cmma::load(&matrices.b, &weight_tile.to_slice(), cmma_n);
484        cmma::load(&matrices.a, &input_tile.to_slice(), cmma_k);
485
486        cmma::execute::<FMat, FMat, F, F>(&matrices.a, &matrices.b, &matrices.acc, &matrices.acc);
487    }
488
489    cmma::store(out, &matrices.acc, dims.gemm_n, MatrixLayout::RowMajor);
490}
491
492#[cube]
493fn load_input_tile<F: Float, FMat: Float>(
494    input: &Tensor<Line<F>>,
495    args: &ConvArgs,
496    tile: &mut SliceMut<Line<FMat>>,
497    dims: &Dimensions,
498    pos: &Positions,
499    k: u32,
500    #[comptime] gemm_settings: GemmSettings,
501    #[comptime] kernel_settings: ConvSettings,
502) {
503    let GemmSettings {
504        cmma_m,
505        cmma_k,
506        warp_size,
507        check_m,
508        check_k,
509        ..
510    } = gemm_settings;
511
512    let ConvSettings {
513        kernel_w,
514        kernel_h,
515        padding_h,
516        padding_w,
517        ..
518    } = kernel_settings;
519
520    let cmma_input_tile_size = cmma_m * cmma_k;
521    let elems_per_thread = cmma_input_tile_size / warp_size;
522    let vec = input.line_size();
523
524    let height = input.shape(1) as i32;
525    let width = input.shape(2) as i32;
526    let channels = dims.pad_channels;
527
528    // Row strides in the implicit GEMM matrix
529    let batch_stride = dims.out_h * dims.out_w;
530    let y_stride = dims.out_w;
531    let x_stride = 1;
532
533    // Start index within a slice (0 to `kernel_size * channels - 1`) that a half warp (16 units) is responsible for
534    let slice_start_idx = k % dims.slice_size;
535    let start = pos.intra_warp_unit_idx * elems_per_thread;
536
537    let rel_slice_row = start / cmma_k; // Relative row (0 - 15)
538    let abs_slice_row = pos.global_m + rel_slice_row; // Row of the matrix the slice is on
539
540    // Given the row of the matrix that the slice is in, and the index of the thread
541    // within a slice, want to compute what input element to load...
542    // first compute coordinates in output space (center of the kernel in MxK matrix A)
543    let batch = abs_slice_row / batch_stride;
544
545    let m_in_bounds = !check_m || batch < input.shape(0);
546    let out_y = (abs_slice_row % batch_stride) / y_stride;
547    let out_x = ((abs_slice_row % batch_stride) % y_stride) / x_stride;
548
549    #[unroll]
550    for m in range_stepped(0, elems_per_thread, vec) {
551        let m = m + start;
552        // Compute where in the slice we are starting
553
554        // Slices are always `kernel_size * channels` elements wide so we can compute where inside a slice
555        // we are and also which row the slice is in relative to the start of the CMMA matrix
556
557        // Actual index within a slice (0 to `kernel_size * channels - 1`) that the thread is responsible for
558        let my_slice_idx = (slice_start_idx + (m % cmma_k)) % dims.slice_size;
559
560        let channel = my_slice_idx % channels;
561
562        let kernel_x = (my_slice_idx / channels) % dims.pad_kw;
563        let kernel_y = my_slice_idx / (channels * dims.pad_kw);
564
565        let k_in_bounds =
566            !check_k || (channel < input.shape(3) && kernel_x < kernel_w && kernel_y < kernel_h);
567
568        let y = (out_y * args.stride_h + kernel_y * args.dilation_h) as i32 - padding_h;
569        let x = (out_x * args.stride_w + kernel_x * args.dilation_w) as i32 - padding_w;
570        let in_bounds =
571            (padding_h == 0 && padding_w == 0) || (x >= 0 && x < width && y >= 0 && y < height);
572        let idx = batch * input.stride(0)
573            + y as u32 * input.stride(1)
574            + x as u32 * input.stride(2)
575            + channel;
576        let value = select(
577            in_bounds && m_in_bounds && k_in_bounds,
578            Line::cast_from(input[idx / vec]),
579            Line::new(FMat::new(0.0)),
580        );
581
582        tile[m / vec] = value;
583    }
584}
585
586#[cube]
587fn load_weight_tile<F: Float, FMat: Float>(
588    weight: &Tensor<Line<F>>,
589    tile: &mut SliceMut<Line<FMat>>,
590    dims: &Dimensions,
591    pos: &Positions,
592    k: u32,
593    #[comptime] gemm_settings: GemmSettings,
594    #[comptime] kernel_settings: ConvSettings,
595) {
596    let GemmSettings {
597        cmma_n,
598        cmma_k,
599        warp_size,
600        check_n,
601        check_k,
602        ..
603    } = gemm_settings;
604
605    let ConvSettings {
606        kernel_w, kernel_h, ..
607    } = kernel_settings;
608
609    let vec = weight.line_size();
610    let cmma_filter_tile_size = cmma_k * cmma_n;
611    let elems_per_thread = cmma_filter_tile_size / warp_size;
612    let start = pos.intra_warp_unit_idx * elems_per_thread;
613
614    let global_k = start / cmma_n + k;
615
616    let (k_idx, k_in_bounds) = if check_k {
617        let channel = global_k % dims.pad_channels;
618        let kernel_x = global_k / dims.pad_channels % dims.pad_kw;
619        let kernel_y = global_k / (dims.pad_channels * dims.pad_kw);
620        let k_in_bounds =
621            !check_k || (channel < weight.shape(2) && kernel_x < kernel_w && kernel_y < kernel_h);
622        let idx =
623            kernel_y * weight.stride(0) + kernel_x * weight.stride(1) + channel * weight.stride(2);
624        (idx, k_in_bounds)
625    } else {
626        (global_k * weight.stride(2), true)
627    };
628
629    #[unroll]
630    for n in range_stepped(0, elems_per_thread, vec) {
631        let n = n + start;
632
633        let global_n = (n % cmma_n) + pos.global_n;
634        let n_in_bounds = !check_n || global_n < weight.shape(3);
635
636        let idx = k_idx + global_n;
637
638        let value = Line::cast_from(weight[idx / vec]);
639        let value = select(k_in_bounds && n_in_bounds, value, Line::new(FMat::new(0.0)));
640
641        tile[n / vec] = value;
642    }
643}
644
645#[allow(clippy::too_many_arguments)]
646pub(crate) fn can_do_implicit_gemm<R: JitRuntime, E: FloatElement>(
647    batch_size: usize,
648    in_channels: usize,
649    out_channels: usize,
650    kernel_size: [usize; 2],
651    groups: usize,
652    out_h: usize,
653    out_w: usize,
654    client: &ComputeClient<R::Server, R::Channel>,
655) -> bool {
656    let cmma_k = match (
657        E::as_elem_native_unchecked(),
658        client
659            .properties()
660            .feature_enabled(Feature::Type(tf32::as_elem_native_unchecked())),
661    ) {
662        (Elem::Float(FloatKind::F32), true) => 8,
663        _ => 16,
664    };
665
666    let (in_channels, kernel_h, kernel_w) =
667        padded_k(in_channels, kernel_size[0], kernel_size[1], cmma_k);
668    let batch_size = padded_batch_size(batch_size, out_h, out_w);
669    let out_channels = out_channels.div_ceil(16) * 16;
670
671    let gemm_m = batch_size * out_h * out_w;
672    let gemm_n = out_channels;
673    let gemm_k = in_channels * kernel_h * kernel_w;
674
675    let size = find_cmma_size::<R, E>(client, gemm_m as u32, gemm_k as u32, gemm_n as u32);
676
677    if let Some((cmma_m, cmma_k, cmma_n)) = size {
678        let warps_per_cube = 8;
679
680        let smem_size = ((cmma_m + cmma_n) * cmma_k * warps_per_cube) as usize * size_of::<f16>();
681        let topology = client.properties().hardware_properties();
682        let not_intel = topology.plane_size_min >= 32;
683
684        <R::Compiler as Compiler>::max_shared_memory_size() >= smem_size && groups == 1 && not_intel
685    } else {
686        false
687    }
688}
689
690fn padded_k(
691    in_channels: usize,
692    kernel_h: usize,
693    kernel_w: usize,
694    target: usize,
695) -> (usize, usize, usize) {
696    if in_channels * kernel_h * kernel_w % target == 0 {
697        return (in_channels, kernel_h, kernel_w);
698    }
699    let kernel_h = kernel_h.next_power_of_two();
700    let target = target.div_ceil(kernel_h);
701    if in_channels * kernel_w % target == 0 {
702        return (in_channels, kernel_h, kernel_w);
703    }
704    let kernel_w = kernel_w.next_power_of_two();
705    let target = target.div_ceil(kernel_w);
706    if in_channels % target == 0 {
707        return (in_channels, kernel_h, kernel_w);
708    }
709    let in_channels = in_channels.div_ceil(target) * target;
710    (in_channels, kernel_h, kernel_w)
711}
712
713fn padded_batch_size(batch_size: usize, out_h: usize, out_w: usize) -> usize {
714    let out_size = out_h * out_w;
715    let target = if out_size.is_power_of_two() || out_size % 16 == 0 {
716        (16usize).div_ceil(out_size)
717    } else {
718        16
719    };
720    batch_size.div_ceil(target) * target
721}
722
723fn find_cmma_size<R: JitRuntime, F: Float>(
724    client: &ComputeClient<R::Server, R::Channel>,
725    gemm_m: u32,
726    gemm_k: u32,
727    gemm_n: u32,
728) -> Option<(u32, u32, u32)> {
729    supported_cmma_sizes::<R, F>(client)
730        .into_iter()
731        .find(|(m, k, n)| {
732            gemm_m % *m as u32 == 0 && gemm_k % *k as u32 == 0 && gemm_n % *n as u32 == 0
733        })
734        .map(|(m, k, n)| (m as u32, n as u32, k as u32))
735}
736
737fn supported_cmma_sizes<R: JitRuntime, F: Float>(
738    client: &ComputeClient<R::Server, R::Channel>,
739) -> Vec<(u8, u8, u8)> {
740    let (requested_sizes, matrix_elem) = match (
741        F::as_elem_native_unchecked(),
742        client
743            .properties()
744            .feature_enabled(Feature::Type(tf32::as_elem_native_unchecked())),
745    ) {
746        (Elem::Float(FloatKind::F32), true) => {
747            (vec![(16, 8, 16)], tf32::as_elem_native_unchecked())
748        }
749        _ => (
750            vec![(16, 16, 16), (32, 16, 8), (8, 16, 32)],
751            f16::as_elem_native_unchecked(),
752        ),
753    };
754
755    requested_sizes
756        .iter()
757        .copied()
758        .filter(|(m, k, n)| {
759            client.properties().feature_enabled(Feature::Cmma {
760                a: matrix_elem,
761                b: matrix_elem,
762                c: F::as_elem_native_unchecked(),
763                m: *m,
764                k: *k,
765                n: *n,
766            })
767        })
768        .collect()
769}