cubecl_linalg/convolution/homogeneous/
simple_tma.rs

1use std::{any::TypeId, marker::PhantomData};
2
3use crate::{
4    convolution::{
5        ConvGemmConfig,
6        base::{
7            Convolution, ConvolutionConfigFactory, ConvolutionFamily, ConvolutionLaunch,
8            ConvolutionProblem, RuntimeArgs, RuntimeArgsLaunch,
9        },
10        loader::{
11            bias::BiasLoader,
12            im2col_tma::{TmaIm2colLoader, TmaIm2colTiling},
13            weight_tma::{TmaWeightLoader, TmaWeightTiling},
14        },
15    },
16    matmul::{
17        components::{
18            EA, EI, EO, ES, Ident, InputRuntimeArg, InvalidConfigError, MatmulPrecision,
19            MatmulSpec, OutputRuntimeArg,
20            global::{AccumulatorLoader, GlobalConfig, output_loader::Unloader, single_stage},
21            stage::{FullReader, FullReaderFamily, StageMatmul, StageMatmulFamily},
22        },
23        kernels::MatmulAvailabilityError,
24    },
25};
26use cubecl_core::prelude::*;
27use cubecl_core::{
28    self as cubecl,
29    prelude::barrier::{Barrier, BarrierLevel},
30};
31use cubecl_std::{
32    CubeOption, FastDivmodArgs,
33    tensor::r#virtual::{ReadWrite, VirtualTensor},
34};
35
36use super::base::{
37    config::{self, HomogeneousConfig},
38    implicit_conv,
39};
40
41/// Performs matrix multiplication at the global level, with each plane sharing the same responsibilities
42/// - All planes load data to the stage
43/// - All planes are used in the stage matmul computation
44pub struct SimpleTmaConvolution<MP: MatmulPrecision, SMM: StageMatmul<MP>> {
45    _cs: PhantomData<MP>,
46    _stage_matmul: PhantomData<SMM>,
47}
48
49#[cube]
50impl<MP: MatmulPrecision, SMM> Convolution<MP> for SimpleTmaConvolution<MP, SMM>
51where
52    SMM: StageMatmul<
53            MP,
54            LhsReader = FullReader<MP::ES, TmaIm2colTiling>,
55            RhsReader = FullReader<MP::ES, TmaWeightTiling>,
56        >,
57{
58    type LhsLoader = TmaIm2colLoader<MP, Self::Config>;
59    type Config = HomogeneousConfig<single_stage::Config<SMM::Config>>;
60    type RhsLoader = TmaWeightLoader<MP, SMM::Config>;
61    type AccumulatorLoader = BiasLoader<MP>;
62
63    type Out = Unloader<MP::EO>;
64    type Accumulator = SMM::Accumulator;
65
66    fn execute(
67        mut lhs_loader: Self::LhsLoader,
68        mut rhs_loader: Self::RhsLoader,
69        mut acc_loader: Self::AccumulatorLoader,
70        mut out_unloader: Self::Out,
71        acc: &mut Self::Accumulator,
72        k_range: (u32, u32),
73        #[comptime] config: Self::Config,
74    ) {
75        let k_step = config.k_step;
76        let range = k_range.1 - k_range.0;
77        #[allow(unknown_lints)] // `manual_div_ceil` only appeared in 1.83
78        #[allow(clippy::manual_div_ceil)]
79        let num_loops = (range + k_step - 1) / k_step;
80
81        Self::AccumulatorLoader::fill_stage::<Self::Config>(&mut acc_loader, config);
82        let (mut lhs_tile, mut rhs_tile) = SMM::init_tile_inputs(config.to_smm_config());
83
84        sync_units();
85
86        SMM::fill_accumulator::<Self::AccumulatorLoader>(
87            &mut acc_loader,
88            acc,
89            config.to_smm_config(),
90        );
91
92        let barrier = Barrier::new_with_tma_proxy(BarrierLevel::cube_coop(0u32));
93
94        for _ in 0..num_loops {
95            sync_units();
96
97            Self::LhsLoader::fill_stage(&mut lhs_loader, &barrier, config);
98            Self::RhsLoader::fill_stage(&mut rhs_loader, &barrier, config.to_smm_config());
99
100            if UNIT_POS == 0 {
101                let total_stage = config.tiling_dimensions(Ident::Rhs).total_size()
102                    + config.tiling_dimensions(Ident::Lhs).total_size();
103                barrier.arrive_tx(1, total_stage * MP::ES::elem_size());
104            } else {
105                barrier.arrive();
106            }
107
108            barrier.wait();
109
110            let lhs_stage_reader = &Self::LhsLoader::reader(&lhs_loader);
111            let rhs_stage_reader = &Self::RhsLoader::reader(&rhs_loader);
112
113            SMM::execute(
114                lhs_stage_reader,
115                rhs_stage_reader,
116                &mut lhs_tile,
117                &mut rhs_tile,
118                acc,
119                config.to_smm_config(),
120            );
121
122            Self::LhsLoader::advance_view(&mut lhs_loader, k_step);
123            Self::RhsLoader::advance_view(&mut rhs_loader, k_step);
124        }
125
126        sync_units();
127
128        SMM::read_accumulator::<Self::Out, Self::Config>(
129            acc,
130            &mut out_unloader,
131            config.to_smm_config(),
132            config,
133        );
134    }
135
136    fn init_lhs_loader(
137        lhs: VirtualTensor<MP::EI>,
138        x_offset: u32,
139        y_offset: u32,
140        runtime_args: &RuntimeArgs,
141        #[comptime] config: Self::Config,
142    ) -> Self::LhsLoader {
143        Self::LhsLoader::new(lhs, x_offset, y_offset, runtime_args, config)
144    }
145
146    fn init_rhs_loader(
147        rhs: VirtualTensor<MP::EI>,
148        x_offset: u32,
149        y_offset: u32,
150        runtime_args: &RuntimeArgs,
151        #[comptime] config: Self::Config,
152    ) -> Self::RhsLoader {
153        Self::RhsLoader::new::<Self::Config>(
154            rhs.as_tensor_map(),
155            x_offset,
156            y_offset,
157            CubeOption::new_None(),
158            runtime_args,
159            config,
160        )
161    }
162
163    fn init_bias_loader(
164        bias: CubeOption<VirtualTensor<MP::EO>>,
165        n_offset: u32,
166        #[comptime] config: Self::Config,
167    ) -> Self::AccumulatorLoader {
168        Self::AccumulatorLoader::new::<Self::Config>(bias, n_offset, config)
169    }
170
171    fn init_unloader(
172        out: VirtualTensor<MP::EO, ReadWrite>,
173        x_offset: u32,
174        y_offset: u32,
175    ) -> Self::Out {
176        Self::Out::new(out, x_offset, y_offset, 0)
177    }
178
179    fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator {
180        SMM::init_accumulator(config.to_smm_config())
181    }
182}
183
184pub struct SimpleTmaConvolutionFamily<SMM: StageMatmulFamily> {
185    _smm: PhantomData<SMM>,
186}
187
188impl<SMM> ConvolutionFamily for SimpleTmaConvolutionFamily<SMM>
189where
190    SMM: StageMatmulFamily<LhsReader = FullReaderFamily, RhsReader = FullReaderFamily>,
191{
192    type Convolution<MP: MatmulPrecision> =
193        SimpleTmaConvolution<MP, SMM::Matmul<MP, TmaIm2colTiling, TmaWeightTiling>>;
194}
195
196impl<SMM> ConvolutionConfigFactory for SimpleTmaConvolutionFamily<SMM>
197where
198    SMM: StageMatmulFamily,
199{
200    type Config = config::HomogeneousConfig<single_stage::Config<SMM::Config>>;
201    type Input = SMM::Input;
202
203    fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError> {
204        SMM::check_config(&config.to_smm_config())
205    }
206
207    fn make_config(
208        input: Self::Input,
209        problem: &ConvolutionProblem,
210        cube_dim: &CubeDim,
211        cube_count: &CubeCount,
212    ) -> Self::Config {
213        let mut problem = problem.clone();
214
215        // We need smem to be unlined so slicing is simpler. TMA doesn't use the vector
216        // type anyways and treats it as a void* with the actual type being set by the `TensorMap`
217        problem.lhs_line_size = 1;
218        problem.rhs_line_size = 1;
219
220        let smm_config = SMM::make_config(
221            input,
222            &problem.as_matmul_problem(),
223            cube_dim,
224            cube_count,
225            false,
226        );
227        let size = SMM::stage_shape(&smm_config);
228
229        config::HomogeneousConfig::new(
230            single_stage::Config::new(
231                smm_config,
232                // TODO: Find the correct condition to avoid check bounds.
233                true,
234                true,
235                true,
236                problem.lhs_layout,
237                problem.rhs_layout,
238                problem.lhs_line_size as u32,
239                problem.rhs_line_size as u32,
240                problem.out_line_size as u32,
241                size.k,
242            ),
243            problem.kernel_size,
244            problem.stride,
245            problem.dilation,
246            problem.padding,
247        )
248    }
249
250    fn check_availability<R: Runtime, MP: MatmulPrecision>(
251        client: &ComputeClient<R::Server, R::Channel>,
252        config: &Self::Config,
253    ) -> Result<(), MatmulAvailabilityError> {
254        let id_ei = TypeId::of::<MP::EI>();
255        let id_es = TypeId::of::<MP::ES>();
256        let is_tf32 = id_ei == TypeId::of::<f32>() && id_es == TypeId::of::<tf32>();
257
258        if id_ei != id_es && !is_tf32 {
259            return Err(MatmulAvailabilityError::TmaUnavailable);
260        }
261
262        SMM::check_availability::<R, MP>(client, &config.to_smm_config())
263    }
264}
265
266impl<SMM: StageMatmulFamily<LhsReader = FullReaderFamily, RhsReader = FullReaderFamily>>
267    ConvolutionLaunch for SimpleTmaConvolutionFamily<SMM>
268{
269    unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
270        client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
271        cube_dim: CubeDim,
272        cube_count: CubeCount,
273        input: InputRuntimeArg<'a, MS, R>,
274        bias: Option<TensorArg<'a, R>>,
275        output: OutputRuntimeArg<'a, MS, R>,
276        problem: &ConvolutionProblem,
277        config: <Self as ConvolutionConfigFactory>::Config,
278    ) {
279        let tiling_dims = config.tiling_dimensions(Ident::Lhs);
280        let padded_channels =
281            (problem.channels as u32).next_multiple_of(tiling_dims.tile_shape_col());
282
283        let size_m = problem.batches * problem.out_h * problem.out_w;
284        let size_n = problem.n;
285        let size_k = config.kernel_size(0) * config.kernel_size(1) * padded_channels;
286
287        let runtime_args = RuntimeArgsLaunch::new(
288            ScalarArg::new(size_m as u32),
289            ScalarArg::new(size_n as u32),
290            ScalarArg::new(size_k),
291            FastDivmodArgs::new(client, padded_channels),
292            FastDivmodArgs::new(client, problem.out_h as u32),
293            FastDivmodArgs::new(client, problem.out_w as u32),
294        );
295
296        unsafe {
297            implicit_conv::launch_unchecked::<MS::Args, EI<MS>, ES<MS>, EA<MS>, EO<MS>, Self, R>(
298                client,
299                cube_count,
300                cube_dim,
301                input,
302                bias.into(),
303                output,
304                runtime_args,
305                config,
306            );
307        }
308    }
309}