1use crate::components::ConvGemmConfig as _;
2use crate::{components::ConvSetupError, kernels::layered::selector::launch_kernel_concrete};
3use crate::{
4 components::{
5 ConvolutionProblem, Dimensionality,
6 global::args::{ConcreteInputsFactory, ConcreteOutputFactory},
7 },
8 kernels::layered::algorithm::Algorithm,
9};
10use cubecl_core::{Runtime, client::ComputeClient, prelude::*};
11use cubecl_matmul::components::{self, AvailableLineSizes, MatmulElems, MatrixLayout};
12use cubecl_matmul::{
13 MatmulInputHandleRef,
14 components::{InputArg, OutputArg},
15};
16
17#[derive(Clone)]
18pub struct ConvolutionArgs<const N_SPATIAL: usize> {
19 pub stride: [usize; N_SPATIAL],
20 pub padding: [usize; N_SPATIAL],
21 pub dilation: [usize; N_SPATIAL],
22}
23
24#[allow(clippy::result_large_err)]
33pub fn launch_conv<R: Runtime, Alg: Algorithm, const N_SPATIAL: usize>(
34 client: &ComputeClient<R>,
35 input: &MatmulInputHandleRef<'_, R>,
36 weight: &MatmulInputHandleRef<'_, R>,
37 bias: &Option<TensorHandleRef<'_, R>>,
38 out: &TensorHandleRef<'_, R>,
39 args: ConvolutionArgs<N_SPATIAL>,
40 dtypes: MatmulElems,
41) -> Result<(), ConvSetupError>
42where
43 InputArg<Alg::Args>: ConcreteInputsFactory,
44 OutputArg<Alg::Args>: ConcreteOutputFactory,
45{
46 let ConvolutionArgs {
47 stride,
48 padding,
49 dilation,
50 } = args;
51
52 let dimensionality = match N_SPATIAL {
53 1 => Dimensionality::Dim1,
54 2 => Dimensionality::Dim2,
55 3 => Dimensionality::Dim3,
56 other => unimplemented!("Unsupported dimensionality {other}"),
57 };
58
59 launch::<R, Alg>(
60 client,
61 input,
62 weight,
63 bias,
64 out,
65 (&stride, &padding, &dilation),
66 dimensionality,
67 dtypes,
68 )
69}
70
71#[allow(clippy::too_many_arguments)]
72fn launch<R: Runtime, Alg: Algorithm>(
73 client: &ComputeClient<R>,
74 input: &MatmulInputHandleRef<'_, R>,
75 weight: &MatmulInputHandleRef<'_, R>,
76 bias: &Option<TensorHandleRef<'_, R>>,
77 out: &TensorHandleRef<'_, R>,
78 (stride, padding, dilation): (&[usize], &[usize], &[usize]),
79 dimensionality: Dimensionality,
80 dtypes: MatmulElems,
81) -> Result<(), ConvSetupError>
82where
83 InputArg<Alg::Args>: ConcreteInputsFactory,
84 OutputArg<Alg::Args>: ConcreteOutputFactory,
85{
86 let rank = input.data().shape.len();
87 let dim_c = rank - 1;
88
89 let n = input.data().shape[0];
90 let c = input.data().shape[dim_c];
91
92 let out_c = weight.data().shape[0];
93
94 let in_shape = &input.data().shape[1..dim_c];
95 let kernel_shape = &weight.data().shape[1..dim_c];
96 let out_shape = &out.shape[1..dim_c];
97
98 let input_data = Alg::into_tensor_handle(client, input.data(), *dtypes.lhs_global)?;
99 let weight_data = Alg::into_tensor_handle(client, weight.data(), *dtypes.rhs_global)?;
100
101 let mut input = *input;
102 let mut weight = *weight;
103
104 *input.data_mut() = input_data.as_ref();
105 *weight.data_mut() = weight_data.as_ref();
106
107 let problem = ConvolutionProblem {
108 m: n * out_shape.iter().product::<usize>(),
109 n: out_c,
110 k: c * kernel_shape.iter().product::<usize>(),
111 lhs_strides: input.data().strides.to_vec(),
112 rhs_strides: weight.data().strides.to_vec(),
113 lhs_layout: components::MatrixLayout::RowMajor,
114 rhs_layout: components::MatrixLayout::ColMajor,
115 kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
116 stride: stride.iter().map(|it| *it as u32).collect(),
117 padding: padding.iter().map(|it| *it as i32).collect(),
118 dilation: dilation.iter().map(|it| *it as u32).collect(),
119
120 batches: n,
121 shape: in_shape.to_vec(),
122 out_shape: out_shape.to_vec(),
123 channels: c,
124
125 dimensionality,
126 };
127
128 launch_kernel::<R, Alg>(client, &input, &weight, bias, out, problem, dtypes)
129}
130
131#[allow(clippy::result_large_err, clippy::too_many_arguments)]
132pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
133 client: &ComputeClient<R>,
134 input: &MatmulInputHandleRef<'_, R>,
135 weight: &MatmulInputHandleRef<'_, R>,
136 bias: &Option<TensorHandleRef<'_, R>>,
137 out: &TensorHandleRef<'_, R>,
138 problem: ConvolutionProblem,
139 mut dtypes: MatmulElems,
140) -> Result<(), ConvSetupError>
141where
142 InputArg<Alg::Args>: ConcreteInputsFactory,
143 OutputArg<Alg::Args>: ConcreteOutputFactory,
144{
145 let plane_dim = client.properties().hardware.plane_size_max;
146 let line_sizes = AvailableLineSizes::from_type_sizes(
149 client,
150 input.data().elem_size,
151 weight.data().elem_size,
152 out.elem_size,
153 )
154 .filter_lhs_with_tensor(
155 input.data().strides,
156 input.data().shape,
157 MatrixLayout::RowMajor,
158 )
159 .filter_rhs_with_tensor(
160 weight.data().strides,
161 weight.data().shape,
162 MatrixLayout::RowMajor,
163 )
164 .filter_out_with_tensor(out.strides, out.shape);
165
166 let line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
167
168 let selection = Alg::selection(client, &problem, plane_dim, &line_sizes, &mut dtypes)?;
169
170 let config = Alg::setup(client, &problem, &selection, &line_sizes, &dtypes)?;
171
172 let line_sizes = config.line_sizes();
173
174 launch_kernel_concrete::<R, Alg>(
175 client, input, weight, bias, out, problem, line_sizes, selection, &dtypes,
176 )
177}