1use crate::{
2 AcceleratedTileKind, ReadingStrategy, algorithm::simple::*,
3 components::global::args::RuntimeArgs,
4};
5use crate::{
6 ConvolutionArgs, Strategy, components::ConvolutionOperation, forward::args::ConcreteArgs,
7};
8use crate::{
9 algorithm::Algorithm,
10 components::{ConvolutionProblem, Dimensionality},
11};
12use crate::{components::ConvSetupError, kernels::forward::selector::launch_kernel_concrete};
13use cubecl::{Runtime, client::ComputeClient, prelude::*};
14use cubek_matmul::routines::BlueprintStrategy;
15use cubek_matmul::{
16 components::tile::{cmma::CmmaMatmul, mma::MmaMatmul},
17 definition::{AvailableVectorSizes, MatmulElems},
18};
19use cubek_std::{
20 InputBinding,
21 {MatrixLayout, tile::Strided},
22};
23use derive_new::new;
24
25macro_rules! with_tile_kind {
26 ($kind: expr, $T: ident, $launch: expr) => {
27 match $kind {
28 AcceleratedTileKind::Cmma => {
29 type $T = CmmaMatmul<Option<Strided>>;
30 ($launch)()
31 }
32 AcceleratedTileKind::Mma => {
33 type $T = MmaMatmul<Strided, Strided, Option<Strided>>;
34 ($launch)()
35 }
36 }
37 };
38}
39
40#[allow(clippy::result_large_err, clippy::too_many_arguments)]
49pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
50 strategy: &Strategy,
51 client: &ComputeClient<R>,
52 input: InputBinding<R>,
53 weight: InputBinding<R>,
54 bias: Option<InputBinding<R>>,
55 out: TensorBinding<R>,
56 args: ConvolutionArgs<N_SPATIAL>,
57 dtypes: MatmulElems,
58) -> Result<(), ConvSetupError> {
59 let conv = Convolution::new(client, input, weight, bias, out, args, dtypes);
60
61 match strategy {
62 Strategy::Simple {
63 read_strategy,
64 tile_kind,
65 } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
66 ReadingStrategy::Cyclic => conv.launch::<SimpleSyncCyclicConv<Accelerated>>(),
67 ReadingStrategy::Strided => conv.launch::<SimpleSyncStridedConv<Accelerated>>(),
68 ReadingStrategy::Tilewise => conv.launch::<SimpleSyncTilewiseConv<Accelerated>>(),
69 ReadingStrategy::AsyncCyclic => conv.launch::<SimpleAsyncCyclicConv<Accelerated>>(),
70 ReadingStrategy::AsyncStrided => conv.launch::<SimpleAsyncStridedConv<Accelerated>>(),
71 ReadingStrategy::Tma => conv.launch::<SimpleAsyncTmaConv<Accelerated>>(),
72 }),
73 }
74}
75
76#[derive(new)]
77struct Convolution<'a, R: Runtime, const N_SPATIAL: usize> {
78 client: &'a ComputeClient<R>,
79 input: InputBinding<R>,
80 weight: InputBinding<R>,
81 bias: Option<InputBinding<R>>,
82 out: TensorBinding<R>,
83 args: ConvolutionArgs<N_SPATIAL>,
84 dtypes: MatmulElems,
85}
86
87impl<'a, R: Runtime, const N_SPATIAL: usize> Convolution<'a, R, N_SPATIAL> {
88 fn launch<Alg: Algorithm>(self) -> Result<(), ConvSetupError>
89 where
90 Alg::Args: ConcreteArgs<Alg::Routine>,
91 {
92 let ConvolutionArgs {
93 stride,
94 padding,
95 dilation,
96 } = self.args;
97
98 let dimensionality = match N_SPATIAL {
99 1 => Dimensionality::Dim1,
100 2 => Dimensionality::Dim2,
101 3 => Dimensionality::Dim3,
102 other => unimplemented!("Unsupported dimensionality {other}"),
103 };
104
105 launch_with_algorithm::<R, Alg>(
106 self.client,
107 self.input,
108 self.weight,
109 self.bias,
110 self.out,
111 (&stride, &padding, &dilation),
112 dimensionality,
113 &BlueprintStrategy::Inferred(Default::default()),
114 self.dtypes,
115 )
116 }
117}
118
119#[allow(clippy::too_many_arguments)]
120fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
121 client: &ComputeClient<R>,
122 input: InputBinding<R>,
123 weight: InputBinding<R>,
124 bias: Option<InputBinding<R>>,
125 out: TensorBinding<R>,
126 (stride, padding, dilation): (&[usize], &[usize], &[usize]),
127 dimensionality: Dimensionality,
128 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
129 dtypes: MatmulElems,
130) -> Result<(), ConvSetupError>
131where
132 Alg::Args: ConcreteArgs<Alg::Routine>,
133{
134 let rank = input.data().shape.len();
135 let dim_c = rank - 1;
136
137 let n = input.data().shape[0];
138 let c = input.data().shape[dim_c];
139
140 let out_c = weight.data().shape[0];
141
142 let in_shape = &input.data().shape[1..dim_c];
143 let kernel_shape = &weight.data().shape[1..dim_c];
144 let out_shape = &out.shape[1..dim_c];
145
146 let op = ConvolutionOperation::Forward;
147
148 let input_data = Alg::correct_layout(client, input.clone().into_data(), dtypes.lhs_global, op)?;
149 let weight_data =
150 Alg::correct_layout(client, weight.clone().into_data(), dtypes.rhs_global, op)?;
151
152 let mut input = input.clone();
153 let mut weight = weight.clone();
154
155 *input.data_mut() = input_data;
156 *weight.data_mut() = weight_data;
157
158 let address_type = input
159 .required_address_type()
160 .max(weight.required_address_type())
161 .max(
162 bias.clone()
163 .map(|bias| bias.required_address_type())
164 .unwrap_or_default(),
165 )
166 .max(out.required_address_type(dtypes.acc_global.size()));
167
168 let problem = ConvolutionProblem {
169 m: n * out_shape.iter().product::<usize>(),
170 n: out_c,
171 k: c * kernel_shape.iter().product::<usize>(),
172 lhs_strides: input.data().strides.clone(),
173 rhs_strides: weight.data().strides.clone(),
174 lhs_layout: MatrixLayout::RowMajor,
175 rhs_layout: MatrixLayout::ColMajor,
176 kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
177 stride: stride.iter().map(|it| *it as u32).collect(),
178 padding: padding.iter().map(|it| *it as i32).collect(),
179 dilation: dilation.iter().map(|it| *it as u32).collect(),
180
181 batches: n,
182 in_shape: in_shape.into(),
183 out_shape: out_shape.into(),
184 channels: c,
185 out_channels: out_c,
186
187 padded_channels: c,
188 operation: op,
189
190 dimensionality,
191 global_dtypes: dtypes.as_global_elems(),
192 address_type,
193 };
194
195 launch_kernel::<R, Alg>(
196 client,
197 input,
198 weight,
199 bias,
200 out,
201 problem,
202 blueprint_strategy,
203 dtypes,
204 )
205}
206
207#[allow(clippy::result_large_err, clippy::too_many_arguments)]
208pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
209 client: &ComputeClient<R>,
210 input: InputBinding<R>,
211 weight: InputBinding<R>,
212 bias: Option<InputBinding<R>>,
213 out: TensorBinding<R>,
214 problem: ConvolutionProblem,
215 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
216 dtypes: MatmulElems,
217) -> Result<(), ConvSetupError>
218where
219 Alg::Args: ConcreteArgs<Alg::Routine>,
220{
221 let vector_sizes = AvailableVectorSizes::from_type_sizes(
224 client,
225 input.data_elem_size(),
226 weight.data_elem_size(),
227 dtypes.acc_global.size(),
228 )
229 .filter_lhs_with_tensor(
230 &input.data().strides,
231 &input.data().shape,
232 MatrixLayout::RowMajor,
233 )
234 .filter_rhs_with_tensor(
235 &weight.data().strides,
236 &weight.data().shape,
237 MatrixLayout::RowMajor,
238 )
239 .filter_out_with_tensor(&out.strides, &out.shape);
240
241 let mut vector_sizes = Alg::filter_vector_sizes(vector_sizes).pick_max()?;
242
243 if input.scale().is_some() {
246 vector_sizes.lhs = 1;
247 }
248 if weight.scale().is_some() {
249 vector_sizes.rhs = 1;
250 }
251
252 launch_kernel_concrete::<R, Alg::Args, Alg::Routine>(
253 client,
254 input,
255 weight,
256 bias,
257 out,
258 problem,
259 vector_sizes,
260 blueprint_strategy,
261 &dtypes,
262 )
263}