1use crate::{
2 AcceleratedTileKind, ReadingStrategy, algorithm::simple::*,
3 components::global::args::RuntimeArgs,
4};
5use crate::{
6 ConvolutionArgs, Strategy, components::ConvolutionOperation, forward::args::ConcreteArgs,
7};
8use crate::{
9 algorithm::Algorithm,
10 components::{ConvolutionProblem, Dimensionality},
11};
12use crate::{components::ConvSetupError, kernels::forward::selector::launch_kernel_concrete};
13use cubecl::{Runtime, client::ComputeClient, prelude::*};
14use cubek_matmul::routines::{BlueprintStrategy, TilingArgs};
15use cubek_matmul::{
16 components::tile_matmul::DispatchTileMatmul,
17 definition::{AvailableVectorSizes, MatmulElems},
18 routines::Routine,
19};
20use cubek_std::{InputBinding, MatrixLayout};
21use derive_new::new;
22
23fn tile_kind_to_dispatch(kind: &AcceleratedTileKind) -> DispatchTileMatmul {
24 match kind {
25 AcceleratedTileKind::Cmma => DispatchTileMatmul::Cmma,
26 AcceleratedTileKind::Mma => DispatchTileMatmul::Mma,
27 }
28}
29
30#[allow(clippy::result_large_err, clippy::too_many_arguments)]
39pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
40 strategy: &Strategy,
41 client: &ComputeClient<R>,
42 input: InputBinding<R>,
43 weight: InputBinding<R>,
44 bias: Option<InputBinding<R>>,
45 out: TensorBinding<R>,
46 args: ConvolutionArgs<N_SPATIAL>,
47 dtypes: MatmulElems,
48) -> Result<(), ConvSetupError> {
49 let conv = Convolution::new(client, input, weight, bias, out, args, dtypes);
50
51 match strategy {
52 Strategy::Simple {
53 read_strategy,
54 tile_kind,
55 } => {
56 let kind = tile_kind_to_dispatch(tile_kind);
57 match read_strategy {
58 ReadingStrategy::Cyclic => conv.launch::<SimpleSyncCyclicConv>(kind),
59 ReadingStrategy::Strided => conv.launch::<SimpleSyncStridedConv>(kind),
60 ReadingStrategy::Tilewise => conv.launch::<SimpleSyncTilewiseConv>(kind),
61 ReadingStrategy::AsyncCyclic => conv.launch::<SimpleAsyncCyclicConv>(kind),
62 ReadingStrategy::AsyncStrided => conv.launch::<SimpleAsyncStridedConv>(kind),
63 ReadingStrategy::Tma => conv.launch::<SimpleAsyncTmaConv>(kind),
64 }
65 }
66 }
67}
68
69#[derive(new)]
70struct Convolution<'a, R: Runtime, const N_SPATIAL: usize> {
71 client: &'a ComputeClient<R>,
72 input: InputBinding<R>,
73 weight: InputBinding<R>,
74 bias: Option<InputBinding<R>>,
75 out: TensorBinding<R>,
76 args: ConvolutionArgs<N_SPATIAL>,
77 dtypes: MatmulElems,
78}
79
80impl<'a, R: Runtime, const N_SPATIAL: usize> Convolution<'a, R, N_SPATIAL> {
81 fn launch<Alg: Algorithm>(self, tile_matmul: DispatchTileMatmul) -> Result<(), ConvSetupError>
82 where
83 Alg::Args: ConcreteArgs<Alg::Routine>,
84 <Alg::Routine as Routine<RuntimeArgs>>::Strategy: TilingArgs,
85 {
86 let ConvolutionArgs {
87 stride,
88 padding,
89 dilation,
90 } = self.args;
91
92 let dimensionality = match N_SPATIAL {
93 1 => Dimensionality::Dim1,
94 2 => Dimensionality::Dim2,
95 3 => Dimensionality::Dim3,
96 other => unimplemented!("Unsupported dimensionality {other}"),
97 };
98
99 let mut args = <Alg::Routine as Routine<RuntimeArgs>>::Strategy::default();
100 args.set_tile_matmul(tile_matmul);
101
102 launch_with_algorithm::<R, Alg>(
103 self.client,
104 self.input,
105 self.weight,
106 self.bias,
107 self.out,
108 (&stride, &padding, &dilation),
109 dimensionality,
110 &BlueprintStrategy::Inferred(args),
111 self.dtypes,
112 )
113 }
114}
115
116#[allow(clippy::too_many_arguments)]
117fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
118 client: &ComputeClient<R>,
119 input: InputBinding<R>,
120 weight: InputBinding<R>,
121 bias: Option<InputBinding<R>>,
122 out: TensorBinding<R>,
123 (stride, padding, dilation): (&[usize], &[usize], &[usize]),
124 dimensionality: Dimensionality,
125 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
126 dtypes: MatmulElems,
127) -> Result<(), ConvSetupError>
128where
129 Alg::Args: ConcreteArgs<Alg::Routine>,
130{
131 let rank = input.data().shape.len();
132 let dim_c = rank - 1;
133
134 let n = input.data().shape[0];
135 let c = input.data().shape[dim_c];
136
137 let out_c = weight.data().shape[0];
138
139 let in_shape = &input.data().shape[1..dim_c];
140 let kernel_shape = &weight.data().shape[1..dim_c];
141 let out_shape = &out.shape[1..dim_c];
142
143 let op = ConvolutionOperation::Forward;
144
145 let input_data = Alg::correct_layout(client, input.clone().into_data(), dtypes.lhs_global, op)?;
146 let weight_data =
147 Alg::correct_layout(client, weight.clone().into_data(), dtypes.rhs_global, op)?;
148
149 let mut input = input.clone();
150 let mut weight = weight.clone();
151
152 *input.data_mut() = input_data;
153 *weight.data_mut() = weight_data;
154
155 let address_type = input
156 .required_address_type()
157 .max(weight.required_address_type())
158 .max(
159 bias.clone()
160 .map(|bias| bias.required_address_type())
161 .unwrap_or_default(),
162 )
163 .max(out.required_address_type(dtypes.acc_global.size()));
164
165 let problem = ConvolutionProblem {
166 m: n * out_shape.iter().product::<usize>(),
167 n: out_c,
168 k: c * kernel_shape.iter().product::<usize>(),
169 lhs_strides: input.data().strides.clone(),
170 rhs_strides: weight.data().strides.clone(),
171 lhs_layout: MatrixLayout::RowMajor,
172 rhs_layout: MatrixLayout::ColMajor,
173 kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
174 stride: stride.iter().map(|it| *it as u32).collect(),
175 padding: padding.iter().map(|it| *it as i32).collect(),
176 dilation: dilation.iter().map(|it| *it as u32).collect(),
177
178 batches: n,
179 in_shape: in_shape.into(),
180 out_shape: out_shape.into(),
181 channels: c,
182 out_channels: out_c,
183
184 padded_channels: c,
185 operation: op,
186
187 dimensionality,
188 global_dtypes: dtypes.as_global_elems(),
189 address_type,
190 };
191
192 launch_kernel::<R, Alg>(
193 client,
194 input,
195 weight,
196 bias,
197 out,
198 problem,
199 blueprint_strategy,
200 dtypes,
201 )
202}
203
204#[allow(clippy::result_large_err, clippy::too_many_arguments)]
205pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
206 client: &ComputeClient<R>,
207 input: InputBinding<R>,
208 weight: InputBinding<R>,
209 bias: Option<InputBinding<R>>,
210 out: TensorBinding<R>,
211 problem: ConvolutionProblem,
212 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
213 dtypes: MatmulElems,
214) -> Result<(), ConvSetupError>
215where
216 Alg::Args: ConcreteArgs<Alg::Routine>,
217{
218 let vector_sizes = AvailableVectorSizes::from_type_sizes(
221 client,
222 input.data_elem_size(),
223 weight.data_elem_size(),
224 dtypes.acc_global.size(),
225 )
226 .filter_lhs_with_tensor(
227 &input.data().strides,
228 &input.data().shape,
229 MatrixLayout::RowMajor,
230 )
231 .filter_rhs_with_tensor(
232 &weight.data().strides,
233 &weight.data().shape,
234 MatrixLayout::RowMajor,
235 )
236 .filter_out_with_tensor(&out.strides, &out.shape);
237
238 let mut vector_sizes = Alg::filter_vector_sizes(vector_sizes).pick_max()?;
239
240 if input.scale().is_some() {
243 vector_sizes.lhs = 1;
244 }
245 if weight.scale().is_some() {
246 vector_sizes.rhs = 1;
247 }
248
249 launch_kernel_concrete::<R, Alg::Args, Alg::Routine>(
250 client,
251 input,
252 weight,
253 bias,
254 out,
255 problem,
256 vector_sizes,
257 blueprint_strategy,
258 &dtypes,
259 )
260}