cubecl_convolution/homogeneous/
simple.rs

1use std::marker::PhantomData;
2
3use crate::{
4    base::{
5        Convolution, ConvolutionConfigFactory, ConvolutionFamily, ConvolutionLaunch,
6        ConvolutionProblem, RuntimeArgs, RuntimeArgsLaunch,
7    },
8    loader::{bias::BiasLoader, im2col::SimpleIm2colLoader},
9};
10use cubecl_core as cubecl;
11use cubecl_core::prelude::*;
12use cubecl_matmul::components::{
13    AvailableLineSizes, EA, EI, EO, ES, InputIdent, InputRuntimeArg, MatmulLineSizes,
14    MatmulPrecision, MatmulSelection, MatmulSetupError, MatmulSpec, OutputRuntimeArg,
15    global::{
16        AccumulatorLoader, GlobalConfig,
17        load::{NoLoadingValidation, SyncFullLoader, sync_full_cyclic},
18        single_stage::simple::SimpleConfig,
19    },
20    stage::{
21        ContiguousTilingLayout, FullReaderFamily, FullStageToTileReader, RowMajorTilingOrder,
22        StageConfig, StageMatmul, StageMatmulFamily,
23    },
24};
25use cubecl_std::{
26    CubeOption, FastDivmodArgs,
27    tensor::r#virtual::{ReadWrite, VirtualTensor},
28};
29
30use super::base::{
31    config::{self, ConvolutionConfig},
32    implicit_conv, shape_divmod,
33};
34
35/// Performs matrix multiplication at the global level, with each plane sharing the same responsibilities
36/// - All planes load data to the stage
37/// - All planes are used in the stage matmul computation
38pub struct SimpleConvolution<MP: MatmulPrecision, SMM: StageMatmul<MP>> {
39    _cs: PhantomData<MP>,
40    _stage_matmul: PhantomData<SMM>,
41}
42
43#[cube]
44impl<MP: MatmulPrecision, SMM> Convolution<MP> for SimpleConvolution<MP, SMM>
45where
46    SMM: StageMatmul<
47            MP,
48            LhsReader = FullStageToTileReader<MP::ES, ConvTilingLayout>,
49            RhsReader = FullStageToTileReader<MP::ES, ConvTilingLayout>,
50        >,
51{
52    type LhsLoader = SimpleIm2colLoader<MP, Self::Config>;
53    type Config = ConvolutionConfig<SimpleConfig<SMM::Config>>;
54    type RhsLoader = SyncFullLoader<
55        MP,
56        Self::Config,
57        sync_full_cyclic::SyncFullCyclicLoading<RowMajorTilingOrder>,
58    >;
59    type AccumulatorLoader = BiasLoader<MP>;
60
61    type Writer = SMM::Writer;
62    type Accumulator = SMM::Accumulator;
63
64    fn execute(
65        mut lhs_loader: Self::LhsLoader,
66        mut rhs_loader: Self::RhsLoader,
67        mut acc_loader: Self::AccumulatorLoader,
68        mut out_writer: Self::Writer,
69        acc: &mut Self::Accumulator,
70        k_range: (u32, u32),
71        #[comptime] config: Self::Config,
72    ) {
73        let k_step = config.k_step;
74        let range = k_range.1 - k_range.0;
75        #[allow(unknown_lints)] // `manual_div_ceil` only appeared in 1.83
76        #[allow(clippy::manual_div_ceil)]
77        let num_loops = (range + k_step - 1) / k_step;
78
79        Self::AccumulatorLoader::fill_stage::<Self::Config>(&mut acc_loader, config);
80        let (mut lhs_tile, mut rhs_tile) = SMM::init_tile_inputs(config.stage_config());
81
82        sync_cube();
83
84        SMM::fill_accumulator::<Self::AccumulatorLoader>(
85            &mut acc_loader,
86            acc,
87            config.stage_config(),
88        );
89
90        for _ in 0..num_loops {
91            sync_cube();
92
93            Self::LhsLoader::fill_stage(&mut lhs_loader, config);
94            Self::RhsLoader::fill_stage(&mut rhs_loader, config);
95
96            let lhs_stage_reader = &Self::LhsLoader::reader(&lhs_loader);
97            let rhs_stage_reader = &Self::RhsLoader::reader(&rhs_loader);
98
99            sync_cube();
100
101            SMM::execute(
102                lhs_stage_reader,
103                rhs_stage_reader,
104                &mut lhs_tile,
105                &mut rhs_tile,
106                acc,
107                config.stage_config(),
108            );
109
110            Self::LhsLoader::advance_view(&mut lhs_loader, k_step);
111            Self::RhsLoader::advance_view(&mut rhs_loader, k_step);
112        }
113
114        sync_cube();
115
116        SMM::write_results::<Self::Config>(acc, &mut out_writer, config.stage_config(), config);
117    }
118
119    fn init_lhs_loader(
120        lhs: VirtualTensor<MP::EI>,
121        x_offset: u32,
122        y_offset: u32,
123        runtime_args: &RuntimeArgs,
124        #[comptime] config: Self::Config,
125    ) -> Self::LhsLoader {
126        Self::LhsLoader::new(lhs, x_offset, y_offset, runtime_args, config)
127    }
128
129    fn init_rhs_loader(
130        rhs: VirtualTensor<MP::EI>,
131        x_offset: u32,
132        y_offset: u32,
133        _runtime_args: &RuntimeArgs,
134        #[comptime] config: Self::Config,
135    ) -> Self::RhsLoader {
136        Self::RhsLoader::new(
137            rhs,
138            x_offset,
139            y_offset,
140            0,
141            CubeOption::new_None(),
142            InputIdent::Rhs,
143            config,
144        )
145    }
146
147    fn init_bias_loader(
148        bias: CubeOption<VirtualTensor<MP::EO>>,
149        n_offset: u32,
150        #[comptime] config: Self::Config,
151    ) -> Self::AccumulatorLoader {
152        Self::AccumulatorLoader::new::<Self::Config>(bias, n_offset, config)
153    }
154
155    fn init_writer(
156        out: VirtualTensor<MP::EO, ReadWrite>,
157        x_offset: u32,
158        y_offset: u32,
159    ) -> Self::Writer {
160        SMM::init_writer(out, x_offset, y_offset, 0)
161    }
162
163    fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator {
164        SMM::init_accumulator(config.stage_config())
165    }
166}
167
168pub struct SimpleConvolutionFamily<SMM: StageMatmulFamily> {
169    _smm: PhantomData<SMM>,
170}
171
172pub type ConvTilingLayout = ContiguousTilingLayout<RowMajorTilingOrder>;
173
174impl<SMM> ConvolutionFamily for SimpleConvolutionFamily<SMM>
175where
176    SMM: StageMatmulFamily<LhsReader = FullReaderFamily, RhsReader = FullReaderFamily>,
177{
178    type Convolution<MP: MatmulPrecision> =
179        SimpleConvolution<MP, SMM::Matmul<MP, ConvTilingLayout, ConvTilingLayout>>;
180
181    fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
182        available_line_sizes
183    }
184}
185
186impl<SMM> ConvolutionConfigFactory for SimpleConvolutionFamily<SMM>
187where
188    SMM: StageMatmulFamily,
189{
190    type Config = config::ConvolutionConfig<SimpleConfig<SMM::Config>>;
191
192    fn setup<R: Runtime, MP: MatmulPrecision>(
193        client: &ComputeClient<R::Server, R::Channel>,
194        problem: &ConvolutionProblem,
195        selection: &MatmulSelection,
196        line_sizes: &MatmulLineSizes,
197    ) -> Result<Self::Config, MatmulSetupError> {
198        let stage_config = SMM::setup::<MP, R>(
199            client,
200            &problem.as_matmul_problem(),
201            selection,
202            line_sizes,
203            (1, 1).into(),
204            None,
205            false,
206        )?;
207        let stage_k = stage_config.tiling_scheme().elements_in_stage_k();
208
209        config::ConvolutionConfig::new(
210            SimpleConfig::new::<NoLoadingValidation, NoLoadingValidation, MP, R>(
211                client,
212                stage_config,
213                stage_config.num_main_flow_planes(),
214                true,
215                true,
216                true,
217                stage_k,
218                selection.loading_precompute_strategy,
219                selection.loader_mode,
220            )?,
221            &problem.kernel_size,
222            &problem.stride,
223            &problem.dilation,
224            &problem.padding,
225            problem.dimensionality,
226            1,
227        )
228    }
229}
230
231impl<SMM: StageMatmulFamily<LhsReader = FullReaderFamily, RhsReader = FullReaderFamily>>
232    ConvolutionLaunch for SimpleConvolutionFamily<SMM>
233{
234    unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
235        client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
236        cube_dim: CubeDim,
237        cube_count: CubeCount,
238        input: InputRuntimeArg<'a, MS, R>,
239        bias: Option<TensorArg<'a, R>>,
240        output: OutputRuntimeArg<'a, MS, R>,
241        problem: &ConvolutionProblem,
242        config: <Self as ConvolutionConfigFactory>::Config,
243    ) {
244        let runtime_args = RuntimeArgsLaunch::new(
245            ScalarArg::new(problem.m as u32),
246            ScalarArg::new(problem.n as u32),
247            ScalarArg::new(problem.k as u32),
248            FastDivmodArgs::new(client, problem.channels as u32),
249            shape_divmod(client, &problem.out_shape),
250        );
251
252        unsafe {
253            implicit_conv::launch_unchecked::<MS::Args, EI<MS>, ES<MS>, EA<MS>, EO<MS>, Self, R>(
254                client,
255                cube_count,
256                cube_dim,
257                input,
258                bias.into(),
259                output,
260                runtime_args,
261                config,
262            );
263        }
264    }
265}