1use crate::components::ConvGemmConfig as _;
2use crate::{components::ConvSetupError, kernels::layered::selector::launch_kernel_concrete};
3use crate::{
4 components::{
5 ConvolutionProblem, Dimensionality,
6 global::args::{ConcreteInputsFactory, ConcreteOutputFactory},
7 },
8 kernels::layered::algorithm::Algorithm,
9};
10use cubecl_core::{Runtime, client::ComputeClient, prelude::*};
11use cubecl_matmul::components::{
12 self, AvailableLineSizes, MatmulElems, MatmulIdent, MatmulSelection,
13};
14use cubecl_matmul::{
15 MatmulInputHandleRef,
16 components::{InputArg, OutputArg},
17};
18
19#[derive(Clone)]
20pub struct ConvolutionArgs<const N_SPATIAL: usize> {
21 pub stride: [usize; N_SPATIAL],
22 pub padding: [usize; N_SPATIAL],
23 pub dilation: [usize; N_SPATIAL],
24}
25
26#[allow(clippy::result_large_err)]
35pub fn launch_conv<R: Runtime, Alg: Algorithm, const N_SPATIAL: usize>(
36 client: &ComputeClient<R>,
37 input: &MatmulInputHandleRef<'_, R>,
38 weight: &MatmulInputHandleRef<'_, R>,
39 bias: &Option<TensorHandleRef<'_, R>>,
40 out: &TensorHandleRef<'_, R>,
41 args: ConvolutionArgs<N_SPATIAL>,
42 dtypes: MatmulElems,
43) -> Result<(), ConvSetupError>
44where
45 InputArg<Alg::Args>: ConcreteInputsFactory,
46 OutputArg<Alg::Args>: ConcreteOutputFactory,
47{
48 let ConvolutionArgs {
49 stride,
50 padding,
51 dilation,
52 } = args;
53
54 let dimensionality = match N_SPATIAL {
55 1 => Dimensionality::Dim1,
56 2 => Dimensionality::Dim2,
57 3 => Dimensionality::Dim3,
58 other => unimplemented!("Unsupported dimensionality {other}"),
59 };
60
61 launch::<R, Alg>(
62 client,
63 input,
64 weight,
65 bias,
66 out,
67 (&stride, &padding, &dilation),
68 dimensionality,
69 dtypes,
70 )
71}
72
73#[allow(clippy::too_many_arguments)]
74fn launch<R: Runtime, Alg: Algorithm>(
75 client: &ComputeClient<R>,
76 input: &MatmulInputHandleRef<'_, R>,
77 weight: &MatmulInputHandleRef<'_, R>,
78 bias: &Option<TensorHandleRef<'_, R>>,
79 out: &TensorHandleRef<'_, R>,
80 (stride, padding, dilation): (&[usize], &[usize], &[usize]),
81 dimensionality: Dimensionality,
82 mut dtypes: MatmulElems,
83) -> Result<(), ConvSetupError>
84where
85 InputArg<Alg::Args>: ConcreteInputsFactory,
86 OutputArg<Alg::Args>: ConcreteOutputFactory,
87{
88 let rank = input.data().shape.len();
89 let dim_c = rank - 1;
90
91 let n = input.data().shape[0];
92 let c = input.data().shape[dim_c];
93
94 let out_c = weight.data().shape[0];
95
96 let in_shape = &input.data().shape[1..dim_c];
97 let kernel_shape = &weight.data().shape[1..dim_c];
98 let out_shape = &out.shape[1..dim_c];
99
100 let input_data =
101 Alg::into_tensor_handle(client, input.data(), MatmulIdent::Lhs, dtypes.lhs_global)?;
102 let weight_data =
103 Alg::into_tensor_handle(client, weight.data(), MatmulIdent::Rhs, dtypes.rhs_global)?;
104
105 let mut input = *input;
106 let mut weight = *weight;
107
108 *input.data_mut() = input_data.as_ref();
109 *weight.data_mut() = weight_data.as_ref();
110
111 let plane_dim = client.properties().hardware.plane_size_max;
112
113 let problem = ConvolutionProblem {
114 m: n * out_shape.iter().product::<usize>(),
115 n: out_c,
116 k: c * kernel_shape.iter().product::<usize>(),
117 lhs_layout: components::MatrixLayout::RowMajor,
118 rhs_layout: components::MatrixLayout::ColMajor,
119 kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
120 stride: stride.iter().map(|it| *it as u32).collect(),
121 padding: padding.iter().map(|it| *it as i32).collect(),
122 dilation: dilation.iter().map(|it| *it as u32).collect(),
123
124 batches: n,
125 shape: in_shape.to_vec(),
126 out_shape: out_shape.to_vec(),
127 channels: c,
128
129 dimensionality,
130 };
131
132 let selection = Alg::selection(client, &problem, plane_dim, &mut dtypes)?;
133
134 launch_kernel::<R, Alg>(
135 client, &input, &weight, bias, out, problem, selection, dtypes,
136 )
137}
138
139#[allow(clippy::result_large_err, clippy::too_many_arguments)]
140pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
141 client: &ComputeClient<R>,
142 input: &MatmulInputHandleRef<'_, R>,
143 weight: &MatmulInputHandleRef<'_, R>,
144 bias: &Option<TensorHandleRef<'_, R>>,
145 out: &TensorHandleRef<'_, R>,
146 problem: ConvolutionProblem,
147 selection: MatmulSelection,
148 dtypes: MatmulElems,
149) -> Result<(), ConvSetupError>
150where
151 InputArg<Alg::Args>: ConcreteInputsFactory,
152 OutputArg<Alg::Args>: ConcreteOutputFactory,
153{
154 let line_sizes = AvailableLineSizes::from_type_sizes(
155 client,
156 input.data().elem_size,
157 weight.data().elem_size,
158 out.elem_size,
159 )
160 .filter_lhs_with_tensor(input.data().strides, input.data().shape, problem.lhs_layout)
161 .filter_rhs_with_tensor(
162 weight.data().strides,
163 weight.data().shape,
164 problem.rhs_layout,
165 )
166 .filter_out_with_tensor(out.strides, out.shape);
167
168 let line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
169
170 let config = Alg::setup(client, &problem, &selection, &line_sizes, &dtypes)?;
171
172 let line_sizes = config.line_sizes();
173
174 launch_kernel_concrete::<R, Alg>(
175 client, input, weight, bias, out, problem, line_sizes, selection, &dtypes,
176 )
177}