cubecl_convolution/kernels/layered/algorithm/
simple.rs

1use cubecl_core::server::LaunchError;
2use cubecl_core::{Runtime, client::ComputeClient, ir::StorageType, prelude::TensorHandleRef};
3use cubecl_matmul::components::{
4    MatmulElems, MatmulLineSizes, MatmulSelection, MatmulSetupError, stage::StridedStageFamily,
5    tile::io::Strided,
6};
7use cubecl_matmul::components::{
8    global::args::TensorArgs, stage::PlaneMatmulFamily, tile::TileMatmulFamily,
9};
10use cubecl_matmul::components::{
11    global::read::sync_full_cyclic::SyncFullCyclicLoading,
12    stage::{ColMajorTilingOrder, NumStages, RowMajorTilingOrder},
13};
14use cubecl_std::{
15    CubeOption,
16    tensor::{TensorHandle, into_contiguous_pitched},
17};
18use std::marker::PhantomData;
19
20use crate::components::{
21    ConvolutionProblem, convolution_matmul_selection,
22    global::{
23        read::full_reader::FullLoadingStrategy, single_stage::simple::SimpleConvolutionFamily,
24    },
25};
26
27use super::Algorithm;
28
29/// Cmma convolution
30pub struct SimpleConvAlgorithm<
31    TMM: TileMatmulFamily,
32    LL: FullLoadingStrategy = SyncFullCyclicLoading<RowMajorTilingOrder>,
33    LR: FullLoadingStrategy = SyncFullCyclicLoading<ColMajorTilingOrder>,
34> {
35    _tmm: PhantomData<TMM>,
36    _loader: PhantomData<(LL, LR)>,
37}
38
39impl<
40    TMM: TileMatmulFamily<
41            LhsTile = Strided,
42            RhsTile = Strided,
43            AccTile = CubeOption<Strided>,
44            OutTile = Strided,
45        >,
46    LL: FullLoadingStrategy,
47    LR: FullLoadingStrategy<SyncStrategy = LL::SyncStrategy>,
48> Algorithm for SimpleConvAlgorithm<TMM, LL, LR>
49{
50    type TileMatmul = TMM;
51    type StageMatmul = PlaneMatmulFamily<
52        Self::TileMatmul,
53        StridedStageFamily,
54        StridedStageFamily,
55        Option<StridedStageFamily>,
56    >;
57    type GlobalConvolution = SimpleConvolutionFamily<Self::StageMatmul, LL, LR>;
58
59    type Args = TensorArgs;
60
61    fn into_tensor_handle<R: Runtime>(
62        client: &ComputeClient<R>,
63        handle: &TensorHandleRef<'_, R>,
64        dtype: StorageType,
65    ) -> Result<TensorHandle<R>, LaunchError> {
66        if has_valid_layout(handle) {
67            Ok(TensorHandle::from_ref(handle, dtype))
68        } else {
69            into_contiguous_pitched(client, handle, dtype)
70        }
71    }
72
73    // TODO this is not the same as tma stages, it's stages in the sense of double buffering in matmul
74    fn num_stages() -> NumStages {
75        (1, 1).into()
76    }
77
78    fn selection<R: Runtime>(
79        client: &ComputeClient<R>,
80        problem: &ConvolutionProblem,
81        plane_dim: u32,
82        line_sizes: &MatmulLineSizes,
83        dtypes: &mut MatmulElems,
84    ) -> Result<MatmulSelection, MatmulSetupError> {
85        Ok(convolution_matmul_selection::<TMM, R>(
86            client,
87            problem,
88            plane_dim,
89            TMM::should_swizzle(client),
90            line_sizes,
91            dtypes,
92        )?)
93    }
94}
95
96fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>) -> bool {
97    let rank = handle.shape.len();
98    let dim_c = rank - 1;
99    handle.strides[dim_c] == 1
100}