1use crate::{AcceleratedTileKind, ReadingStrategy};
2use crate::{
3 ConvolutionArgs, Strategy,
4 components::{ConvGemmConfig as _, ConvolutionOperation},
5 forward::args::ConcreteArgs,
6 kernels::forward::simple::*,
7};
8use crate::{components::ConvSetupError, kernels::forward::selector::launch_kernel_concrete};
9use crate::{
10 components::{ConvolutionProblem, Dimensionality},
11 kernels::forward::algorithm::Algorithm,
12};
13use cubecl::{
14 Runtime,
15 client::ComputeClient,
16 prelude::*,
17 std::{CubeOption, tensor::TensorHandle},
18};
19use cubek_matmul::launch::MatmulInputHandle;
20use cubek_matmul::{
21 components::tile::{cmma::CmmaMatmul, io::Strided, mma::MmaMatmul},
22 definition::{AvailableLineSizes, MatmulElems, MatrixLayout},
23};
24use cubek_matmul::{definition, launch::MatmulInputHandleRef};
25use derive_new::new;
26
27macro_rules! with_tile_kind {
28 ($kind: expr, $T: ident, $launch: expr) => {
29 match $kind {
30 AcceleratedTileKind::Cmma => {
31 type $T = CmmaMatmul<CubeOption<Strided>>;
32 ($launch)()
33 }
34 AcceleratedTileKind::Mma => {
35 type $T = MmaMatmul<Strided, Strided, CubeOption<Strided>>;
36 ($launch)()
37 }
38 }
39 };
40}
41
42#[allow(clippy::result_large_err, clippy::too_many_arguments)]
43pub fn launch<R: Runtime, const N_SPATIAL: usize>(
44 strategy: &Strategy,
45 client: &ComputeClient<R>,
46 input: MatmulInputHandle<R>,
47 weight: MatmulInputHandle<R>,
48 bias: Option<MatmulInputHandle<R>>,
49 out: TensorHandle<R>,
50 args: ConvolutionArgs<N_SPATIAL>,
51 dtypes: MatmulElems,
52) -> Result<(), ConvSetupError> {
53 launch_ref(
54 strategy,
55 client,
56 &input.as_ref(),
57 &weight.as_ref(),
58 &bias.as_ref().map(|it| it.as_ref()),
59 &out.as_ref(),
60 args,
61 dtypes,
62 )
63}
64
65#[allow(clippy::result_large_err, clippy::too_many_arguments)]
74pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
75 strategy: &Strategy,
76 client: &ComputeClient<R>,
77 input: &MatmulInputHandleRef<'_, R>,
78 weight: &MatmulInputHandleRef<'_, R>,
79 bias: &Option<MatmulInputHandleRef<'_, R>>,
80 out: &TensorHandleRef<'_, R>,
81 args: ConvolutionArgs<N_SPATIAL>,
82 dtypes: MatmulElems,
83) -> Result<(), ConvSetupError> {
84 let conv = Convolution::new(client, input, weight, bias, out, args, dtypes);
85
86 match strategy {
87 Strategy::Simple {
88 read_strategy,
89 tile_kind,
90 } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
91 ReadingStrategy::Cyclic => conv.launch::<SimpleSyncCyclicConv<Accelerated>>(),
92 ReadingStrategy::Strided => conv.launch::<SimpleSyncStridedConv<Accelerated>>(),
93 ReadingStrategy::Tilewise => conv.launch::<SimpleSyncTilewiseConv<Accelerated>>(),
94 ReadingStrategy::AsyncCyclic => conv.launch::<SimpleAsyncCyclicConv<Accelerated>>(),
95 ReadingStrategy::AsyncStrided => conv.launch::<SimpleAsyncStridedConv<Accelerated>>(),
96 ReadingStrategy::Tma => conv.launch::<SimpleAsyncTmaConv<Accelerated>>(),
97 }),
98 }
99}
100
101#[derive(new)]
102struct Convolution<'a, R: Runtime, const N_SPATIAL: usize> {
103 client: &'a ComputeClient<R>,
104 input: &'a MatmulInputHandleRef<'a, R>,
105 weight: &'a MatmulInputHandleRef<'a, R>,
106 bias: &'a Option<MatmulInputHandleRef<'a, R>>,
107 out: &'a TensorHandleRef<'a, R>,
108 args: ConvolutionArgs<N_SPATIAL>,
109 dtypes: MatmulElems,
110}
111
112impl<'a, R: Runtime, const N_SPATIAL: usize> Convolution<'a, R, N_SPATIAL> {
113 fn launch<Alg: Algorithm>(self) -> Result<(), ConvSetupError>
114 where
115 Alg::Args: ConcreteArgs,
116 {
117 let ConvolutionArgs {
118 stride,
119 padding,
120 dilation,
121 } = self.args;
122
123 let dimensionality = match N_SPATIAL {
124 1 => Dimensionality::Dim1,
125 2 => Dimensionality::Dim2,
126 3 => Dimensionality::Dim3,
127 other => unimplemented!("Unsupported dimensionality {other}"),
128 };
129
130 launch_with_algorithm::<R, Alg>(
131 self.client,
132 self.input,
133 self.weight,
134 self.bias,
135 self.out,
136 (&stride, &padding, &dilation),
137 dimensionality,
138 self.dtypes,
139 )
140 }
141}
142
143#[allow(clippy::too_many_arguments)]
144fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
145 client: &ComputeClient<R>,
146 input: &MatmulInputHandleRef<'_, R>,
147 weight: &MatmulInputHandleRef<'_, R>,
148 bias: &Option<MatmulInputHandleRef<'_, R>>,
149 out: &TensorHandleRef<'_, R>,
150 (stride, padding, dilation): (&[usize], &[usize], &[usize]),
151 dimensionality: Dimensionality,
152 dtypes: MatmulElems,
153) -> Result<(), ConvSetupError>
154where
155 Alg::Args: ConcreteArgs,
156{
157 let rank = input.data().shape.len();
158 let dim_c = rank - 1;
159
160 let n = input.data().shape[0];
161 let c = input.data().shape[dim_c];
162
163 let out_c = weight.data().shape[0];
164
165 let in_shape = &input.data().shape[1..dim_c];
166 let kernel_shape = &weight.data().shape[1..dim_c];
167 let out_shape = &out.shape[1..dim_c];
168
169 let op = ConvolutionOperation::Forward;
170
171 let input_data = Alg::into_tensor_handle(client, input.data(), dtypes.lhs_global, op)?;
172 let weight_data = Alg::into_tensor_handle(client, weight.data(), dtypes.rhs_global, op)?;
173
174 let mut input = *input;
175 let mut weight = *weight;
176
177 *input.data_mut() = input_data.as_ref();
178 *weight.data_mut() = weight_data.as_ref();
179
180 let problem = ConvolutionProblem {
181 m: n * out_shape.iter().product::<usize>(),
182 n: out_c,
183 k: c * kernel_shape.iter().product::<usize>(),
184 lhs_strides: input.data().strides.to_vec(),
185 rhs_strides: weight.data().strides.to_vec(),
186 lhs_layout: definition::MatrixLayout::RowMajor,
187 rhs_layout: definition::MatrixLayout::ColMajor,
188 kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
189 stride: stride.iter().map(|it| *it as u32).collect(),
190 padding: padding.iter().map(|it| *it as i32).collect(),
191 dilation: dilation.iter().map(|it| *it as u32).collect(),
192
193 batches: n,
194 in_shape: in_shape.to_vec(),
195 out_shape: out_shape.to_vec(),
196 channels: c,
197 out_channels: out_c,
198
199 padded_channels: c,
200 operation: op,
201
202 dimensionality,
203 global_dtypes: dtypes.as_global_elems(),
204 };
205
206 launch_kernel::<R, Alg>(client, &input, &weight, bias, out, problem, dtypes)
207}
208
209#[allow(clippy::result_large_err, clippy::too_many_arguments)]
210pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
211 client: &ComputeClient<R>,
212 input: &MatmulInputHandleRef<'_, R>,
213 weight: &MatmulInputHandleRef<'_, R>,
214 bias: &Option<MatmulInputHandleRef<'_, R>>,
215 out: &TensorHandleRef<'_, R>,
216 problem: ConvolutionProblem,
217 mut dtypes: MatmulElems,
218) -> Result<(), ConvSetupError>
219where
220 Alg::Args: ConcreteArgs,
221{
222 let plane_dim = client.properties().hardware.plane_size_max;
223 let line_sizes = AvailableLineSizes::from_type_sizes(
226 client,
227 input.data().elem_size,
228 weight.data().elem_size,
229 out.elem_size,
230 )
231 .filter_lhs_with_tensor(
232 input.data().strides,
233 input.data().shape,
234 MatrixLayout::RowMajor,
235 )
236 .filter_rhs_with_tensor(
237 weight.data().strides,
238 weight.data().shape,
239 MatrixLayout::RowMajor,
240 )
241 .filter_out_with_tensor(out.strides, out.shape);
242
243 let line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
244
245 let selection = Alg::selection(client, &problem, plane_dim, &line_sizes, &mut dtypes)?;
246 let problem = Alg::Args::adjust_problem(client, problem, &selection, &dtypes);
247
248 let config = Alg::expand_config(
249 client.properties(),
250 &problem,
251 &selection,
252 &line_sizes,
253 &dtypes,
254 )?;
255
256 let line_sizes = config.line_sizes();
257
258 launch_kernel_concrete::<R, Alg>(
259 client, input, weight, bias, out, problem, line_sizes, selection, &dtypes,
260 )
261}