1use burn_tensor::{
2 ops::{conv::calculate_conv_output_size, ConvOptions},
3 Shape,
4};
5use cmma::{Matrix, MatrixIdent, MatrixLayout};
6use cubecl::{
7 cube,
8 ir::{Elem, FloatKind},
9 prelude::*,
10 Compiler, CubeCount, CubeDim, Feature,
11};
12use half::f16;
13
14use crate::{
15 kernel::{conv::ConvLaunchError, into_contiguous, slice, slice_assign},
16 ops::{
17 numeric::{empty_device, zeros_device},
18 permute,
19 },
20 tensor::JitTensor,
21 FloatElement, JitRuntime,
22};
23
24use super::nchw_to_nhwc;
25
26pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement>(
34 input: JitTensor<R>,
35 weight: JitTensor<R>,
36 bias: Option<JitTensor<R>>,
37 options: ConvOptions<2>,
38) -> Result<JitTensor<R>, ConvLaunchError> {
39 let is_tf32 = F::as_elem_native_unchecked() == Elem::Float(FloatKind::F32)
40 && input
41 .client
42 .properties()
43 .feature_enabled(Feature::Type(Elem::Float(FloatKind::TF32)));
44
45 let k_target = if is_tf32 { 8 } else { 16 };
46
47 let [batch_size, in_channels, height, width] = input.shape.dims();
48 let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims();
49 let (pad_in_channels, pad_kh, pad_kw) = padded_k(in_channels, kernel_h, kernel_w, k_target);
50 let padded_out_channels = out_channels.div_ceil(16) * 16;
51
52 let out_h = calculate_conv_output_size(
53 kernel_h,
54 options.stride[0],
55 options.padding[0],
56 options.dilation[0],
57 height,
58 );
59 let out_w = calculate_conv_output_size(
60 kernel_w,
61 options.stride[1],
62 options.padding[1],
63 options.dilation[1],
64 width,
65 );
66
67 let padded_batch_size = padded_batch_size(batch_size, out_h, out_w);
68
69 if !can_do_implicit_gemm::<R, F>(
70 batch_size,
71 in_channels,
72 out_channels,
73 [kernel_h, kernel_w],
74 options.groups,
75 out_h,
76 out_w,
77 &input.client,
78 ) {
79 panic!(
80 "Requirements for implicit GEMM not met:
81- CMMA must be available
82- `groups` must be 1
83- subcube size must be non-variable (might not hold on Intel)
84 "
85 );
86 }
87
88 let input = match input.is_contiguous() {
90 true => nchw_to_nhwc::<R, F>(input),
91 false => into_contiguous(permute(input, &[0, 2, 3, 1])),
92 };
93 let weight = into_contiguous(permute(weight, &[2, 3, 1, 0]));
94
95 let out_shape = Shape::new([padded_batch_size, out_h, out_w, padded_out_channels]);
96 let out = empty_device::<R, F>(input.client.clone(), input.device.clone(), out_shape);
97
98 let gemm_m = (padded_batch_size * out_h * out_w) as u32;
100 let gemm_n = padded_out_channels as u32;
101 let gemm_k = (pad_in_channels * pad_kh * pad_kw) as u32;
102
103 let (cmma_m, cmma_n, cmma_k) =
104 find_cmma_size::<R, F>(&input.client, gemm_m, gemm_k, gemm_n).unwrap();
105
106 let slice_size = pad_kh * pad_kw * pad_in_channels;
107
108 let cube_dim_x = 128;
109 let cube_dim_y = Ord::min(gemm_n.div_ceil(16), 2);
110
111 let input_tile_size = cmma_m * cmma_k;
112 let weight_tile_size = cmma_k * cmma_n;
113
114 let topology = input.client.properties().hardware_properties();
115 let warp_size = topology.plane_size_min;
116 let warps_per_cube = (cube_dim_y * cube_dim_x) / warp_size;
117
118 let supported_vecs = R::supported_line_sizes();
119
120 let input_elems_per_thread = input_tile_size / warp_size;
121 let input_vectorization = find_common_vec(in_channels, input_elems_per_thread, supported_vecs);
122
123 let weight_elems_per_thread = weight_tile_size / warp_size;
124 let weight_vectorization =
125 find_common_vec(out_channels, weight_elems_per_thread, supported_vecs);
126
127 let has_bias = bias.is_some();
128 let bias = match bias {
129 Some(bias) if out_channels == padded_out_channels => bias,
130 Some(bias) => {
131 let shape = Shape::new([padded_out_channels]);
132 let padded_bias = zeros_device::<R, F>(bias.client.clone(), bias.device.clone(), shape);
133 #[allow(clippy::single_range_in_vec_init)]
134 slice_assign::<R, F>(padded_bias, &[0..out_channels], bias)
135 }
136 None => empty_device::<R, F>(input.client.clone(), input.device.clone(), Shape::new([1])),
137 };
138
139 let settings = GemmSettings {
140 cmma_m,
141 cmma_n,
142 cmma_k,
143 check_m: batch_size != padded_batch_size,
144 check_n: out_channels != padded_out_channels,
145 check_k: (kernel_h * kernel_w * in_channels) as u32 != gemm_k,
146 warp_size,
147 warps_per_cube,
148 cube_dim_x,
149 };
150
151 let cube_dim = CubeDim {
154 x: cube_dim_x,
155 y: cube_dim_y,
156 z: 1,
157 };
158
159 let cube_count_x = gemm_m.div_ceil(cmma_m * cube_dim_x / warp_size);
160 let cube_count_y = gemm_n.div_ceil(cmma_n * cube_dim_y);
161
162 let aligned = gemm_m / (cmma_m * cube_dim_x / warp_size) == cube_count_x
164 && gemm_n / (cmma_n * cube_dim_y) == cube_count_y;
165
166 let cube_count = CubeCount::Static(cube_count_x, cube_count_y, 1);
167
168 let launch = match is_tf32 {
169 false => implicit_gemm_kernel::launch::<F, f16, R>,
170 true => implicit_gemm_kernel::launch::<F, tf32, R>,
171 };
172
173 launch(
174 &input.client,
175 cube_count,
176 cube_dim,
177 input.as_tensor_arg::<F>(input_vectorization),
178 weight.as_tensor_arg::<F>(weight_vectorization),
179 bias.as_tensor_arg::<F>(1),
180 out.as_tensor_arg::<F>(1),
181 DimensionsLaunch::new(
182 ScalarArg::new(gemm_m),
183 ScalarArg::new(gemm_n),
184 ScalarArg::new(gemm_k),
185 ScalarArg::new(slice_size as u32),
186 ScalarArg::new(pad_kw as u32),
187 ScalarArg::new(pad_in_channels as u32),
188 ScalarArg::new(out_h as u32),
189 ScalarArg::new(out_w as u32),
190 ),
191 ConvArgsLaunch::new(
192 ScalarArg::new(options.stride[0] as u32),
193 ScalarArg::new(options.stride[1] as u32),
194 ScalarArg::new(options.padding[0] as i32),
195 ScalarArg::new(options.padding[1] as i32),
196 ScalarArg::new(options.dilation[0] as u32),
197 ScalarArg::new(options.dilation[1] as u32),
198 ),
199 settings,
200 ConvSettings {
201 kernel_h: kernel_h as u32,
202 kernel_w: kernel_w as u32,
203 padding_h: options.padding[0] as i32,
204 padding_w: options.padding[1] as i32,
205 aligned,
206 has_bias,
207 },
208 );
209
210 let out = slice::<R, F>(out, &[0..batch_size, 0..out_h, 0..out_w, 0..out_channels]);
211
212 Ok(permute(out, &[0, 3, 1, 2]))
214}
215
216fn find_common_vec(channels: usize, elems_per_thread: u32, supported_vecs: &[u8]) -> u8 {
217 let channels = channels as u8;
218 let elems_per_thread = elems_per_thread as u8;
219 let smaller = Ord::min(channels, elems_per_thread);
220 (1..=smaller)
221 .rev()
222 .filter(|it| supported_vecs.contains(it))
223 .find(|vec| channels % *vec == 0 && elems_per_thread % *vec == 0)
224 .unwrap_or(1)
225}
226
227#[derive(CubeLaunch)]
228struct ConvArgs {
229 stride_h: u32,
230 stride_w: u32,
231 pad_h: i32,
232 pad_w: i32,
233 dilation_h: u32,
234 dilation_w: u32,
235}
236
237#[derive(CubeLaunch)]
238struct Dimensions {
239 gemm_m: u32,
240 gemm_n: u32,
241 gemm_k: u32,
242 slice_size: u32,
243
244 pad_kw: u32,
245 pad_channels: u32,
246
247 out_h: u32,
248 out_w: u32,
249}
250
251#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
252struct GemmSettings {
253 cmma_m: u32,
254 cmma_n: u32,
255 cmma_k: u32,
256
257 check_m: bool,
258 check_n: bool,
259 check_k: bool,
260
261 warp_size: u32,
262 warps_per_cube: u32,
263
264 cube_dim_x: u32,
265}
266
267#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
268struct ConvSettings {
269 kernel_h: u32,
270 kernel_w: u32,
271 padding_h: i32,
272 padding_w: i32,
273 aligned: bool,
274 has_bias: bool,
275}
276
277#[derive(Clone, Copy, CubeType)]
278struct Positions {
279 global_m: u32,
280 global_n: u32,
281
282 intra_warp_unit_idx: u32,
283 cube_linear_warp_idx: u32,
284}
285
286#[derive(CubeType)]
287struct Matrices<F: Float, FAcc: Float> {
288 a: Matrix<F>,
289 b: Matrix<F>,
290 acc: Matrix<FAcc>,
291}
292
293#[allow(clippy::collapsible_else_if)]
294#[cube(launch)]
295fn implicit_gemm_kernel<F: Float, FMat: Float>(
296 input: &Tensor<Line<F>>,
297 weight: &Tensor<Line<F>>,
298 bias: &Tensor<F>,
299 out: &mut Tensor<F>,
300 dims: &Dimensions,
301 args: &ConvArgs,
302 #[comptime] gemm_settings: GemmSettings,
303 #[comptime] conv_settings: ConvSettings,
304) {
305 let _ = bias[0];
306
307 let GemmSettings {
308 cmma_m,
309 cmma_n,
310 cmma_k,
311 warps_per_cube,
312 ..
313 } = gemm_settings;
314
315 let cmma_out_tile_size = cmma_m * cmma_n;
316 let cmma_input_tile_size = cmma_m * cmma_k;
317 let cmma_filter_tile_size = cmma_k * cmma_n;
318
319 let pos = calculate_positions(gemm_settings);
320
321 let in_vec = input.line_size();
322 let weight_vec = weight.line_size();
323
324 let mut smem_input_tile = SharedMemory::<FMat>::new_lined(
328 comptime!(cmma_input_tile_size * warps_per_cube / in_vec),
329 in_vec,
330 );
331 let mut smem_weight_tile = SharedMemory::<FMat>::new_lined(
332 comptime!(cmma_filter_tile_size * warps_per_cube / weight_vec),
333 weight_vec,
334 );
335
336 let input_tile_start = pos.cube_linear_warp_idx * (cmma_input_tile_size / in_vec);
337 let weight_tile_start = pos.cube_linear_warp_idx * (cmma_filter_tile_size / weight_vec);
338 let mut input_tile =
339 smem_input_tile.slice_mut(input_tile_start, input_tile_start + cmma_input_tile_size);
340 let mut weight_tile =
341 smem_weight_tile.slice_mut(weight_tile_start, weight_tile_start + cmma_filter_tile_size);
342
343 let out_pos = pos.global_n + pos.global_m * dims.gemm_n;
344 let mut out = out.slice_mut(out_pos, out_pos + cmma_out_tile_size);
345
346 if conv_settings.aligned || pos.global_m < dims.gemm_m && pos.global_n < dims.gemm_n {
347 execute_gemm::<F, FMat>(
348 input,
349 weight,
350 bias,
351 &mut out,
352 &mut input_tile,
353 &mut weight_tile,
354 dims,
355 &pos,
356 args,
357 gemm_settings,
358 conv_settings,
359 );
360 }
361}
362
363#[cube]
364fn calculate_positions(#[comptime] gemm_settings: GemmSettings) -> Positions {
365 let GemmSettings {
366 cmma_m,
367 cmma_n,
368 warp_size,
369 cube_dim_x,
370 ..
371 } = gemm_settings;
372
373 let global_warp_m = ABSOLUTE_POS_X / warp_size;
376 let global_warp_n = ABSOLUTE_POS_Y;
377 let cube_warp_m = UNIT_POS_X / warp_size;
378 let cube_warp_n = UNIT_POS_Y;
379 let num_warps_m = cube_dim_x / warp_size;
380 let intra_warp_unit_idx = UNIT_POS_X % warp_size; let cube_linear_warp_idx = (cube_warp_n * num_warps_m) + cube_warp_m; Positions {
384 global_m: global_warp_m * cmma_m,
385 global_n: global_warp_n * cmma_n,
386 intra_warp_unit_idx,
387 cube_linear_warp_idx,
388 }
389}
390
391#[cube]
392fn make_matrices<F: Float, FAcc: Float>(
393 #[comptime] gemm_settings: GemmSettings,
394 #[comptime] has_bias: bool,
395) -> Matrices<F, FAcc> {
396 let GemmSettings {
397 cmma_m,
398 cmma_n,
399 cmma_k,
400 ..
401 } = gemm_settings;
402
403 let acc = if has_bias {
404 unsafe {
405 Matrix::<FAcc>::uninitialized(
406 MatrixIdent::Accumulator,
407 cmma_m,
408 cmma_n,
409 cmma_k,
410 MatrixLayout::Undefined,
411 )
412 }
413 } else {
414 Matrix::<FAcc>::from_value(
415 MatrixIdent::Accumulator,
416 cmma_m,
417 cmma_n,
418 cmma_k,
419 MatrixLayout::Undefined,
420 FAcc::new(0.0),
421 )
422 };
423
424 Matrices::<F, FAcc> {
425 a: unsafe {
426 Matrix::<F>::uninitialized(
427 MatrixIdent::A,
428 cmma_m,
429 cmma_n,
430 cmma_k,
431 MatrixLayout::RowMajor,
432 )
433 },
434 b: unsafe {
435 Matrix::<F>::uninitialized(
436 MatrixIdent::B,
437 cmma_m,
438 cmma_n,
439 cmma_k,
440 MatrixLayout::RowMajor,
441 )
442 },
443 acc,
444 }
445}
446
447#[cube]
448fn execute_gemm<F: Float, FMat: Float>(
449 input: &Tensor<Line<F>>,
450 weight: &Tensor<Line<F>>,
451 bias: &Tensor<F>,
452 out: &mut SliceMut<F>,
453 input_tile: &mut SliceMut<Line<FMat>>,
454 weight_tile: &mut SliceMut<Line<FMat>>,
455 dims: &Dimensions,
456 pos: &Positions,
457 args: &ConvArgs,
458 #[comptime] g_settings: GemmSettings,
459 #[comptime] k_settings: ConvSettings,
460) {
461 let GemmSettings { cmma_n, cmma_k, .. } = g_settings;
462 let has_bias = k_settings.has_bias;
463
464 let matrices = make_matrices::<FMat, F>(g_settings, has_bias);
465 if has_bias {
466 let bias_tile = bias.slice(pos.global_n, pos.global_n + cmma_n);
467 cmma::load_with_layout(&matrices.acc, &bias_tile, 0, MatrixLayout::RowMajor);
468 }
469
470 for k in range_stepped(0, dims.gemm_k, cmma_k) {
472 load_input_tile(
477 input, args, input_tile, dims, pos, k, g_settings, k_settings,
478 );
479
480 load_weight_tile(weight, weight_tile, dims, pos, k, g_settings, k_settings);
481
482 cmma::load(&matrices.b, &weight_tile.to_slice(), cmma_n);
484 cmma::load(&matrices.a, &input_tile.to_slice(), cmma_k);
485
486 cmma::execute::<FMat, FMat, F, F>(&matrices.a, &matrices.b, &matrices.acc, &matrices.acc);
487 }
488
489 cmma::store(out, &matrices.acc, dims.gemm_n, MatrixLayout::RowMajor);
490}
491
492#[cube]
493fn load_input_tile<F: Float, FMat: Float>(
494 input: &Tensor<Line<F>>,
495 args: &ConvArgs,
496 tile: &mut SliceMut<Line<FMat>>,
497 dims: &Dimensions,
498 pos: &Positions,
499 k: u32,
500 #[comptime] gemm_settings: GemmSettings,
501 #[comptime] kernel_settings: ConvSettings,
502) {
503 let GemmSettings {
504 cmma_m,
505 cmma_k,
506 warp_size,
507 check_m,
508 check_k,
509 ..
510 } = gemm_settings;
511
512 let ConvSettings {
513 kernel_w,
514 kernel_h,
515 padding_h,
516 padding_w,
517 ..
518 } = kernel_settings;
519
520 let cmma_input_tile_size = cmma_m * cmma_k;
521 let elems_per_thread = cmma_input_tile_size / warp_size;
522 let vec = input.line_size();
523
524 let height = input.shape(1) as i32;
525 let width = input.shape(2) as i32;
526 let channels = dims.pad_channels;
527
528 let batch_stride = dims.out_h * dims.out_w;
530 let y_stride = dims.out_w;
531 let x_stride = 1;
532
533 let slice_start_idx = k % dims.slice_size;
535 let start = pos.intra_warp_unit_idx * elems_per_thread;
536
537 let rel_slice_row = start / cmma_k; let abs_slice_row = pos.global_m + rel_slice_row; let batch = abs_slice_row / batch_stride;
544
545 let m_in_bounds = !check_m || batch < input.shape(0);
546 let out_y = (abs_slice_row % batch_stride) / y_stride;
547 let out_x = ((abs_slice_row % batch_stride) % y_stride) / x_stride;
548
549 #[unroll]
550 for m in range_stepped(0, elems_per_thread, vec) {
551 let m = m + start;
552 let my_slice_idx = (slice_start_idx + (m % cmma_k)) % dims.slice_size;
559
560 let channel = my_slice_idx % channels;
561
562 let kernel_x = (my_slice_idx / channels) % dims.pad_kw;
563 let kernel_y = my_slice_idx / (channels * dims.pad_kw);
564
565 let k_in_bounds =
566 !check_k || (channel < input.shape(3) && kernel_x < kernel_w && kernel_y < kernel_h);
567
568 let y = (out_y * args.stride_h + kernel_y * args.dilation_h) as i32 - padding_h;
569 let x = (out_x * args.stride_w + kernel_x * args.dilation_w) as i32 - padding_w;
570 let in_bounds =
571 (padding_h == 0 && padding_w == 0) || (x >= 0 && x < width && y >= 0 && y < height);
572 let idx = batch * input.stride(0)
573 + y as u32 * input.stride(1)
574 + x as u32 * input.stride(2)
575 + channel;
576 let value = select(
577 in_bounds && m_in_bounds && k_in_bounds,
578 Line::cast_from(input[idx / vec]),
579 Line::new(FMat::new(0.0)),
580 );
581
582 tile[m / vec] = value;
583 }
584}
585
586#[cube]
587fn load_weight_tile<F: Float, FMat: Float>(
588 weight: &Tensor<Line<F>>,
589 tile: &mut SliceMut<Line<FMat>>,
590 dims: &Dimensions,
591 pos: &Positions,
592 k: u32,
593 #[comptime] gemm_settings: GemmSettings,
594 #[comptime] kernel_settings: ConvSettings,
595) {
596 let GemmSettings {
597 cmma_n,
598 cmma_k,
599 warp_size,
600 check_n,
601 check_k,
602 ..
603 } = gemm_settings;
604
605 let ConvSettings {
606 kernel_w, kernel_h, ..
607 } = kernel_settings;
608
609 let vec = weight.line_size();
610 let cmma_filter_tile_size = cmma_k * cmma_n;
611 let elems_per_thread = cmma_filter_tile_size / warp_size;
612 let start = pos.intra_warp_unit_idx * elems_per_thread;
613
614 let global_k = start / cmma_n + k;
615
616 let (k_idx, k_in_bounds) = if check_k {
617 let channel = global_k % dims.pad_channels;
618 let kernel_x = global_k / dims.pad_channels % dims.pad_kw;
619 let kernel_y = global_k / (dims.pad_channels * dims.pad_kw);
620 let k_in_bounds =
621 !check_k || (channel < weight.shape(2) && kernel_x < kernel_w && kernel_y < kernel_h);
622 let idx =
623 kernel_y * weight.stride(0) + kernel_x * weight.stride(1) + channel * weight.stride(2);
624 (idx, k_in_bounds)
625 } else {
626 (global_k * weight.stride(2), true)
627 };
628
629 #[unroll]
630 for n in range_stepped(0, elems_per_thread, vec) {
631 let n = n + start;
632
633 let global_n = (n % cmma_n) + pos.global_n;
634 let n_in_bounds = !check_n || global_n < weight.shape(3);
635
636 let idx = k_idx + global_n;
637
638 let value = Line::cast_from(weight[idx / vec]);
639 let value = select(k_in_bounds && n_in_bounds, value, Line::new(FMat::new(0.0)));
640
641 tile[n / vec] = value;
642 }
643}
644
645#[allow(clippy::too_many_arguments)]
646pub(crate) fn can_do_implicit_gemm<R: JitRuntime, E: FloatElement>(
647 batch_size: usize,
648 in_channels: usize,
649 out_channels: usize,
650 kernel_size: [usize; 2],
651 groups: usize,
652 out_h: usize,
653 out_w: usize,
654 client: &ComputeClient<R::Server, R::Channel>,
655) -> bool {
656 let cmma_k = match (
657 E::as_elem_native_unchecked(),
658 client
659 .properties()
660 .feature_enabled(Feature::Type(tf32::as_elem_native_unchecked())),
661 ) {
662 (Elem::Float(FloatKind::F32), true) => 8,
663 _ => 16,
664 };
665
666 let (in_channels, kernel_h, kernel_w) =
667 padded_k(in_channels, kernel_size[0], kernel_size[1], cmma_k);
668 let batch_size = padded_batch_size(batch_size, out_h, out_w);
669 let out_channels = out_channels.div_ceil(16) * 16;
670
671 let gemm_m = batch_size * out_h * out_w;
672 let gemm_n = out_channels;
673 let gemm_k = in_channels * kernel_h * kernel_w;
674
675 let size = find_cmma_size::<R, E>(client, gemm_m as u32, gemm_k as u32, gemm_n as u32);
676
677 if let Some((cmma_m, cmma_k, cmma_n)) = size {
678 let warps_per_cube = 8;
679
680 let smem_size = ((cmma_m + cmma_n) * cmma_k * warps_per_cube) as usize * size_of::<f16>();
681 let topology = client.properties().hardware_properties();
682 let not_intel = topology.plane_size_min >= 32;
683
684 <R::Compiler as Compiler>::max_shared_memory_size() >= smem_size && groups == 1 && not_intel
685 } else {
686 false
687 }
688}
689
690fn padded_k(
691 in_channels: usize,
692 kernel_h: usize,
693 kernel_w: usize,
694 target: usize,
695) -> (usize, usize, usize) {
696 if in_channels * kernel_h * kernel_w % target == 0 {
697 return (in_channels, kernel_h, kernel_w);
698 }
699 let kernel_h = kernel_h.next_power_of_two();
700 let target = target.div_ceil(kernel_h);
701 if in_channels * kernel_w % target == 0 {
702 return (in_channels, kernel_h, kernel_w);
703 }
704 let kernel_w = kernel_w.next_power_of_two();
705 let target = target.div_ceil(kernel_w);
706 if in_channels % target == 0 {
707 return (in_channels, kernel_h, kernel_w);
708 }
709 let in_channels = in_channels.div_ceil(target) * target;
710 (in_channels, kernel_h, kernel_w)
711}
712
713fn padded_batch_size(batch_size: usize, out_h: usize, out_w: usize) -> usize {
714 let out_size = out_h * out_w;
715 let target = if out_size.is_power_of_two() || out_size % 16 == 0 {
716 (16usize).div_ceil(out_size)
717 } else {
718 16
719 };
720 batch_size.div_ceil(target) * target
721}
722
723fn find_cmma_size<R: JitRuntime, F: Float>(
724 client: &ComputeClient<R::Server, R::Channel>,
725 gemm_m: u32,
726 gemm_k: u32,
727 gemm_n: u32,
728) -> Option<(u32, u32, u32)> {
729 supported_cmma_sizes::<R, F>(client)
730 .into_iter()
731 .find(|(m, k, n)| {
732 gemm_m % *m as u32 == 0 && gemm_k % *k as u32 == 0 && gemm_n % *n as u32 == 0
733 })
734 .map(|(m, k, n)| (m as u32, n as u32, k as u32))
735}
736
737fn supported_cmma_sizes<R: JitRuntime, F: Float>(
738 client: &ComputeClient<R::Server, R::Channel>,
739) -> Vec<(u8, u8, u8)> {
740 let (requested_sizes, matrix_elem) = match (
741 F::as_elem_native_unchecked(),
742 client
743 .properties()
744 .feature_enabled(Feature::Type(tf32::as_elem_native_unchecked())),
745 ) {
746 (Elem::Float(FloatKind::F32), true) => {
747 (vec![(16, 8, 16)], tf32::as_elem_native_unchecked())
748 }
749 _ => (
750 vec![(16, 16, 16), (32, 16, 8), (8, 16, 32)],
751 f16::as_elem_native_unchecked(),
752 ),
753 };
754
755 requested_sizes
756 .iter()
757 .copied()
758 .filter(|(m, k, n)| {
759 client.properties().feature_enabled(Feature::Cmma {
760 a: matrix_elem,
761 b: matrix_elem,
762 c: F::as_elem_native_unchecked(),
763 m: *m,
764 k: *k,
765 n: *n,
766 })
767 })
768 .collect()
769}