1use crate::{
2 AcceleratedTileKind, ReadingStrategy, algorithm::simple::*,
3 components::global::args::RuntimeArgs,
4};
5use crate::{
6 ConvolutionArgs, Strategy, components::ConvolutionOperation, forward::args::ConcreteArgs,
7};
8use crate::{
9 algorithm::Algorithm,
10 components::{ConvolutionProblem, Dimensionality},
11};
12use crate::{components::ConvSetupError, kernels::forward::selector::launch_kernel_concrete};
13use cubecl::{
14 Runtime,
15 client::ComputeClient,
16 prelude::*,
17 std::{CubeOption, tensor::TensorHandle},
18};
19use cubek_matmul::{
20 components::tile::{cmma::CmmaMatmul, io::Strided, mma::MmaMatmul},
21 definition::{AvailableLineSizes, MatmulElems, MatrixLayout},
22};
23use cubek_matmul::{definition, launch::MatmulInputHandleRef};
24use cubek_matmul::{launch::MatmulInputHandle, routines::BlueprintStrategy};
25use derive_new::new;
26
27macro_rules! with_tile_kind {
28 ($kind: expr, $T: ident, $launch: expr) => {
29 match $kind {
30 AcceleratedTileKind::Cmma => {
31 type $T = CmmaMatmul<CubeOption<Strided>>;
32 ($launch)()
33 }
34 AcceleratedTileKind::Mma => {
35 type $T = MmaMatmul<Strided, Strided, CubeOption<Strided>>;
36 ($launch)()
37 }
38 }
39 };
40}
41
42#[allow(clippy::result_large_err, clippy::too_many_arguments)]
43pub fn launch<R: Runtime, const N_SPATIAL: usize>(
44 strategy: &Strategy,
45 client: &ComputeClient<R>,
46 input: MatmulInputHandle<R>,
47 weight: MatmulInputHandle<R>,
48 bias: Option<MatmulInputHandle<R>>,
49 out: TensorHandle<R>,
50 args: ConvolutionArgs<N_SPATIAL>,
51 dtypes: MatmulElems,
52) -> Result<(), ConvSetupError> {
53 launch_ref(
54 strategy,
55 client,
56 &input.as_ref(),
57 &weight.as_ref(),
58 &bias.as_ref().map(|it| it.as_ref()),
59 &out.as_ref(),
60 args,
61 dtypes,
62 )
63}
64
65#[allow(clippy::result_large_err, clippy::too_many_arguments)]
74pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
75 strategy: &Strategy,
76 client: &ComputeClient<R>,
77 input: &MatmulInputHandleRef<'_, R>,
78 weight: &MatmulInputHandleRef<'_, R>,
79 bias: &Option<MatmulInputHandleRef<'_, R>>,
80 out: &TensorHandleRef<'_, R>,
81 args: ConvolutionArgs<N_SPATIAL>,
82 dtypes: MatmulElems,
83) -> Result<(), ConvSetupError> {
84 let conv = Convolution::new(client, input, weight, bias, out, args, dtypes);
85
86 match strategy {
87 Strategy::Simple {
88 read_strategy,
89 tile_kind,
90 } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
91 ReadingStrategy::Cyclic => conv.launch::<SimpleSyncCyclicConv<Accelerated>>(),
92 ReadingStrategy::Strided => conv.launch::<SimpleSyncStridedConv<Accelerated>>(),
93 ReadingStrategy::Tilewise => conv.launch::<SimpleSyncTilewiseConv<Accelerated>>(),
94 ReadingStrategy::AsyncCyclic => conv.launch::<SimpleAsyncCyclicConv<Accelerated>>(),
95 ReadingStrategy::AsyncStrided => conv.launch::<SimpleAsyncStridedConv<Accelerated>>(),
96 ReadingStrategy::Tma => conv.launch::<SimpleAsyncTmaConv<Accelerated>>(),
97 }),
98 }
99}
100
101#[derive(new)]
102struct Convolution<'a, R: Runtime, const N_SPATIAL: usize> {
103 client: &'a ComputeClient<R>,
104 input: &'a MatmulInputHandleRef<'a, R>,
105 weight: &'a MatmulInputHandleRef<'a, R>,
106 bias: &'a Option<MatmulInputHandleRef<'a, R>>,
107 out: &'a TensorHandleRef<'a, R>,
108 args: ConvolutionArgs<N_SPATIAL>,
109 dtypes: MatmulElems,
110}
111
112impl<'a, R: Runtime, const N_SPATIAL: usize> Convolution<'a, R, N_SPATIAL> {
113 fn launch<Alg: Algorithm>(self) -> Result<(), ConvSetupError>
114 where
115 Alg::Args: ConcreteArgs<Alg::Routine>,
116 {
117 let ConvolutionArgs {
118 stride,
119 padding,
120 dilation,
121 } = self.args;
122
123 let dimensionality = match N_SPATIAL {
124 1 => Dimensionality::Dim1,
125 2 => Dimensionality::Dim2,
126 3 => Dimensionality::Dim3,
127 other => unimplemented!("Unsupported dimensionality {other}"),
128 };
129
130 launch_with_algorithm::<R, Alg>(
131 self.client,
132 self.input,
133 self.weight,
134 self.bias,
135 self.out,
136 (&stride, &padding, &dilation),
137 dimensionality,
138 &BlueprintStrategy::Inferred(Default::default()),
139 self.dtypes,
140 )
141 }
142}
143
144#[allow(clippy::too_many_arguments)]
145fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
146 client: &ComputeClient<R>,
147 input: &MatmulInputHandleRef<'_, R>,
148 weight: &MatmulInputHandleRef<'_, R>,
149 bias: &Option<MatmulInputHandleRef<'_, R>>,
150 out: &TensorHandleRef<'_, R>,
151 (stride, padding, dilation): (&[usize], &[usize], &[usize]),
152 dimensionality: Dimensionality,
153 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
154 dtypes: MatmulElems,
155) -> Result<(), ConvSetupError>
156where
157 Alg::Args: ConcreteArgs<Alg::Routine>,
158{
159 let rank = input.data().shape.len();
160 let dim_c = rank - 1;
161
162 let n = input.data().shape[0];
163 let c = input.data().shape[dim_c];
164
165 let out_c = weight.data().shape[0];
166
167 let in_shape = &input.data().shape[1..dim_c];
168 let kernel_shape = &weight.data().shape[1..dim_c];
169 let out_shape = &out.shape[1..dim_c];
170
171 let op = ConvolutionOperation::Forward;
172
173 let input_data = Alg::into_tensor_handle(client, input.data(), dtypes.lhs_global, op)?;
174 let weight_data = Alg::into_tensor_handle(client, weight.data(), dtypes.rhs_global, op)?;
175
176 let mut input = *input;
177 let mut weight = *weight;
178
179 *input.data_mut() = input_data.as_ref();
180 *weight.data_mut() = weight_data.as_ref();
181
182 let problem = ConvolutionProblem {
183 m: n * out_shape.iter().product::<usize>(),
184 n: out_c,
185 k: c * kernel_shape.iter().product::<usize>(),
186 lhs_strides: input.data().strides.to_vec(),
187 rhs_strides: weight.data().strides.to_vec(),
188 lhs_layout: definition::MatrixLayout::RowMajor,
189 rhs_layout: definition::MatrixLayout::ColMajor,
190 kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
191 stride: stride.iter().map(|it| *it as u32).collect(),
192 padding: padding.iter().map(|it| *it as i32).collect(),
193 dilation: dilation.iter().map(|it| *it as u32).collect(),
194
195 batches: n,
196 in_shape: in_shape.to_vec(),
197 out_shape: out_shape.to_vec(),
198 channels: c,
199 out_channels: out_c,
200
201 padded_channels: c,
202 operation: op,
203
204 dimensionality,
205 global_dtypes: dtypes.as_global_elems(),
206 };
207
208 launch_kernel::<R, Alg>(
209 client,
210 &input,
211 &weight,
212 bias,
213 out,
214 problem,
215 blueprint_strategy,
216 dtypes,
217 )
218}
219
220#[allow(clippy::result_large_err, clippy::too_many_arguments)]
221pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
222 client: &ComputeClient<R>,
223 input: &MatmulInputHandleRef<'_, R>,
224 weight: &MatmulInputHandleRef<'_, R>,
225 bias: &Option<MatmulInputHandleRef<'_, R>>,
226 out: &TensorHandleRef<'_, R>,
227 problem: ConvolutionProblem,
228 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
229 dtypes: MatmulElems,
230) -> Result<(), ConvSetupError>
231where
232 Alg::Args: ConcreteArgs<Alg::Routine>,
233{
234 let line_sizes = AvailableLineSizes::from_type_sizes(
237 client,
238 input.data().elem_size,
239 weight.data().elem_size,
240 out.elem_size,
241 )
242 .filter_lhs_with_tensor(
243 input.data().strides,
244 input.data().shape,
245 MatrixLayout::RowMajor,
246 )
247 .filter_rhs_with_tensor(
248 weight.data().strides,
249 weight.data().shape,
250 MatrixLayout::RowMajor,
251 )
252 .filter_out_with_tensor(out.strides, out.shape);
253
254 let mut line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
255
256 if input.scale().is_some() {
259 line_sizes.lhs = 1;
260 }
261 if weight.scale().is_some() {
262 line_sizes.rhs = 1;
263 }
264
265 launch_kernel_concrete::<R, Alg::Args, Alg::Routine>(
266 client,
267 input,
268 weight,
269 bias,
270 out,
271 problem,
272 line_sizes,
273 blueprint_strategy,
274 &dtypes,
275 )
276}