1use std::{any::TypeId, marker::PhantomData};
2
3use crate::{
4 convolution::{
5 ConvGemmConfig,
6 base::{
7 Convolution, ConvolutionConfigFactory, ConvolutionFamily, ConvolutionLaunch,
8 ConvolutionProblem, RuntimeArgs, RuntimeArgsLaunch,
9 },
10 loader::{
11 bias::BiasLoader,
12 im2col_tma::{TmaIm2colLoader, TmaIm2colTiling},
13 weight_tma::{TmaWeightLoader, TmaWeightTiling},
14 },
15 },
16 matmul::{
17 components::{
18 EA, EI, EO, ES, Ident, InputRuntimeArg, InvalidConfigError, MatmulPrecision,
19 MatmulSpec, OutputRuntimeArg,
20 global::{AccumulatorLoader, GlobalConfig, output_loader::Unloader, single_stage},
21 stage::{FullReader, FullReaderFamily, StageMatmul, StageMatmulFamily},
22 },
23 kernels::MatmulAvailabilityError,
24 },
25};
26use cubecl_core::prelude::*;
27use cubecl_core::{
28 self as cubecl,
29 prelude::barrier::{Barrier, BarrierLevel},
30};
31use cubecl_std::{
32 CubeOption, FastDivmodArgs,
33 tensor::r#virtual::{ReadWrite, VirtualTensor},
34};
35
36use super::base::{
37 config::{self, HomogeneousConfig},
38 implicit_conv,
39};
40
41pub struct SimpleTmaConvolution<MP: MatmulPrecision, SMM: StageMatmul<MP>> {
45 _cs: PhantomData<MP>,
46 _stage_matmul: PhantomData<SMM>,
47}
48
49#[cube]
50impl<MP: MatmulPrecision, SMM> Convolution<MP> for SimpleTmaConvolution<MP, SMM>
51where
52 SMM: StageMatmul<
53 MP,
54 LhsReader = FullReader<MP::ES, TmaIm2colTiling>,
55 RhsReader = FullReader<MP::ES, TmaWeightTiling>,
56 >,
57{
58 type LhsLoader = TmaIm2colLoader<MP, Self::Config>;
59 type Config = HomogeneousConfig<single_stage::Config<SMM::Config>>;
60 type RhsLoader = TmaWeightLoader<MP, SMM::Config>;
61 type AccumulatorLoader = BiasLoader<MP>;
62
63 type Out = Unloader<MP::EO>;
64 type Accumulator = SMM::Accumulator;
65
66 fn execute(
67 mut lhs_loader: Self::LhsLoader,
68 mut rhs_loader: Self::RhsLoader,
69 mut acc_loader: Self::AccumulatorLoader,
70 mut out_unloader: Self::Out,
71 acc: &mut Self::Accumulator,
72 k_range: (u32, u32),
73 #[comptime] config: Self::Config,
74 ) {
75 let k_step = config.k_step;
76 let range = k_range.1 - k_range.0;
77 #[allow(unknown_lints)] #[allow(clippy::manual_div_ceil)]
79 let num_loops = (range + k_step - 1) / k_step;
80
81 Self::AccumulatorLoader::fill_stage::<Self::Config>(&mut acc_loader, config);
82 let (mut lhs_tile, mut rhs_tile) = SMM::init_tile_inputs(config.to_smm_config());
83
84 sync_units();
85
86 SMM::fill_accumulator::<Self::AccumulatorLoader>(
87 &mut acc_loader,
88 acc,
89 config.to_smm_config(),
90 );
91
92 let barrier = Barrier::new_with_tma_proxy(BarrierLevel::cube_coop(0u32));
93
94 for _ in 0..num_loops {
95 sync_units();
96
97 Self::LhsLoader::fill_stage(&mut lhs_loader, &barrier, config);
98 Self::RhsLoader::fill_stage(&mut rhs_loader, &barrier, config.to_smm_config());
99
100 if UNIT_POS == 0 {
101 let total_stage = config.tiling_dimensions(Ident::Rhs).total_size()
102 + config.tiling_dimensions(Ident::Lhs).total_size();
103 barrier.arrive_tx(1, total_stage * MP::ES::elem_size());
104 } else {
105 barrier.arrive();
106 }
107
108 barrier.wait();
109
110 let lhs_stage_reader = &Self::LhsLoader::reader(&lhs_loader);
111 let rhs_stage_reader = &Self::RhsLoader::reader(&rhs_loader);
112
113 SMM::execute(
114 lhs_stage_reader,
115 rhs_stage_reader,
116 &mut lhs_tile,
117 &mut rhs_tile,
118 acc,
119 config.to_smm_config(),
120 );
121
122 Self::LhsLoader::advance_view(&mut lhs_loader, k_step);
123 Self::RhsLoader::advance_view(&mut rhs_loader, k_step);
124 }
125
126 sync_units();
127
128 SMM::read_accumulator::<Self::Out, Self::Config>(
129 acc,
130 &mut out_unloader,
131 config.to_smm_config(),
132 config,
133 );
134 }
135
136 fn init_lhs_loader(
137 lhs: VirtualTensor<MP::EI>,
138 x_offset: u32,
139 y_offset: u32,
140 runtime_args: &RuntimeArgs,
141 #[comptime] config: Self::Config,
142 ) -> Self::LhsLoader {
143 Self::LhsLoader::new(lhs, x_offset, y_offset, runtime_args, config)
144 }
145
146 fn init_rhs_loader(
147 rhs: VirtualTensor<MP::EI>,
148 x_offset: u32,
149 y_offset: u32,
150 runtime_args: &RuntimeArgs,
151 #[comptime] config: Self::Config,
152 ) -> Self::RhsLoader {
153 Self::RhsLoader::new::<Self::Config>(
154 rhs.as_tensor_map(),
155 x_offset,
156 y_offset,
157 CubeOption::new_None(),
158 runtime_args,
159 config,
160 )
161 }
162
163 fn init_bias_loader(
164 bias: CubeOption<VirtualTensor<MP::EO>>,
165 n_offset: u32,
166 #[comptime] config: Self::Config,
167 ) -> Self::AccumulatorLoader {
168 Self::AccumulatorLoader::new::<Self::Config>(bias, n_offset, config)
169 }
170
171 fn init_unloader(
172 out: VirtualTensor<MP::EO, ReadWrite>,
173 x_offset: u32,
174 y_offset: u32,
175 ) -> Self::Out {
176 Self::Out::new(out, x_offset, y_offset, 0)
177 }
178
179 fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator {
180 SMM::init_accumulator(config.to_smm_config())
181 }
182}
183
184pub struct SimpleTmaConvolutionFamily<SMM: StageMatmulFamily> {
185 _smm: PhantomData<SMM>,
186}
187
188impl<SMM> ConvolutionFamily for SimpleTmaConvolutionFamily<SMM>
189where
190 SMM: StageMatmulFamily<LhsReader = FullReaderFamily, RhsReader = FullReaderFamily>,
191{
192 type Convolution<MP: MatmulPrecision> =
193 SimpleTmaConvolution<MP, SMM::Matmul<MP, TmaIm2colTiling, TmaWeightTiling>>;
194}
195
196impl<SMM> ConvolutionConfigFactory for SimpleTmaConvolutionFamily<SMM>
197where
198 SMM: StageMatmulFamily,
199{
200 type Config = config::HomogeneousConfig<single_stage::Config<SMM::Config>>;
201 type Input = SMM::Input;
202
203 fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError> {
204 SMM::check_config(&config.to_smm_config())
205 }
206
207 fn make_config(
208 input: Self::Input,
209 problem: &ConvolutionProblem,
210 cube_dim: &CubeDim,
211 cube_count: &CubeCount,
212 ) -> Self::Config {
213 let mut problem = problem.clone();
214
215 problem.lhs_line_size = 1;
218 problem.rhs_line_size = 1;
219
220 let smm_config = SMM::make_config(
221 input,
222 &problem.as_matmul_problem(),
223 cube_dim,
224 cube_count,
225 false,
226 );
227 let size = SMM::stage_shape(&smm_config);
228
229 config::HomogeneousConfig::new(
230 single_stage::Config::new(
231 smm_config,
232 true,
234 true,
235 true,
236 problem.lhs_layout,
237 problem.rhs_layout,
238 problem.lhs_line_size as u32,
239 problem.rhs_line_size as u32,
240 problem.out_line_size as u32,
241 size.k,
242 ),
243 problem.kernel_size,
244 problem.stride,
245 problem.dilation,
246 problem.padding,
247 )
248 }
249
250 fn check_availability<R: Runtime, MP: MatmulPrecision>(
251 client: &ComputeClient<R::Server, R::Channel>,
252 config: &Self::Config,
253 ) -> Result<(), MatmulAvailabilityError> {
254 let id_ei = TypeId::of::<MP::EI>();
255 let id_es = TypeId::of::<MP::ES>();
256 let is_tf32 = id_ei == TypeId::of::<f32>() && id_es == TypeId::of::<tf32>();
257
258 if id_ei != id_es && !is_tf32 {
259 return Err(MatmulAvailabilityError::TmaUnavailable);
260 }
261
262 SMM::check_availability::<R, MP>(client, &config.to_smm_config())
263 }
264}
265
266impl<SMM: StageMatmulFamily<LhsReader = FullReaderFamily, RhsReader = FullReaderFamily>>
267 ConvolutionLaunch for SimpleTmaConvolutionFamily<SMM>
268{
269 unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
270 client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
271 cube_dim: CubeDim,
272 cube_count: CubeCount,
273 input: InputRuntimeArg<'a, MS, R>,
274 bias: Option<TensorArg<'a, R>>,
275 output: OutputRuntimeArg<'a, MS, R>,
276 problem: &ConvolutionProblem,
277 config: <Self as ConvolutionConfigFactory>::Config,
278 ) {
279 let tiling_dims = config.tiling_dimensions(Ident::Lhs);
280 let padded_channels =
281 (problem.channels as u32).next_multiple_of(tiling_dims.tile_shape_col());
282
283 let size_m = problem.batches * problem.out_h * problem.out_w;
284 let size_n = problem.n;
285 let size_k = config.kernel_size(0) * config.kernel_size(1) * padded_channels;
286
287 let runtime_args = RuntimeArgsLaunch::new(
288 ScalarArg::new(size_m as u32),
289 ScalarArg::new(size_n as u32),
290 ScalarArg::new(size_k),
291 FastDivmodArgs::new(client, padded_channels),
292 FastDivmodArgs::new(client, problem.out_h as u32),
293 FastDivmodArgs::new(client, problem.out_w as u32),
294 );
295
296 unsafe {
297 implicit_conv::launch_unchecked::<MS::Args, EI<MS>, ES<MS>, EA<MS>, EO<MS>, Self, R>(
298 client,
299 cube_count,
300 cube_dim,
301 input,
302 bias.into(),
303 output,
304 runtime_args,
305 config,
306 );
307 }
308 }
309}