cubecl_convolution/homogeneous/
simple_tma.rs

1use std::marker::PhantomData;
2
3use crate::{
4    algorithm::simple_tma::check_problem_tma,
5    base::{
6        Convolution, ConvolutionConfigFactory, ConvolutionFamily, ConvolutionLaunch,
7        ConvolutionProblem, RuntimeArgs, RuntimeArgsLaunch,
8    },
9    loader::{
10        bias::BiasLoader,
11        im2col_tma::{TmaIm2colLoader, TmaIm2colTiling},
12        weight_tma::{TmaWeightLoader, TmaWeightTiling},
13    },
14};
15use cubecl_core::prelude::*;
16use cubecl_core::{
17    self as cubecl,
18    prelude::barrier::{Barrier, BarrierLevel},
19};
20use cubecl_matmul::components::{
21    AvailableLineSizes, EA, EI, EO, ES, InputRuntimeArg, MatmulLineSizes, MatmulPrecision,
22    MatmulSelection, MatmulSetupError, MatmulSpec, OutputRuntimeArg,
23    global::{
24        AccumulatorLoader, GlobalConfig,
25        load::{NoLoadingValidation, arrive_tma},
26        single_stage::tma::SimpleTmaConfig,
27    },
28    stage::{FullReaderFamily, FullStageToTileReader, StageConfig, StageMatmul, StageMatmulFamily},
29};
30use cubecl_std::{
31    CubeOption, FastDivmodArgs,
32    tensor::r#virtual::{ReadWrite, VirtualTensor},
33};
34
35use super::base::{
36    config::{self, ConvolutionConfig},
37    implicit_conv, shape_divmod,
38};
39
40/// Performs matrix multiplication at the global level, with each plane sharing the same responsibilities
41/// - All planes load data to the stage
42/// - All planes are used in the stage matmul computation
43pub struct SimpleTmaConvolution<MP: MatmulPrecision, SMM: StageMatmul<MP>> {
44    _cs: PhantomData<MP>,
45    _stage_matmul: PhantomData<SMM>,
46}
47
48#[cube]
49impl<MP: MatmulPrecision, SMM> Convolution<MP> for SimpleTmaConvolution<MP, SMM>
50where
51    SMM: StageMatmul<
52            MP,
53            LhsReader = FullStageToTileReader<MP::ES, TmaIm2colTiling>,
54            RhsReader = FullStageToTileReader<MP::ES, TmaWeightTiling>,
55        >,
56{
57    type LhsLoader = TmaIm2colLoader<MP, Self::Config>;
58    type Config = ConvolutionConfig<SimpleTmaConfig<SMM::Config>>;
59    type RhsLoader = TmaWeightLoader<MP, SMM::Config>;
60    type AccumulatorLoader = BiasLoader<MP>;
61
62    type Writer = SMM::Writer;
63    type Accumulator = SMM::Accumulator;
64
65    fn execute(
66        mut lhs_loader: Self::LhsLoader,
67        mut rhs_loader: Self::RhsLoader,
68        mut acc_loader: Self::AccumulatorLoader,
69        mut out_writer: Self::Writer,
70        acc: &mut Self::Accumulator,
71        k_range: (u32, u32),
72        #[comptime] config: Self::Config,
73    ) {
74        let k_step = config.k_step;
75        let range = k_range.1 - k_range.0;
76        #[allow(unknown_lints)] // `manual_div_ceil` only appeared in 1.83
77        #[allow(clippy::manual_div_ceil)]
78        let num_loops = (range + k_step - 1) / k_step;
79
80        let total_stage_elems = config.tiling_scheme().elements_in_stage_mk()
81            + config.tiling_scheme().elements_in_stage_nk();
82
83        Self::AccumulatorLoader::fill_stage::<Self::Config>(&mut acc_loader, config);
84        let (mut lhs_tile, mut rhs_tile) = SMM::init_tile_inputs(config.stage_config());
85
86        sync_cube();
87
88        SMM::fill_accumulator::<Self::AccumulatorLoader>(
89            &mut acc_loader,
90            acc,
91            config.stage_config(),
92        );
93
94        let barrier = Barrier::new_with_tma_proxy(BarrierLevel::cube_coop(0u32));
95
96        for _ in 0..num_loops {
97            sync_cube();
98
99            Self::LhsLoader::fill_stage(&mut lhs_loader, &barrier, 0u32, config);
100            Self::RhsLoader::fill_stage(&mut rhs_loader, &barrier, 0u32, config.stage_config());
101
102            arrive_tma::<MP::ES>(&barrier, total_stage_elems);
103
104            barrier.wait();
105
106            let lhs_stage_reader = &Self::LhsLoader::reader(&lhs_loader, 0u32);
107            let rhs_stage_reader = &Self::RhsLoader::reader(&rhs_loader, 0u32);
108
109            SMM::execute(
110                lhs_stage_reader,
111                rhs_stage_reader,
112                &mut lhs_tile,
113                &mut rhs_tile,
114                acc,
115                config.stage_config(),
116            );
117
118            Self::LhsLoader::advance_view(&mut lhs_loader, k_step);
119            Self::RhsLoader::advance_view(&mut rhs_loader, k_step);
120        }
121
122        sync_cube();
123
124        SMM::write_results::<Self::Config>(acc, &mut out_writer, config.stage_config(), config);
125    }
126
127    fn init_lhs_loader(
128        lhs: VirtualTensor<MP::EI>,
129        x_offset: u32,
130        y_offset: u32,
131        runtime_args: &RuntimeArgs,
132        #[comptime] config: Self::Config,
133    ) -> Self::LhsLoader {
134        Self::LhsLoader::new(lhs, x_offset, y_offset, runtime_args, 1u32, config)
135    }
136
137    fn init_rhs_loader(
138        rhs: VirtualTensor<MP::EI>,
139        x_offset: u32,
140        y_offset: u32,
141        runtime_args: &RuntimeArgs,
142        #[comptime] config: Self::Config,
143    ) -> Self::RhsLoader {
144        Self::RhsLoader::new::<Self::Config>(
145            rhs.as_tensor_map(),
146            x_offset,
147            y_offset,
148            CubeOption::new_None(),
149            runtime_args,
150            1u32,
151            config,
152        )
153    }
154
155    fn init_bias_loader(
156        bias: CubeOption<VirtualTensor<MP::EO>>,
157        n_offset: u32,
158        #[comptime] config: Self::Config,
159    ) -> Self::AccumulatorLoader {
160        Self::AccumulatorLoader::new::<Self::Config>(bias, n_offset, config)
161    }
162
163    fn init_writer(
164        out: VirtualTensor<MP::EO, ReadWrite>,
165        x_offset: u32,
166        y_offset: u32,
167    ) -> Self::Writer {
168        SMM::init_writer(out, x_offset, y_offset, 0)
169    }
170
171    fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator {
172        SMM::init_accumulator(config.stage_config())
173    }
174}
175
176pub struct SimpleTmaConvolutionFamily<SMM: StageMatmulFamily> {
177    _smm: PhantomData<SMM>,
178}
179
180impl<SMM> ConvolutionFamily for SimpleTmaConvolutionFamily<SMM>
181where
182    SMM: StageMatmulFamily<LhsReader = FullReaderFamily, RhsReader = FullReaderFamily>,
183{
184    type Convolution<MP: MatmulPrecision> =
185        SimpleTmaConvolution<MP, SMM::Matmul<MP, TmaIm2colTiling, TmaWeightTiling>>;
186
187    fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
188        available_line_sizes
189            .filter_lhs(|ls| *ls == 1)
190            .filter_rhs(|ls| *ls == 1)
191    }
192}
193
194impl<SMM> ConvolutionConfigFactory for SimpleTmaConvolutionFamily<SMM>
195where
196    SMM: StageMatmulFamily,
197{
198    type Config = config::ConvolutionConfig<SimpleTmaConfig<SMM::Config>>;
199
200    fn setup<R: Runtime, MP: MatmulPrecision>(
201        client: &ComputeClient<R::Server, R::Channel>,
202        problem: &ConvolutionProblem,
203        selection: &MatmulSelection,
204        line_sizes: &MatmulLineSizes,
205    ) -> Result<Self::Config, MatmulSetupError> {
206        check_problem_tma(problem)?;
207
208        // We need smem to be unlined so slicing is simpler. TMA doesn't use the vector
209        // type anyways and treats it as a void* with the actual type being set by the `TensorMap`
210        assert!(line_sizes.lhs == 1);
211        assert!(line_sizes.rhs == 1);
212
213        let stage_config = SMM::setup::<MP, R>(
214            client,
215            &problem.as_matmul_problem(),
216            selection,
217            line_sizes,
218            (1, 1).into(),
219            None,
220            false,
221        )?;
222        let stage_k = stage_config.tiling_scheme().elements_in_stage_k();
223
224        config::ConvolutionConfig::new(
225            SimpleTmaConfig::new::<NoLoadingValidation, NoLoadingValidation, MP, R>(
226                client,
227                stage_config,
228                stage_config.num_main_flow_planes(),
229                // TODO: Find the correct condition to avoid check bounds.
230                true,
231                true,
232                true,
233                stage_k,
234                selection.loading_precompute_strategy,
235                selection.loader_mode,
236            )?,
237            &problem.kernel_size,
238            &problem.stride,
239            &problem.dilation,
240            &problem.padding,
241            problem.dimensionality,
242            1,
243        )
244    }
245}
246
247impl<SMM: StageMatmulFamily<LhsReader = FullReaderFamily, RhsReader = FullReaderFamily>>
248    ConvolutionLaunch for SimpleTmaConvolutionFamily<SMM>
249{
250    unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
251        client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
252        cube_dim: CubeDim,
253        cube_count: CubeCount,
254        input: InputRuntimeArg<'a, MS, R>,
255        bias: Option<TensorArg<'a, R>>,
256        output: OutputRuntimeArg<'a, MS, R>,
257        problem: &ConvolutionProblem,
258        config: <Self as ConvolutionConfigFactory>::Config,
259    ) {
260        let padded_channels =
261            (problem.channels as u32).next_multiple_of(config.tiling_scheme().elements_in_tile_k());
262
263        let size_k = problem.kernel_size.iter().product::<u32>() * padded_channels;
264
265        let runtime_args = RuntimeArgsLaunch::new(
266            ScalarArg::new(problem.m as u32),
267            ScalarArg::new(problem.n as u32),
268            ScalarArg::new(size_k),
269            FastDivmodArgs::new(client, padded_channels),
270            shape_divmod(client, &problem.out_shape),
271        );
272
273        unsafe {
274            implicit_conv::launch_unchecked::<MS::Args, EI<MS>, ES<MS>, EA<MS>, EO<MS>, Self, R>(
275                client,
276                cube_count,
277                cube_dim,
278                input,
279                bias.into(),
280                output,
281                runtime_args,
282                config,
283            );
284        }
285    }
286}