1use std::any::TypeId;
2
3use cubecl_core::{Runtime, client::ComputeClient, prelude::*};
4use cubecl_matmul::{MatmulInputHandleRef, components::AccG};
5use cubecl_runtime::TypeUsage;
6use half::f16;
7
8use crate::{
9 components::{ConvGemmConfig as _, ConvSetupError},
10 kernels::layered::selector::launch_kernel_concrete,
11};
12use crate::{
13 components::{
14 ConvolutionProblem, Dimensionality,
15 global::args::{ConcreteInputsFactory, ConcreteOutputFactory},
16 },
17 kernels::layered::algorithm::Algorithm,
18};
19use cubecl_matmul::components::global::args::MatmulArgs;
20use cubecl_matmul::components::{
21 self, AvailableLineSizes, LhsG, MatmulElems, MatmulIdent, MatmulPrecision, MatmulSelection,
22 MatrixPrecision, RhsG,
23};
24
25type Input<Alg, MP> = <<Alg as Algorithm>::Args as MatmulArgs>::Input<
26 <<MP as MatmulPrecision>::Lhs as MatrixPrecision>::Global,
27 <<MP as MatmulPrecision>::Rhs as MatrixPrecision>::Global,
28 <<MP as MatmulPrecision>::Acc as MatrixPrecision>::Global,
29>;
30type Output<Alg, MP> = <<Alg as Algorithm>::Args as MatmulArgs>::Output<
31 <<MP as MatmulPrecision>::Acc as MatrixPrecision>::Global,
32>;
33
34#[derive(Clone)]
35pub struct ConvolutionArgs<const N_SPATIAL: usize> {
36 pub stride: [usize; N_SPATIAL],
37 pub padding: [usize; N_SPATIAL],
38 pub dilation: [usize; N_SPATIAL],
39}
40
41#[allow(clippy::result_large_err)]
50pub fn launch_conv<R: Runtime, MP: MatmulPrecision, Alg: Algorithm, const N_SPATIAL: usize>(
51 client: &ComputeClient<R::Server>,
52 input: &MatmulInputHandleRef<'_, R>,
53 weight: &MatmulInputHandleRef<'_, R>,
54 bias: &Option<TensorHandleRef<'_, R>>,
55 out: &TensorHandleRef<'_, R>,
56 args: ConvolutionArgs<N_SPATIAL>,
57) -> Result<(), ConvSetupError>
58where
59 Input<Alg, MP>: ConcreteInputsFactory,
60 Output<Alg, MP>: ConcreteOutputFactory,
61{
62 let ConvolutionArgs {
63 stride,
64 padding,
65 dilation,
66 } = args;
67
68 let dimensionality = match N_SPATIAL {
69 1 => Dimensionality::Dim1,
70 2 => Dimensionality::Dim2,
71 3 => Dimensionality::Dim3,
72 other => unimplemented!("Unsupported dimensionality {other}"),
73 };
74
75 launch::<R, MP, Alg>(
76 client,
77 input,
78 weight,
79 bias,
80 out,
81 (&stride, &padding, &dilation),
82 dimensionality,
83 )
84}
85
86fn launch<R: Runtime, MP: MatmulPrecision, Alg: Algorithm>(
87 client: &ComputeClient<R::Server>,
88 input: &MatmulInputHandleRef<'_, R>,
89 weight: &MatmulInputHandleRef<'_, R>,
90 bias: &Option<TensorHandleRef<'_, R>>,
91 out: &TensorHandleRef<'_, R>,
92 (stride, padding, dilation): (&[usize], &[usize], &[usize]),
93 dimensionality: Dimensionality,
94) -> Result<(), ConvSetupError>
95where
96 Input<Alg, MP>: ConcreteInputsFactory,
97 Output<Alg, MP>: ConcreteOutputFactory,
98{
99 let rank = input.data().shape.len();
100 let dim_c = rank - 1;
101
102 let n = input.data().shape[0];
103 let c = input.data().shape[dim_c];
104
105 let out_c = weight.data().shape[0];
106
107 let in_shape = &input.data().shape[1..dim_c];
108 let kernel_shape = &weight.data().shape[1..dim_c];
109 let out_shape = &out.shape[1..dim_c];
110
111 let input_data = Alg::into_tensor_handle::<R, LhsG<MP>>(client, input.data(), MatmulIdent::Lhs);
112 let weight_data =
113 Alg::into_tensor_handle::<R, RhsG<MP>>(client, weight.data(), MatmulIdent::Rhs);
114
115 let mut input = *input;
116 let mut weight = *weight;
117
118 *input.data_mut() = input_data.as_ref();
119 *weight.data_mut() = weight_data.as_ref();
120
121 let plane_dim = client.properties().hardware.plane_size_max;
122
123 let problem = ConvolutionProblem {
124 m: n * out_shape.iter().product::<usize>(),
125 n: out_c,
126 k: c * kernel_shape.iter().product::<usize>(),
127 lhs_layout: components::MatrixLayout::RowMajor,
128 rhs_layout: components::MatrixLayout::ColMajor,
129 kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
130 stride: stride.iter().map(|it| *it as u32).collect(),
131 padding: padding.iter().map(|it| *it as i32).collect(),
132 dilation: dilation.iter().map(|it| *it as u32).collect(),
133
134 batches: n,
135 shape: in_shape.to_vec(),
136 out_shape: out_shape.to_vec(),
137 channels: c,
138
139 dimensionality,
140 };
141
142 let selection = Alg::selection::<R>(client, &problem, plane_dim, MatmulElems::new::<MP>())?;
143
144 let lhs_is_f32 = TypeId::of::<LhsG<MP>>() == TypeId::of::<f32>();
145 let rhs_is_f32 = TypeId::of::<RhsG<MP>>() == TypeId::of::<f32>();
146
147 let launch = if lhs_is_f32 || rhs_is_f32 {
148 if tf32::supported_uses(client).contains(TypeUsage::Conversion) {
149 if lhs_is_f32 && rhs_is_f32 {
150 launch_kernel::<R, (LhsG<MP>, RhsG<MP>, AccG<MP>, tf32, tf32, f32), Alg>
151 } else if lhs_is_f32 {
152 launch_kernel::<R, (LhsG<MP>, RhsG<MP>, AccG<MP>, tf32, f32, f32), Alg>
153 } else {
154 launch_kernel::<R, (LhsG<MP>, RhsG<MP>, AccG<MP>, f32, tf32, f32), Alg>
155 }
156 } else if lhs_is_f32 && rhs_is_f32 {
157 launch_kernel::<R, (LhsG<MP>, RhsG<MP>, AccG<MP>, f16, f16, f32), Alg>
158 } else if lhs_is_f32 {
159 launch_kernel::<R, (LhsG<MP>, RhsG<MP>, AccG<MP>, f16, f32, f32), Alg>
160 } else {
161 launch_kernel::<R, (LhsG<MP>, RhsG<MP>, AccG<MP>, f32, f16, f32), Alg>
162 }
163 } else {
164 launch_kernel::<R, MP, Alg>
165 };
166
167 launch(client, &input, &weight, bias, out, problem, selection)
168}
169
170#[allow(clippy::result_large_err, clippy::too_many_arguments)]
171pub fn launch_kernel<R: Runtime, MP: MatmulPrecision, Alg: Algorithm>(
172 client: &ComputeClient<R::Server>,
173 input: &MatmulInputHandleRef<'_, R>,
174 weight: &MatmulInputHandleRef<'_, R>,
175 bias: &Option<TensorHandleRef<'_, R>>,
176 out: &TensorHandleRef<'_, R>,
177 problem: ConvolutionProblem,
178 selection: MatmulSelection,
179) -> Result<(), ConvSetupError>
180where
181 Input<Alg, MP>: ConcreteInputsFactory,
182 Output<Alg, MP>: ConcreteOutputFactory,
183{
184 let line_sizes = AvailableLineSizes::from_type_sizes::<R>(
185 input.data().elem_size,
186 weight.data().elem_size,
187 out.elem_size,
188 )
189 .filter_lhs_with_tensor(input.data().strides, input.data().shape, problem.lhs_layout)
190 .filter_rhs_with_tensor(
191 weight.data().strides,
192 weight.data().shape,
193 problem.rhs_layout,
194 )
195 .filter_out_with_tensor(out.strides, out.shape);
196
197 let line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
198
199 let config = Alg::setup::<R, MP>(client, &problem, &selection, &line_sizes)?;
200
201 let line_sizes = config.line_sizes();
202
203 launch_kernel_concrete::<(MP, Alg::Args), R, Alg>(
204 client, input, weight, bias, out, problem, line_sizes, selection,
205 )
206}