cubecl_convolution/
args.rs1use std::any::TypeId;
2
3use cubecl::prelude::*;
4use cubecl_core as cubecl;
5
6use crate::algorithm::simple_tma::{calculate_lower_corner, calculate_upper_corner};
7use cubecl_matmul::components::{
8 MatmulLineSizes, MatmulSelection,
9 global::args::{TensorInputs, TensorInputsLaunch, TensorMapInputs, TensorMapInputsLaunch},
10};
11
12use super::base::ConvolutionProblem;
13
14pub trait ConvInputsLaunch: LaunchArg {
15 fn create<'a, R: Runtime>(
16 lhs: &'a TensorHandleRef<'a, R>,
17 rhs: &'a TensorHandleRef<'a, R>,
18 selection: &MatmulSelection,
19 problem: &ConvolutionProblem,
20 line_sizes: &MatmulLineSizes,
21 ) -> Self::RuntimeArg<'a, R>;
22}
23
24impl<EI: Numeric> ConvInputsLaunch for TensorInputs<EI> {
25 fn create<'a, R: Runtime>(
26 lhs: &'a TensorHandleRef<'a, R>,
27 rhs: &'a TensorHandleRef<'a, R>,
28 _selection: &MatmulSelection,
29 _problem: &ConvolutionProblem,
30 line_sizes: &MatmulLineSizes,
31 ) -> Self::RuntimeArg<'a, R> {
32 TensorInputsLaunch::new(
33 lhs.as_tensor_arg(line_sizes.lhs),
34 None.into(),
35 rhs.as_tensor_arg(line_sizes.rhs),
36 None.into(),
37 )
38 }
39}
40
41impl<EI: Numeric> ConvInputsLaunch for TensorMapInputs<EI> {
42 fn create<'a, R: Runtime>(
43 lhs: &'a TensorHandleRef<'a, R>,
44 rhs: &'a TensorHandleRef<'a, R>,
45 selection: &MatmulSelection,
46 problem: &ConvolutionProblem,
47 line_sizes: &MatmulLineSizes,
48 ) -> Self::RuntimeArg<'a, R> {
49 let tiling_scheme = selection.tiling_scheme;
50 let stage_m = tiling_scheme.elements_in_stage_m();
51 let stage_n = tiling_scheme.elements_in_stage_n();
52 let tile_size_k = tiling_scheme.elements_in_tile_k();
53 let stage_size_rhs = vec![stage_n, 1, tile_size_k];
54
55 let elem_size = size_of::<EI>();
56
57 fn prefetch(bytes: usize) -> TensorMapPrefetch {
58 match bytes {
59 ..64 => TensorMapPrefetch::None,
60 64..128 => TensorMapPrefetch::B64,
61 128..256 => TensorMapPrefetch::B128,
62 256.. => TensorMapPrefetch::B256,
63 }
64 }
65
66 let prefetch_lhs = prefetch(tile_size_k as usize * elem_size);
67 let prefetch_rhs = prefetch(stage_size_rhs[2] as usize * elem_size);
68
69 let elem = if TypeId::of::<EI>() == TypeId::of::<f32>() {
72 tf32::as_elem_native_unchecked()
73 } else {
74 EI::as_elem_native_unchecked()
75 };
76
77 let mut elem_stride = vec![1; 2 + problem.stride.len()];
78
79 for (i, stride) in problem.stride.iter().enumerate() {
80 elem_stride[i + 1] = *stride as usize;
81 }
82
83 let lhs = TensorMapArg::new(
84 TensorMapFormat::Im2col {
85 pixel_box_lower_corner: calculate_lower_corner(&problem.padding),
86 pixel_box_upper_corner: calculate_upper_corner(
87 &problem.padding,
88 &problem.kernel_size,
89 &problem.dilation,
90 ),
91 channels_per_pixel: tile_size_k,
92 pixels_per_column: stage_m,
93 },
94 lhs.as_tensor_arg(line_sizes.lhs),
95 elem,
96 )
97 .with_elem_stride(elem_stride)
98 .with_prefetch(prefetch_lhs);
99
100 let rhs = TensorMapArg::new(
101 TensorMapFormat::Tiled {
102 tile_size: stage_size_rhs,
103 },
104 rhs.as_tensor_arg(1),
105 EI::as_elem_native_unchecked(),
106 )
107 .with_prefetch(prefetch_rhs);
108
109 TensorMapInputsLaunch::new(lhs, rhs)
110 }
111}