1use std::marker::PhantomData;
2
3use crate::{
4 convolution::{
5 ConvGemmConfig,
6 base::{
7 Convolution, ConvolutionConfigFactory, ConvolutionFamily, ConvolutionLaunch,
8 ConvolutionProblem, RuntimeArgs, RuntimeArgsLaunch,
9 },
10 loader::{bias::BiasLoader, im2col::SimpleIm2colLoader},
11 },
12 matmul::components::{
13 EA, EI, EO, ES, InputIdent, InputRuntimeArg, InvalidConfigError, MatmulPrecision,
14 MatmulSpec, OutputRuntimeArg,
15 global::{
16 AccumulatorLoader, GlobalConfig,
17 load::{SyncFullLoader, sync_full_cyclic},
18 output_loader::Unloader,
19 single_stage,
20 },
21 stage::{
22 ContiguousTilingLayout, FullReader, FullReaderFamily, RowMajorTilingOrder, StageMatmul,
23 StageMatmulFamily,
24 },
25 },
26};
27use cubecl_core as cubecl;
28use cubecl_core::prelude::*;
29use cubecl_std::{
30 CubeOption, FastDivmodArgs,
31 tensor::r#virtual::{ReadWrite, VirtualTensor},
32};
33
34use super::base::{
35 config::{self, HomogeneousConfig},
36 implicit_conv,
37};
38
39pub struct SimpleConvolution<MP: MatmulPrecision, SMM: StageMatmul<MP>> {
43 _cs: PhantomData<MP>,
44 _stage_matmul: PhantomData<SMM>,
45}
46
47#[cube]
48impl<MP: MatmulPrecision, SMM> Convolution<MP> for SimpleConvolution<MP, SMM>
49where
50 SMM: StageMatmul<
51 MP,
52 LhsReader = FullReader<MP::ES, ConvTilingLayout>,
53 RhsReader = FullReader<MP::ES, ConvTilingLayout>,
54 >,
55{
56 type LhsLoader = SimpleIm2colLoader<MP, Self::Config>;
57 type Config = HomogeneousConfig<single_stage::Config<SMM::Config>>;
58 type RhsLoader =
59 SyncFullLoader<MP, SMM::Config, sync_full_cyclic::LoadingStrategy<RowMajorTilingOrder>>;
60 type AccumulatorLoader = BiasLoader<MP>;
61
62 type Out = Unloader<MP::EO>;
63 type Accumulator = SMM::Accumulator;
64
65 fn execute(
66 mut lhs_loader: Self::LhsLoader,
67 mut rhs_loader: Self::RhsLoader,
68 mut acc_loader: Self::AccumulatorLoader,
69 mut out_unloader: Self::Out,
70 acc: &mut Self::Accumulator,
71 k_range: (u32, u32),
72 #[comptime] config: Self::Config,
73 ) {
74 let k_step = config.k_step;
75 let range = k_range.1 - k_range.0;
76 #[allow(unknown_lints)] #[allow(clippy::manual_div_ceil)]
78 let num_loops = (range + k_step - 1) / k_step;
79
80 Self::AccumulatorLoader::fill_stage::<Self::Config>(&mut acc_loader, config);
81 let (mut lhs_tile, mut rhs_tile) = SMM::init_tile_inputs(config.to_smm_config());
82
83 sync_units();
84
85 SMM::fill_accumulator::<Self::AccumulatorLoader>(
86 &mut acc_loader,
87 acc,
88 config.to_smm_config(),
89 );
90
91 for _ in 0..num_loops {
92 sync_units();
93
94 Self::LhsLoader::fill_stage(&mut lhs_loader, config);
95 Self::RhsLoader::fill_stage(&mut rhs_loader, config.to_matmul_config());
96
97 let lhs_stage_reader = &Self::LhsLoader::reader(&lhs_loader);
98 let rhs_stage_reader = &Self::RhsLoader::reader(&rhs_loader);
99
100 sync_units();
101
102 SMM::execute(
103 lhs_stage_reader,
104 rhs_stage_reader,
105 &mut lhs_tile,
106 &mut rhs_tile,
107 acc,
108 config.to_smm_config(),
109 );
110
111 Self::LhsLoader::advance_view(&mut lhs_loader, k_step);
112 Self::RhsLoader::advance_view(&mut rhs_loader, k_step);
113 }
114
115 sync_units();
116
117 SMM::read_accumulator::<Self::Out, Self::Config>(
118 acc,
119 &mut out_unloader,
120 config.to_smm_config(),
121 config,
122 );
123 }
124
125 fn init_lhs_loader(
126 lhs: VirtualTensor<MP::EI>,
127 x_offset: u32,
128 y_offset: u32,
129 runtime_args: &RuntimeArgs,
130 #[comptime] config: Self::Config,
131 ) -> Self::LhsLoader {
132 Self::LhsLoader::new(lhs, x_offset, y_offset, runtime_args, config)
133 }
134
135 fn init_rhs_loader(
136 rhs: VirtualTensor<MP::EI>,
137 x_offset: u32,
138 y_offset: u32,
139 _runtime_args: &RuntimeArgs,
140 #[comptime] config: Self::Config,
141 ) -> Self::RhsLoader {
142 Self::RhsLoader::new::<Self::Config>(
143 rhs,
144 x_offset,
145 y_offset,
146 0,
147 CubeOption::new_None(),
148 InputIdent::Rhs,
149 config,
150 )
151 }
152
153 fn init_bias_loader(
154 bias: CubeOption<VirtualTensor<MP::EO>>,
155 n_offset: u32,
156 #[comptime] config: Self::Config,
157 ) -> Self::AccumulatorLoader {
158 Self::AccumulatorLoader::new::<Self::Config>(bias, n_offset, config)
159 }
160
161 fn init_unloader(
162 out: VirtualTensor<MP::EO, ReadWrite>,
163 x_offset: u32,
164 y_offset: u32,
165 ) -> Self::Out {
166 Self::Out::new(out, x_offset, y_offset, 0)
167 }
168
169 fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator {
170 SMM::init_accumulator(config.to_smm_config())
171 }
172}
173
174pub struct SimpleConvolutionFamily<SMM: StageMatmulFamily> {
175 _smm: PhantomData<SMM>,
176}
177
178pub type ConvTilingLayout = ContiguousTilingLayout<RowMajorTilingOrder>;
179
180impl<SMM> ConvolutionFamily for SimpleConvolutionFamily<SMM>
181where
182 SMM: StageMatmulFamily<LhsReader = FullReaderFamily, RhsReader = FullReaderFamily>,
183{
184 type Convolution<MP: MatmulPrecision> =
185 SimpleConvolution<MP, SMM::Matmul<MP, ConvTilingLayout, ConvTilingLayout>>;
186}
187
188impl<SMM> ConvolutionConfigFactory for SimpleConvolutionFamily<SMM>
189where
190 SMM: StageMatmulFamily,
191{
192 type Config = config::HomogeneousConfig<single_stage::Config<SMM::Config>>;
193 type Input = SMM::Input;
194
195 fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError> {
196 SMM::check_config(&config.to_smm_config())
197 }
198
199 fn make_config(
200 input: Self::Input,
201 problem: &ConvolutionProblem,
202 cube_dim: &CubeDim,
203 cube_count: &CubeCount,
204 ) -> Self::Config {
205 let smm_config = SMM::make_config(
206 input,
207 &problem.as_matmul_problem(),
208 cube_dim,
209 cube_count,
210 false,
211 );
212 let size = SMM::stage_shape(&smm_config);
213
214 config::HomogeneousConfig::new(
215 single_stage::Config::new(
216 smm_config,
217 true,
219 true,
220 true,
221 problem.lhs_layout,
222 problem.rhs_layout,
223 problem.lhs_line_size as u32,
224 problem.rhs_line_size as u32,
225 problem.out_line_size as u32,
226 size.k,
227 ),
228 problem.kernel_size,
229 problem.stride,
230 problem.dilation,
231 problem.padding,
232 )
233 }
234
235 fn check_availability<R: Runtime, MP: MatmulPrecision>(
236 client: &ComputeClient<R::Server, R::Channel>,
237 config: &Self::Config,
238 ) -> Result<(), crate::matmul::kernels::MatmulAvailabilityError> {
239 SMM::check_availability::<R, MP>(client, &config.to_smm_config())
240 }
241}
242
243impl<SMM: StageMatmulFamily<LhsReader = FullReaderFamily, RhsReader = FullReaderFamily>>
244 ConvolutionLaunch for SimpleConvolutionFamily<SMM>
245{
246 unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
247 client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
248 cube_dim: CubeDim,
249 cube_count: CubeCount,
250 input: InputRuntimeArg<'a, MS, R>,
251 bias: Option<TensorArg<'a, R>>,
252 output: OutputRuntimeArg<'a, MS, R>,
253 problem: &ConvolutionProblem,
254 config: <Self as ConvolutionConfigFactory>::Config,
255 ) {
256 let size_m = problem.batches * problem.out_h * problem.out_w;
257 let size_n = problem.n;
258 let size_k = config.kernel_size(0) * config.kernel_size(1) * problem.channels as u32;
259
260 let runtime_args = RuntimeArgsLaunch::new(
261 ScalarArg::new(size_m as u32),
262 ScalarArg::new(size_n as u32),
263 ScalarArg::new(size_k),
264 FastDivmodArgs::new(client, problem.channels as u32),
265 FastDivmodArgs::new(client, problem.out_h as u32),
266 FastDivmodArgs::new(client, problem.out_w as u32),
267 );
268
269 unsafe {
270 implicit_conv::launch_unchecked::<MS::Args, EI<MS>, ES<MS>, EA<MS>, EO<MS>, Self, R>(
271 client,
272 cube_count,
273 cube_dim,
274 input,
275 bias.into(),
276 output,
277 runtime_args,
278 config,
279 );
280 }
281 }
282}