1use std::any::TypeId;
2
3use cubecl_core::{Runtime, client::ComputeClient, prelude::*};
4use cubecl_matmul::components::global::GlobalConfig;
5use half::f16;
6
7use crate::ConvGemmConfig;
8use crate::base::ConvolutionLaunch;
9use cubecl_matmul::components::global::args::{ConcreteOutputFactory, MatmulArgs};
10use cubecl_matmul::components::{
11 self, AvailableLineSizes, InputIdent, MatmulPrecision, MatmulSelection,
12};
13
14use super::{
15 ConvLaunchError,
16 algorithm::Algorithm,
17 args::ConvInputsLaunch,
18 base::{ConvolutionProblem, Dimensionality},
19};
20
21type Input<Alg, MP> = <<Alg as Algorithm>::Args as MatmulArgs>::Input<<MP as MatmulPrecision>::EI>;
22type Output<Alg, MP> =
23 <<Alg as Algorithm>::Args as MatmulArgs>::Output<<MP as MatmulPrecision>::EO>;
24
25#[derive(Clone)]
26pub struct ConvolutionArgs<const N_SPATIAL: usize> {
27 pub stride: [usize; N_SPATIAL],
28 pub padding: [usize; N_SPATIAL],
29 pub dilation: [usize; N_SPATIAL],
30}
31
32#[allow(clippy::result_large_err)]
41pub fn launch_conv<R: Runtime, MP: MatmulPrecision, Alg: Algorithm, const N_SPATIAL: usize>(
42 client: &ComputeClient<R::Server, R::Channel>,
43 input: &TensorHandleRef<'_, R>,
44 weight: &TensorHandleRef<'_, R>,
45 bias: &Option<TensorHandleRef<'_, R>>,
46 out: &TensorHandleRef<'_, R>,
47 args: ConvolutionArgs<N_SPATIAL>,
48) -> Result<(), ConvLaunchError>
49where
50 Input<Alg, MP>: ConvInputsLaunch,
51 Output<Alg, MP>: ConcreteOutputFactory,
52{
53 let ConvolutionArgs {
54 stride,
55 padding,
56 dilation,
57 } = args;
58
59 let dimensionality = match N_SPATIAL {
60 1 => Dimensionality::Dim1,
61 2 => Dimensionality::Dim2,
62 3 => Dimensionality::Dim3,
63 other => unimplemented!("Unsupported dimensionality {other}"),
64 };
65
66 launch::<R, MP, Alg>(
67 client,
68 input,
69 weight,
70 bias,
71 out,
72 (&stride, &padding, &dilation),
73 dimensionality,
74 )
75}
76
77fn launch<R: Runtime, MP: MatmulPrecision, Alg: Algorithm>(
78 client: &ComputeClient<R::Server, R::Channel>,
79 input: &TensorHandleRef<'_, R>,
80 weight: &TensorHandleRef<'_, R>,
81 bias: &Option<TensorHandleRef<'_, R>>,
82 out: &TensorHandleRef<'_, R>,
83 (stride, padding, dilation): (&[usize], &[usize], &[usize]),
84 dimensionality: Dimensionality,
85) -> Result<(), ConvLaunchError>
86where
87 Input<Alg, MP>: ConvInputsLaunch,
88 Output<Alg, MP>: ConcreteOutputFactory,
89{
90 let rank = input.shape.len();
91 let dim_c = rank - 1;
92
93 let n = input.shape[0];
94 let c = input.shape[dim_c];
95
96 let out_c = weight.shape[0];
97
98 let in_shape = &input.shape[1..dim_c];
99 let kernel_shape = &weight.shape[1..dim_c];
100 let out_shape = &out.shape[1..dim_c];
101
102 let input = Alg::into_tensor_handle::<R, MP::EI>(client, input, InputIdent::Lhs);
103 let weight = Alg::into_tensor_handle::<R, MP::EI>(client, weight, InputIdent::Rhs);
104
105 let plane_dim = client.properties().hardware.plane_size_max;
106
107 let problem = ConvolutionProblem {
108 m: n * out_shape.iter().product::<usize>(),
109 n: out_c,
110 k: c * kernel_shape.iter().product::<usize>(),
111 lhs_layout: components::MatrixLayout::RowMajor,
112 rhs_layout: components::MatrixLayout::ColMajor,
113 kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
114 stride: stride.iter().map(|it| *it as u32).collect(),
115 padding: padding.iter().map(|it| *it as i32).collect(),
116 dilation: dilation.iter().map(|it| *it as u32).collect(),
117
118 batches: n,
119 shape: in_shape.to_vec(),
120 out_shape: out_shape.to_vec(),
121 channels: c,
122
123 dimensionality,
124 };
125
126 let selection = Alg::selection::<R>(
127 client,
128 &problem,
129 plane_dim,
130 MP::ES::as_elem_native_unchecked(),
131 MP::EA::as_elem_native_unchecked(),
132 );
133
134 let launch = if TypeId::of::<MP::EI>() == TypeId::of::<f32>() {
135 if tf32::is_supported(client) {
136 launch_kernel::<R, (MP::EI, tf32, f32, MP::EO), Alg>
137 } else {
138 launch_kernel::<R, (MP::EI, f16, f32, MP::EO), Alg>
139 }
140 } else {
141 launch_kernel::<R, MP, Alg>
142 };
143
144 launch(
145 client,
146 &input.as_ref(),
147 &weight.as_ref(),
148 bias,
149 out,
150 problem,
151 selection,
152 )
153}
154
155#[allow(clippy::result_large_err, clippy::too_many_arguments)]
156pub fn launch_kernel<R: Runtime, MP: MatmulPrecision, Alg: Algorithm>(
157 client: &ComputeClient<R::Server, R::Channel>,
158 input: &TensorHandleRef<'_, R>,
159 weight: &TensorHandleRef<'_, R>,
160 bias: &Option<TensorHandleRef<'_, R>>,
161 out: &TensorHandleRef<'_, R>,
162 problem: ConvolutionProblem,
163 selection: MatmulSelection,
164) -> Result<(), ConvLaunchError>
165where
166 Input<Alg, MP>: ConvInputsLaunch,
167 Output<Alg, MP>: ConcreteOutputFactory,
168{
169 let rank = out.shape.len();
170 let dim_c = rank - 1;
171
172 let out_shape = [out.shape[0..dim_c].iter().product(), out.shape[dim_c]];
174 let out_strides = [out.strides[rank - 2], out.strides[dim_c]];
175
176 let out = unsafe {
177 TensorHandleRef::from_raw_parts(out.handle, &out_strides, &out_shape, out.elem_size)
178 };
179
180 let line_sizes = AvailableLineSizes::from_elem_types::<R>(
181 &MP::EI::as_elem_native_unchecked(),
182 &MP::EO::as_elem_native_unchecked(),
183 )
184 .filter_lhs_with_tensor(input.strides, input.shape, problem.lhs_layout)
185 .filter_rhs_with_tensor(weight.strides, weight.shape, problem.rhs_layout)
186 .filter_out_with_tensor(out.strides, out.shape);
187
188 let line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
189
190 let config = Alg::setup::<R, MP>(client, &problem, &selection, &line_sizes)?;
191
192 let line_sizes = config.line_sizes();
193
194 let input = <Input<Alg, MP> as ConvInputsLaunch>::create(
195 input,
196 weight,
197 &selection,
198 &problem,
199 &line_sizes,
200 );
201 let output = <Output<Alg, MP> as ConcreteOutputFactory>::create(
202 &out,
203 &selection,
204 &problem.as_matmul_problem(),
205 &line_sizes,
206 );
207 let bias = bias.as_ref().map(|bias| bias.as_tensor_arg(line_sizes.out));
208
209 unsafe {
210 Alg::GlobalConvolution::launch_unchecked::<(MP, Alg::Args), R>(
211 client,
212 config.cube_dim(),
213 Alg::cube_count(&selection, &problem),
214 input,
215 bias,
216 output,
217 &problem,
218 config,
219 );
220 }
221
222 Ok(())
223}