1use crate::{
2 components::{ConvGemmConfig as _, ConvSetupError},
3 kernels::layered::selector::launch_kernel_concrete,
4};
5use crate::{
6 components::{
7 ConvolutionProblem, Dimensionality,
8 global::args::{ConcreteInputsFactory, ConcreteOutputFactory},
9 },
10 kernels::layered::algorithm::Algorithm,
11};
12use cubecl_core::{Runtime, client::ComputeClient, prelude::*};
13use cubecl_matmul::components::{
14 self, AvailableLineSizes, MatmulElems, MatmulIdent, MatmulSelection,
15};
16use cubecl_matmul::{
17 MatmulInputHandleRef,
18 components::{InputArg, OutputArg},
19};
20
21#[derive(Clone)]
22pub struct ConvolutionArgs<const N_SPATIAL: usize> {
23 pub stride: [usize; N_SPATIAL],
24 pub padding: [usize; N_SPATIAL],
25 pub dilation: [usize; N_SPATIAL],
26}
27
28#[allow(clippy::result_large_err)]
37pub fn launch_conv<R: Runtime, Alg: Algorithm, const N_SPATIAL: usize>(
38 client: &ComputeClient<R::Server>,
39 input: &MatmulInputHandleRef<'_, R>,
40 weight: &MatmulInputHandleRef<'_, R>,
41 bias: &Option<TensorHandleRef<'_, R>>,
42 out: &TensorHandleRef<'_, R>,
43 args: ConvolutionArgs<N_SPATIAL>,
44 dtypes: MatmulElems,
45) -> Result<(), ConvSetupError>
46where
47 InputArg<Alg::Args>: ConcreteInputsFactory,
48 OutputArg<Alg::Args>: ConcreteOutputFactory,
49{
50 let ConvolutionArgs {
51 stride,
52 padding,
53 dilation,
54 } = args;
55
56 let dimensionality = match N_SPATIAL {
57 1 => Dimensionality::Dim1,
58 2 => Dimensionality::Dim2,
59 3 => Dimensionality::Dim3,
60 other => unimplemented!("Unsupported dimensionality {other}"),
61 };
62
63 launch::<R, Alg>(
64 client,
65 input,
66 weight,
67 bias,
68 out,
69 (&stride, &padding, &dilation),
70 dimensionality,
71 dtypes,
72 )
73}
74
75#[allow(clippy::too_many_arguments)]
76fn launch<R: Runtime, Alg: Algorithm>(
77 client: &ComputeClient<R::Server>,
78 input: &MatmulInputHandleRef<'_, R>,
79 weight: &MatmulInputHandleRef<'_, R>,
80 bias: &Option<TensorHandleRef<'_, R>>,
81 out: &TensorHandleRef<'_, R>,
82 (stride, padding, dilation): (&[usize], &[usize], &[usize]),
83 dimensionality: Dimensionality,
84 mut dtypes: MatmulElems,
85) -> Result<(), ConvSetupError>
86where
87 InputArg<Alg::Args>: ConcreteInputsFactory,
88 OutputArg<Alg::Args>: ConcreteOutputFactory,
89{
90 let rank = input.data().shape.len();
91 let dim_c = rank - 1;
92
93 let n = input.data().shape[0];
94 let c = input.data().shape[dim_c];
95
96 let out_c = weight.data().shape[0];
97
98 let in_shape = &input.data().shape[1..dim_c];
99 let kernel_shape = &weight.data().shape[1..dim_c];
100 let out_shape = &out.shape[1..dim_c];
101
102 let input_data =
103 Alg::into_tensor_handle::<R>(client, input.data(), MatmulIdent::Lhs, dtypes.lhs_global);
104 let weight_data =
105 Alg::into_tensor_handle::<R>(client, weight.data(), MatmulIdent::Rhs, dtypes.rhs_global);
106
107 let mut input = *input;
108 let mut weight = *weight;
109
110 *input.data_mut() = input_data.as_ref();
111 *weight.data_mut() = weight_data.as_ref();
112
113 let plane_dim = client.properties().hardware.plane_size_max;
114
115 let problem = ConvolutionProblem {
116 m: n * out_shape.iter().product::<usize>(),
117 n: out_c,
118 k: c * kernel_shape.iter().product::<usize>(),
119 lhs_layout: components::MatrixLayout::RowMajor,
120 rhs_layout: components::MatrixLayout::ColMajor,
121 kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
122 stride: stride.iter().map(|it| *it as u32).collect(),
123 padding: padding.iter().map(|it| *it as i32).collect(),
124 dilation: dilation.iter().map(|it| *it as u32).collect(),
125
126 batches: n,
127 shape: in_shape.to_vec(),
128 out_shape: out_shape.to_vec(),
129 channels: c,
130
131 dimensionality,
132 };
133
134 let selection = Alg::selection::<R>(client, &problem, plane_dim, &mut dtypes)?;
135
136 launch_kernel::<R, Alg>(
137 client, &input, &weight, bias, out, problem, selection, dtypes,
138 )
139}
140
141#[allow(clippy::result_large_err, clippy::too_many_arguments)]
142pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
143 client: &ComputeClient<R::Server>,
144 input: &MatmulInputHandleRef<'_, R>,
145 weight: &MatmulInputHandleRef<'_, R>,
146 bias: &Option<TensorHandleRef<'_, R>>,
147 out: &TensorHandleRef<'_, R>,
148 problem: ConvolutionProblem,
149 selection: MatmulSelection,
150 dtypes: MatmulElems,
151) -> Result<(), ConvSetupError>
152where
153 InputArg<Alg::Args>: ConcreteInputsFactory,
154 OutputArg<Alg::Args>: ConcreteOutputFactory,
155{
156 let line_sizes = AvailableLineSizes::from_type_sizes::<R>(
157 input.data().elem_size,
158 weight.data().elem_size,
159 out.elem_size,
160 )
161 .filter_lhs_with_tensor(input.data().strides, input.data().shape, problem.lhs_layout)
162 .filter_rhs_with_tensor(
163 weight.data().strides,
164 weight.data().shape,
165 problem.rhs_layout,
166 )
167 .filter_out_with_tensor(out.strides, out.shape);
168
169 let line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
170
171 let config = Alg::setup::<R>(client, &problem, &selection, &line_sizes, &dtypes)?;
172
173 let line_sizes = config.line_sizes();
174
175 launch_kernel_concrete::<R, Alg>(
176 client, input, weight, bias, out, problem, line_sizes, selection, &dtypes,
177 )
178}