cubecl_linalg/convolution/homogeneous/
simple.rs

1use std::marker::PhantomData;
2
3use crate::{
4    convolution::{
5        ConvGemmConfig,
6        base::{
7            Convolution, ConvolutionConfigFactory, ConvolutionFamily, ConvolutionLaunch,
8            ConvolutionProblem, RuntimeArgs, RuntimeArgsLaunch,
9        },
10        loader::{bias::BiasLoader, im2col::SimpleIm2colLoader},
11    },
12    matmul::components::{
13        EA, EI, EO, ES, InputIdent, InputRuntimeArg, InvalidConfigError, MatmulPrecision,
14        MatmulSpec, OutputRuntimeArg,
15        global::{
16            AccumulatorLoader, GlobalConfig,
17            load::{SyncFullLoader, sync_full_cyclic},
18            output_loader::Unloader,
19            single_stage,
20        },
21        stage::{
22            ContiguousTilingLayout, FullReader, FullReaderFamily, RowMajorTilingOrder, StageMatmul,
23            StageMatmulFamily,
24        },
25    },
26};
27use cubecl_core as cubecl;
28use cubecl_core::prelude::*;
29use cubecl_std::{
30    CubeOption, FastDivmodArgs,
31    tensor::r#virtual::{ReadWrite, VirtualTensor},
32};
33
34use super::base::{
35    config::{self, HomogeneousConfig},
36    implicit_conv,
37};
38
39/// Performs matrix multiplication at the global level, with each plane sharing the same responsibilities
40/// - All planes load data to the stage
41/// - All planes are used in the stage matmul computation
42pub struct SimpleConvolution<MP: MatmulPrecision, SMM: StageMatmul<MP>> {
43    _cs: PhantomData<MP>,
44    _stage_matmul: PhantomData<SMM>,
45}
46
47#[cube]
48impl<MP: MatmulPrecision, SMM> Convolution<MP> for SimpleConvolution<MP, SMM>
49where
50    SMM: StageMatmul<
51            MP,
52            LhsReader = FullReader<MP::ES, ConvTilingLayout>,
53            RhsReader = FullReader<MP::ES, ConvTilingLayout>,
54        >,
55{
56    type LhsLoader = SimpleIm2colLoader<MP, Self::Config>;
57    type Config = HomogeneousConfig<single_stage::Config<SMM::Config>>;
58    type RhsLoader =
59        SyncFullLoader<MP, SMM::Config, sync_full_cyclic::LoadingStrategy<RowMajorTilingOrder>>;
60    type AccumulatorLoader = BiasLoader<MP>;
61
62    type Out = Unloader<MP::EO>;
63    type Accumulator = SMM::Accumulator;
64
65    fn execute(
66        mut lhs_loader: Self::LhsLoader,
67        mut rhs_loader: Self::RhsLoader,
68        mut acc_loader: Self::AccumulatorLoader,
69        mut out_unloader: Self::Out,
70        acc: &mut Self::Accumulator,
71        k_range: (u32, u32),
72        #[comptime] config: Self::Config,
73    ) {
74        let k_step = config.k_step;
75        let range = k_range.1 - k_range.0;
76        #[allow(unknown_lints)] // `manual_div_ceil` only appeared in 1.83
77        #[allow(clippy::manual_div_ceil)]
78        let num_loops = (range + k_step - 1) / k_step;
79
80        Self::AccumulatorLoader::fill_stage::<Self::Config>(&mut acc_loader, config);
81        let (mut lhs_tile, mut rhs_tile) = SMM::init_tile_inputs(config.to_smm_config());
82
83        sync_units();
84
85        SMM::fill_accumulator::<Self::AccumulatorLoader>(
86            &mut acc_loader,
87            acc,
88            config.to_smm_config(),
89        );
90
91        for _ in 0..num_loops {
92            sync_units();
93
94            Self::LhsLoader::fill_stage(&mut lhs_loader, config);
95            Self::RhsLoader::fill_stage(&mut rhs_loader, config.to_matmul_config());
96
97            let lhs_stage_reader = &Self::LhsLoader::reader(&lhs_loader);
98            let rhs_stage_reader = &Self::RhsLoader::reader(&rhs_loader);
99
100            sync_units();
101
102            SMM::execute(
103                lhs_stage_reader,
104                rhs_stage_reader,
105                &mut lhs_tile,
106                &mut rhs_tile,
107                acc,
108                config.to_smm_config(),
109            );
110
111            Self::LhsLoader::advance_view(&mut lhs_loader, k_step);
112            Self::RhsLoader::advance_view(&mut rhs_loader, k_step);
113        }
114
115        sync_units();
116
117        SMM::read_accumulator::<Self::Out, Self::Config>(
118            acc,
119            &mut out_unloader,
120            config.to_smm_config(),
121            config,
122        );
123    }
124
125    fn init_lhs_loader(
126        lhs: VirtualTensor<MP::EI>,
127        x_offset: u32,
128        y_offset: u32,
129        runtime_args: &RuntimeArgs,
130        #[comptime] config: Self::Config,
131    ) -> Self::LhsLoader {
132        Self::LhsLoader::new(lhs, x_offset, y_offset, runtime_args, config)
133    }
134
135    fn init_rhs_loader(
136        rhs: VirtualTensor<MP::EI>,
137        x_offset: u32,
138        y_offset: u32,
139        _runtime_args: &RuntimeArgs,
140        #[comptime] config: Self::Config,
141    ) -> Self::RhsLoader {
142        Self::RhsLoader::new::<Self::Config>(
143            rhs,
144            x_offset,
145            y_offset,
146            0,
147            CubeOption::new_None(),
148            InputIdent::Rhs,
149            config,
150        )
151    }
152
153    fn init_bias_loader(
154        bias: CubeOption<VirtualTensor<MP::EO>>,
155        n_offset: u32,
156        #[comptime] config: Self::Config,
157    ) -> Self::AccumulatorLoader {
158        Self::AccumulatorLoader::new::<Self::Config>(bias, n_offset, config)
159    }
160
161    fn init_unloader(
162        out: VirtualTensor<MP::EO, ReadWrite>,
163        x_offset: u32,
164        y_offset: u32,
165    ) -> Self::Out {
166        Self::Out::new(out, x_offset, y_offset, 0)
167    }
168
169    fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator {
170        SMM::init_accumulator(config.to_smm_config())
171    }
172}
173
174pub struct SimpleConvolutionFamily<SMM: StageMatmulFamily> {
175    _smm: PhantomData<SMM>,
176}
177
178pub type ConvTilingLayout = ContiguousTilingLayout<RowMajorTilingOrder>;
179
180impl<SMM> ConvolutionFamily for SimpleConvolutionFamily<SMM>
181where
182    SMM: StageMatmulFamily<LhsReader = FullReaderFamily, RhsReader = FullReaderFamily>,
183{
184    type Convolution<MP: MatmulPrecision> =
185        SimpleConvolution<MP, SMM::Matmul<MP, ConvTilingLayout, ConvTilingLayout>>;
186}
187
188impl<SMM> ConvolutionConfigFactory for SimpleConvolutionFamily<SMM>
189where
190    SMM: StageMatmulFamily,
191{
192    type Config = config::HomogeneousConfig<single_stage::Config<SMM::Config>>;
193    type Input = SMM::Input;
194
195    fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError> {
196        SMM::check_config(&config.to_smm_config())
197    }
198
199    fn make_config(
200        input: Self::Input,
201        problem: &ConvolutionProblem,
202        cube_dim: &CubeDim,
203        cube_count: &CubeCount,
204    ) -> Self::Config {
205        let smm_config = SMM::make_config(
206            input,
207            &problem.as_matmul_problem(),
208            cube_dim,
209            cube_count,
210            false,
211        );
212        let size = SMM::stage_shape(&smm_config);
213
214        config::HomogeneousConfig::new(
215            single_stage::Config::new(
216                smm_config,
217                // TODO: Find the correct condition to avoid check bounds.
218                true,
219                true,
220                true,
221                problem.lhs_layout,
222                problem.rhs_layout,
223                problem.lhs_line_size as u32,
224                problem.rhs_line_size as u32,
225                problem.out_line_size as u32,
226                size.k,
227            ),
228            problem.kernel_size,
229            problem.stride,
230            problem.dilation,
231            problem.padding,
232        )
233    }
234
235    fn check_availability<R: Runtime, MP: MatmulPrecision>(
236        client: &ComputeClient<R::Server, R::Channel>,
237        config: &Self::Config,
238    ) -> Result<(), crate::matmul::kernels::MatmulAvailabilityError> {
239        SMM::check_availability::<R, MP>(client, &config.to_smm_config())
240    }
241}
242
243impl<SMM: StageMatmulFamily<LhsReader = FullReaderFamily, RhsReader = FullReaderFamily>>
244    ConvolutionLaunch for SimpleConvolutionFamily<SMM>
245{
246    unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
247        client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
248        cube_dim: CubeDim,
249        cube_count: CubeCount,
250        input: InputRuntimeArg<'a, MS, R>,
251        bias: Option<TensorArg<'a, R>>,
252        output: OutputRuntimeArg<'a, MS, R>,
253        problem: &ConvolutionProblem,
254        config: <Self as ConvolutionConfigFactory>::Config,
255    ) {
256        let size_m = problem.batches * problem.out_h * problem.out_w;
257        let size_n = problem.n;
258        let size_k = config.kernel_size(0) * config.kernel_size(1) * problem.channels as u32;
259
260        let runtime_args = RuntimeArgsLaunch::new(
261            ScalarArg::new(size_m as u32),
262            ScalarArg::new(size_n as u32),
263            ScalarArg::new(size_k),
264            FastDivmodArgs::new(client, problem.channels as u32),
265            FastDivmodArgs::new(client, problem.out_h as u32),
266            FastDivmodArgs::new(client, problem.out_w as u32),
267        );
268
269        unsafe {
270            implicit_conv::launch_unchecked::<MS::Args, EI<MS>, ES<MS>, EA<MS>, EO<MS>, Self, R>(
271                client,
272                cube_count,
273                cube_dim,
274                input,
275                bias.into(),
276                output,
277                runtime_args,
278                config,
279            );
280        }
281    }
282}