1use std::marker::PhantomData;
2
3use crate::{
4 base::{
5 Convolution, ConvolutionConfigFactory, ConvolutionFamily, ConvolutionLaunch,
6 ConvolutionProblem, RuntimeArgs, RuntimeArgsLaunch,
7 },
8 loader::{bias::BiasLoader, im2col::SimpleIm2colLoader},
9};
10use cubecl_core as cubecl;
11use cubecl_core::prelude::*;
12use cubecl_matmul::components::{
13 AvailableLineSizes, EA, EI, EO, ES, InputIdent, InputRuntimeArg, MatmulLineSizes,
14 MatmulPrecision, MatmulSelection, MatmulSetupError, MatmulSpec, OutputRuntimeArg,
15 global::{
16 AccumulatorLoader, GlobalConfig,
17 load::{NoLoadingValidation, SyncFullLoader, sync_full_cyclic},
18 single_stage::simple::SimpleConfig,
19 },
20 stage::{
21 ContiguousTilingLayout, FullReaderFamily, FullStageToTileReader, RowMajorTilingOrder,
22 StageConfig, StageMatmul, StageMatmulFamily,
23 },
24};
25use cubecl_std::{
26 CubeOption, FastDivmodArgs,
27 tensor::r#virtual::{ReadWrite, VirtualTensor},
28};
29
30use super::base::{
31 config::{self, ConvolutionConfig},
32 implicit_conv, shape_divmod,
33};
34
35pub struct SimpleConvolution<MP: MatmulPrecision, SMM: StageMatmul<MP>> {
39 _cs: PhantomData<MP>,
40 _stage_matmul: PhantomData<SMM>,
41}
42
43#[cube]
44impl<MP: MatmulPrecision, SMM> Convolution<MP> for SimpleConvolution<MP, SMM>
45where
46 SMM: StageMatmul<
47 MP,
48 LhsReader = FullStageToTileReader<MP::ES, ConvTilingLayout>,
49 RhsReader = FullStageToTileReader<MP::ES, ConvTilingLayout>,
50 >,
51{
52 type LhsLoader = SimpleIm2colLoader<MP, Self::Config>;
53 type Config = ConvolutionConfig<SimpleConfig<SMM::Config>>;
54 type RhsLoader = SyncFullLoader<
55 MP,
56 Self::Config,
57 sync_full_cyclic::SyncFullCyclicLoading<RowMajorTilingOrder>,
58 >;
59 type AccumulatorLoader = BiasLoader<MP>;
60
61 type Writer = SMM::Writer;
62 type Accumulator = SMM::Accumulator;
63
64 fn execute(
65 mut lhs_loader: Self::LhsLoader,
66 mut rhs_loader: Self::RhsLoader,
67 mut acc_loader: Self::AccumulatorLoader,
68 mut out_writer: Self::Writer,
69 acc: &mut Self::Accumulator,
70 k_range: (u32, u32),
71 #[comptime] config: Self::Config,
72 ) {
73 let k_step = config.k_step;
74 let range = k_range.1 - k_range.0;
75 #[allow(unknown_lints)] #[allow(clippy::manual_div_ceil)]
77 let num_loops = (range + k_step - 1) / k_step;
78
79 Self::AccumulatorLoader::fill_stage::<Self::Config>(&mut acc_loader, config);
80 let (mut lhs_tile, mut rhs_tile) = SMM::init_tile_inputs(config.stage_config());
81
82 sync_cube();
83
84 SMM::fill_accumulator::<Self::AccumulatorLoader>(
85 &mut acc_loader,
86 acc,
87 config.stage_config(),
88 );
89
90 for _ in 0..num_loops {
91 sync_cube();
92
93 Self::LhsLoader::fill_stage(&mut lhs_loader, config);
94 Self::RhsLoader::fill_stage(&mut rhs_loader, config);
95
96 let lhs_stage_reader = &Self::LhsLoader::reader(&lhs_loader);
97 let rhs_stage_reader = &Self::RhsLoader::reader(&rhs_loader);
98
99 sync_cube();
100
101 SMM::execute(
102 lhs_stage_reader,
103 rhs_stage_reader,
104 &mut lhs_tile,
105 &mut rhs_tile,
106 acc,
107 config.stage_config(),
108 );
109
110 Self::LhsLoader::advance_view(&mut lhs_loader, k_step);
111 Self::RhsLoader::advance_view(&mut rhs_loader, k_step);
112 }
113
114 sync_cube();
115
116 SMM::write_results::<Self::Config>(acc, &mut out_writer, config.stage_config(), config);
117 }
118
119 fn init_lhs_loader(
120 lhs: VirtualTensor<MP::EI>,
121 x_offset: u32,
122 y_offset: u32,
123 runtime_args: &RuntimeArgs,
124 #[comptime] config: Self::Config,
125 ) -> Self::LhsLoader {
126 Self::LhsLoader::new(lhs, x_offset, y_offset, runtime_args, config)
127 }
128
129 fn init_rhs_loader(
130 rhs: VirtualTensor<MP::EI>,
131 x_offset: u32,
132 y_offset: u32,
133 _runtime_args: &RuntimeArgs,
134 #[comptime] config: Self::Config,
135 ) -> Self::RhsLoader {
136 Self::RhsLoader::new(
137 rhs,
138 x_offset,
139 y_offset,
140 0,
141 CubeOption::new_None(),
142 InputIdent::Rhs,
143 config,
144 )
145 }
146
147 fn init_bias_loader(
148 bias: CubeOption<VirtualTensor<MP::EO>>,
149 n_offset: u32,
150 #[comptime] config: Self::Config,
151 ) -> Self::AccumulatorLoader {
152 Self::AccumulatorLoader::new::<Self::Config>(bias, n_offset, config)
153 }
154
155 fn init_writer(
156 out: VirtualTensor<MP::EO, ReadWrite>,
157 x_offset: u32,
158 y_offset: u32,
159 ) -> Self::Writer {
160 SMM::init_writer(out, x_offset, y_offset, 0)
161 }
162
163 fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator {
164 SMM::init_accumulator(config.stage_config())
165 }
166}
167
168pub struct SimpleConvolutionFamily<SMM: StageMatmulFamily> {
169 _smm: PhantomData<SMM>,
170}
171
172pub type ConvTilingLayout = ContiguousTilingLayout<RowMajorTilingOrder>;
173
174impl<SMM> ConvolutionFamily for SimpleConvolutionFamily<SMM>
175where
176 SMM: StageMatmulFamily<LhsReader = FullReaderFamily, RhsReader = FullReaderFamily>,
177{
178 type Convolution<MP: MatmulPrecision> =
179 SimpleConvolution<MP, SMM::Matmul<MP, ConvTilingLayout, ConvTilingLayout>>;
180
181 fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
182 available_line_sizes
183 }
184}
185
186impl<SMM> ConvolutionConfigFactory for SimpleConvolutionFamily<SMM>
187where
188 SMM: StageMatmulFamily,
189{
190 type Config = config::ConvolutionConfig<SimpleConfig<SMM::Config>>;
191
192 fn setup<R: Runtime, MP: MatmulPrecision>(
193 client: &ComputeClient<R::Server, R::Channel>,
194 problem: &ConvolutionProblem,
195 selection: &MatmulSelection,
196 line_sizes: &MatmulLineSizes,
197 ) -> Result<Self::Config, MatmulSetupError> {
198 let stage_config = SMM::setup::<MP, R>(
199 client,
200 &problem.as_matmul_problem(),
201 selection,
202 line_sizes,
203 (1, 1).into(),
204 None,
205 false,
206 )?;
207 let stage_k = stage_config.tiling_scheme().elements_in_stage_k();
208
209 config::ConvolutionConfig::new(
210 SimpleConfig::new::<NoLoadingValidation, NoLoadingValidation, MP, R>(
211 client,
212 stage_config,
213 stage_config.num_main_flow_planes(),
214 true,
215 true,
216 true,
217 stage_k,
218 selection.loading_precompute_strategy,
219 selection.loader_mode,
220 )?,
221 &problem.kernel_size,
222 &problem.stride,
223 &problem.dilation,
224 &problem.padding,
225 problem.dimensionality,
226 1,
227 )
228 }
229}
230
231impl<SMM: StageMatmulFamily<LhsReader = FullReaderFamily, RhsReader = FullReaderFamily>>
232 ConvolutionLaunch for SimpleConvolutionFamily<SMM>
233{
234 unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
235 client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
236 cube_dim: CubeDim,
237 cube_count: CubeCount,
238 input: InputRuntimeArg<'a, MS, R>,
239 bias: Option<TensorArg<'a, R>>,
240 output: OutputRuntimeArg<'a, MS, R>,
241 problem: &ConvolutionProblem,
242 config: <Self as ConvolutionConfigFactory>::Config,
243 ) {
244 let runtime_args = RuntimeArgsLaunch::new(
245 ScalarArg::new(problem.m as u32),
246 ScalarArg::new(problem.n as u32),
247 ScalarArg::new(problem.k as u32),
248 FastDivmodArgs::new(client, problem.channels as u32),
249 shape_divmod(client, &problem.out_shape),
250 );
251
252 unsafe {
253 implicit_conv::launch_unchecked::<MS::Args, EI<MS>, ES<MS>, EA<MS>, EO<MS>, Self, R>(
254 client,
255 cube_count,
256 cube_dim,
257 input,
258 bias.into(),
259 output,
260 runtime_args,
261 config,
262 );
263 }
264 }
265}