1use std::marker::PhantomData;
2
3use crate::{
4 algorithm::simple_tma::check_problem_tma,
5 base::{
6 Convolution, ConvolutionConfigFactory, ConvolutionFamily, ConvolutionLaunch,
7 ConvolutionProblem, RuntimeArgs, RuntimeArgsLaunch,
8 },
9 loader::{
10 bias::BiasLoader,
11 im2col_tma::{TmaIm2colLoader, TmaIm2colTiling},
12 weight_tma::{TmaWeightLoader, TmaWeightTiling},
13 },
14};
15use cubecl_core::prelude::*;
16use cubecl_core::{
17 self as cubecl,
18 prelude::barrier::{Barrier, BarrierLevel},
19};
20use cubecl_matmul::components::{
21 AvailableLineSizes, EA, EI, EO, ES, InputRuntimeArg, MatmulLineSizes, MatmulPrecision,
22 MatmulSelection, MatmulSetupError, MatmulSpec, OutputRuntimeArg,
23 global::{
24 AccumulatorLoader, GlobalConfig,
25 load::{NoLoadingValidation, arrive_tma},
26 single_stage::tma::SimpleTmaConfig,
27 },
28 stage::{FullReaderFamily, FullStageToTileReader, StageConfig, StageMatmul, StageMatmulFamily},
29};
30use cubecl_std::{
31 CubeOption, FastDivmodArgs,
32 tensor::r#virtual::{ReadWrite, VirtualTensor},
33};
34
35use super::base::{
36 config::{self, ConvolutionConfig},
37 implicit_conv, shape_divmod,
38};
39
40pub struct SimpleTmaConvolution<MP: MatmulPrecision, SMM: StageMatmul<MP>> {
44 _cs: PhantomData<MP>,
45 _stage_matmul: PhantomData<SMM>,
46}
47
48#[cube]
49impl<MP: MatmulPrecision, SMM> Convolution<MP> for SimpleTmaConvolution<MP, SMM>
50where
51 SMM: StageMatmul<
52 MP,
53 LhsReader = FullStageToTileReader<MP::ES, TmaIm2colTiling>,
54 RhsReader = FullStageToTileReader<MP::ES, TmaWeightTiling>,
55 >,
56{
57 type LhsLoader = TmaIm2colLoader<MP, Self::Config>;
58 type Config = ConvolutionConfig<SimpleTmaConfig<SMM::Config>>;
59 type RhsLoader = TmaWeightLoader<MP, SMM::Config>;
60 type AccumulatorLoader = BiasLoader<MP>;
61
62 type Writer = SMM::Writer;
63 type Accumulator = SMM::Accumulator;
64
65 fn execute(
66 mut lhs_loader: Self::LhsLoader,
67 mut rhs_loader: Self::RhsLoader,
68 mut acc_loader: Self::AccumulatorLoader,
69 mut out_writer: Self::Writer,
70 acc: &mut Self::Accumulator,
71 k_range: (u32, u32),
72 #[comptime] config: Self::Config,
73 ) {
74 let k_step = config.k_step;
75 let range = k_range.1 - k_range.0;
76 #[allow(unknown_lints)] #[allow(clippy::manual_div_ceil)]
78 let num_loops = (range + k_step - 1) / k_step;
79
80 let total_stage_elems = config.tiling_scheme().elements_in_stage_mk()
81 + config.tiling_scheme().elements_in_stage_nk();
82
83 Self::AccumulatorLoader::fill_stage::<Self::Config>(&mut acc_loader, config);
84 let (mut lhs_tile, mut rhs_tile) = SMM::init_tile_inputs(config.stage_config());
85
86 sync_cube();
87
88 SMM::fill_accumulator::<Self::AccumulatorLoader>(
89 &mut acc_loader,
90 acc,
91 config.stage_config(),
92 );
93
94 let barrier = Barrier::new_with_tma_proxy(BarrierLevel::cube_coop(0u32));
95
96 for _ in 0..num_loops {
97 sync_cube();
98
99 Self::LhsLoader::fill_stage(&mut lhs_loader, &barrier, 0u32, config);
100 Self::RhsLoader::fill_stage(&mut rhs_loader, &barrier, 0u32, config.stage_config());
101
102 arrive_tma::<MP::ES>(&barrier, total_stage_elems);
103
104 barrier.wait();
105
106 let lhs_stage_reader = &Self::LhsLoader::reader(&lhs_loader, 0u32);
107 let rhs_stage_reader = &Self::RhsLoader::reader(&rhs_loader, 0u32);
108
109 SMM::execute(
110 lhs_stage_reader,
111 rhs_stage_reader,
112 &mut lhs_tile,
113 &mut rhs_tile,
114 acc,
115 config.stage_config(),
116 );
117
118 Self::LhsLoader::advance_view(&mut lhs_loader, k_step);
119 Self::RhsLoader::advance_view(&mut rhs_loader, k_step);
120 }
121
122 sync_cube();
123
124 SMM::write_results::<Self::Config>(acc, &mut out_writer, config.stage_config(), config);
125 }
126
127 fn init_lhs_loader(
128 lhs: VirtualTensor<MP::EI>,
129 x_offset: u32,
130 y_offset: u32,
131 runtime_args: &RuntimeArgs,
132 #[comptime] config: Self::Config,
133 ) -> Self::LhsLoader {
134 Self::LhsLoader::new(lhs, x_offset, y_offset, runtime_args, 1u32, config)
135 }
136
137 fn init_rhs_loader(
138 rhs: VirtualTensor<MP::EI>,
139 x_offset: u32,
140 y_offset: u32,
141 runtime_args: &RuntimeArgs,
142 #[comptime] config: Self::Config,
143 ) -> Self::RhsLoader {
144 Self::RhsLoader::new::<Self::Config>(
145 rhs.as_tensor_map(),
146 x_offset,
147 y_offset,
148 CubeOption::new_None(),
149 runtime_args,
150 1u32,
151 config,
152 )
153 }
154
155 fn init_bias_loader(
156 bias: CubeOption<VirtualTensor<MP::EO>>,
157 n_offset: u32,
158 #[comptime] config: Self::Config,
159 ) -> Self::AccumulatorLoader {
160 Self::AccumulatorLoader::new::<Self::Config>(bias, n_offset, config)
161 }
162
163 fn init_writer(
164 out: VirtualTensor<MP::EO, ReadWrite>,
165 x_offset: u32,
166 y_offset: u32,
167 ) -> Self::Writer {
168 SMM::init_writer(out, x_offset, y_offset, 0)
169 }
170
171 fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator {
172 SMM::init_accumulator(config.stage_config())
173 }
174}
175
176pub struct SimpleTmaConvolutionFamily<SMM: StageMatmulFamily> {
177 _smm: PhantomData<SMM>,
178}
179
180impl<SMM> ConvolutionFamily for SimpleTmaConvolutionFamily<SMM>
181where
182 SMM: StageMatmulFamily<LhsReader = FullReaderFamily, RhsReader = FullReaderFamily>,
183{
184 type Convolution<MP: MatmulPrecision> =
185 SimpleTmaConvolution<MP, SMM::Matmul<MP, TmaIm2colTiling, TmaWeightTiling>>;
186
187 fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
188 available_line_sizes
189 .filter_lhs(|ls| *ls == 1)
190 .filter_rhs(|ls| *ls == 1)
191 }
192}
193
194impl<SMM> ConvolutionConfigFactory for SimpleTmaConvolutionFamily<SMM>
195where
196 SMM: StageMatmulFamily,
197{
198 type Config = config::ConvolutionConfig<SimpleTmaConfig<SMM::Config>>;
199
200 fn setup<R: Runtime, MP: MatmulPrecision>(
201 client: &ComputeClient<R::Server, R::Channel>,
202 problem: &ConvolutionProblem,
203 selection: &MatmulSelection,
204 line_sizes: &MatmulLineSizes,
205 ) -> Result<Self::Config, MatmulSetupError> {
206 check_problem_tma(problem)?;
207
208 assert!(line_sizes.lhs == 1);
211 assert!(line_sizes.rhs == 1);
212
213 let stage_config = SMM::setup::<MP, R>(
214 client,
215 &problem.as_matmul_problem(),
216 selection,
217 line_sizes,
218 (1, 1).into(),
219 None,
220 false,
221 )?;
222 let stage_k = stage_config.tiling_scheme().elements_in_stage_k();
223
224 config::ConvolutionConfig::new(
225 SimpleTmaConfig::new::<NoLoadingValidation, NoLoadingValidation, MP, R>(
226 client,
227 stage_config,
228 stage_config.num_main_flow_planes(),
229 true,
231 true,
232 true,
233 stage_k,
234 selection.loading_precompute_strategy,
235 selection.loader_mode,
236 )?,
237 &problem.kernel_size,
238 &problem.stride,
239 &problem.dilation,
240 &problem.padding,
241 problem.dimensionality,
242 1,
243 )
244 }
245}
246
247impl<SMM: StageMatmulFamily<LhsReader = FullReaderFamily, RhsReader = FullReaderFamily>>
248 ConvolutionLaunch for SimpleTmaConvolutionFamily<SMM>
249{
250 unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
251 client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
252 cube_dim: CubeDim,
253 cube_count: CubeCount,
254 input: InputRuntimeArg<'a, MS, R>,
255 bias: Option<TensorArg<'a, R>>,
256 output: OutputRuntimeArg<'a, MS, R>,
257 problem: &ConvolutionProblem,
258 config: <Self as ConvolutionConfigFactory>::Config,
259 ) {
260 let padded_channels =
261 (problem.channels as u32).next_multiple_of(config.tiling_scheme().elements_in_tile_k());
262
263 let size_k = problem.kernel_size.iter().product::<u32>() * padded_channels;
264
265 let runtime_args = RuntimeArgsLaunch::new(
266 ScalarArg::new(problem.m as u32),
267 ScalarArg::new(problem.n as u32),
268 ScalarArg::new(size_k),
269 FastDivmodArgs::new(client, padded_channels),
270 shape_divmod(client, &problem.out_shape),
271 );
272
273 unsafe {
274 implicit_conv::launch_unchecked::<MS::Args, EI<MS>, ES<MS>, EA<MS>, EO<MS>, Self, R>(
275 client,
276 cube_count,
277 cube_dim,
278 input,
279 bias.into(),
280 output,
281 runtime_args,
282 config,
283 );
284 }
285 }
286}