cubek_convolution/kernels/forward/algorithm/
simple.rs

1use cubecl::server::LaunchError;
2use cubecl::std::{CubeOption, tensor::TensorHandle};
3use cubecl::{Runtime, client::ComputeClient, ir::StorageType, prelude::TensorHandleRef};
4use cubek_matmul::components::tile::TileMatmulFamily;
5use cubek_matmul::components::{
6    global::read::sync_full_cyclic::SyncFullCyclicLoading,
7    stage::{ColMajorTilingOrder, RowMajorTilingOrder},
8};
9use cubek_matmul::components::{
10    global::read::{
11        async_full_tma::AsyncFullTmaLoading, sync_full_strided::SyncFullStridedLoading,
12        sync_full_tilewise::SyncFullTilewiseLoading,
13    },
14    stage::StridedStageFamily,
15    tile::io::Strided,
16};
17use cubek_matmul::definition::{MatmulElems, MatmulLineSizes, MatmulSetupError, TilingBlueprint};
18use cubek_matmul::launch::{TensorArgs, TensorMapArgs};
19use cubek_matmul::{components::stage::PlaneMatmulFamily, definition::AvailableLineSizes};
20use std::marker::PhantomData;
21
22use crate::{
23    components::{
24        ConvolutionOperation, ConvolutionProblem, convolution_matmul_selection,
25        global::{
26            read::{
27                full_reader::FullLoadingStrategy,
28                strategy::{
29                    async_full_cyclic::AsyncFullCyclicLoading,
30                    async_full_strided::AsyncFullStridedLoading,
31                },
32            },
33            single_stage::simple::SimpleConvolutionFamily,
34        },
35    },
36    kernels::forward::{into_tensor_handle, into_tensor_handle_tma},
37};
38
39use super::Algorithm;
40
41/// Cmma convolution
42pub struct SimpleConv<TMM: TileMatmulFamily, LL: FullLoadingStrategy, LR: FullLoadingStrategy> {
43    _tmm: PhantomData<TMM>,
44    _loader: PhantomData<(LL, LR)>,
45}
46
47pub type SimpleSyncCyclicConv<TMM> = SimpleConv<
48    TMM,
49    SyncFullCyclicLoading<RowMajorTilingOrder>,
50    SyncFullCyclicLoading<ColMajorTilingOrder>,
51>;
52pub type SimpleSyncStridedConv<TMM> =
53    SimpleConv<TMM, SyncFullStridedLoading, SyncFullStridedLoading>;
54pub type SimpleSyncTilewiseConv<TMM> = SimpleConv<
55    TMM,
56    SyncFullTilewiseLoading<RowMajorTilingOrder>,
57    SyncFullTilewiseLoading<ColMajorTilingOrder>,
58>;
59pub type SimpleAsyncCyclicConv<TMM> = SimpleConv<
60    TMM,
61    AsyncFullCyclicLoading<RowMajorTilingOrder>,
62    AsyncFullCyclicLoading<ColMajorTilingOrder>,
63>;
64pub type SimpleAsyncStridedConv<TMM> =
65    SimpleConv<TMM, AsyncFullStridedLoading, AsyncFullStridedLoading>;
66
67pub struct SimpleAsyncTmaConv<TMM: TileMatmulFamily> {
68    _tmm: PhantomData<TMM>,
69}
70
71impl<
72    TMM: TileMatmulFamily<
73            LhsTile = Strided,
74            RhsTile = Strided,
75            AccTile = CubeOption<Strided>,
76            OutTile = Strided,
77        >,
78    LL: FullLoadingStrategy,
79    LR: FullLoadingStrategy<SyncStrategy = LL::SyncStrategy>,
80> Algorithm for SimpleConv<TMM, LL, LR>
81{
82    type TileMatmul = TMM;
83    type StageMatmul = PlaneMatmulFamily<
84        Self::TileMatmul,
85        StridedStageFamily,
86        StridedStageFamily,
87        Option<StridedStageFamily>,
88    >;
89    type GlobalConvolution = SimpleConvolutionFamily<Self::StageMatmul, LL, LR>;
90
91    type Args = TensorArgs;
92
93    fn into_tensor_handle<R: Runtime>(
94        client: &ComputeClient<R>,
95        handle: &TensorHandleRef<'_, R>,
96        dtype: StorageType,
97        _operation: ConvolutionOperation,
98    ) -> Result<TensorHandle<R>, LaunchError> {
99        into_tensor_handle(client, handle, dtype)
100    }
101
102    fn selection<R: Runtime>(
103        client: &ComputeClient<R>,
104        problem: &ConvolutionProblem,
105        plane_dim: u32,
106        line_sizes: &MatmulLineSizes,
107        dtypes: &mut MatmulElems,
108    ) -> Result<TilingBlueprint, MatmulSetupError> {
109        Ok(convolution_matmul_selection::<TMM, R>(
110            client,
111            problem,
112            plane_dim,
113            TMM::should_swizzle(client),
114            line_sizes,
115            dtypes,
116        )?)
117    }
118}
119
120impl<
121    TMM: TileMatmulFamily<
122            LhsTile = Strided,
123            RhsTile = Strided,
124            AccTile = CubeOption<Strided>,
125            OutTile = Strided,
126        >,
127> Algorithm for SimpleAsyncTmaConv<TMM>
128{
129    type TileMatmul = TMM;
130    type StageMatmul = PlaneMatmulFamily<
131        Self::TileMatmul,
132        StridedStageFamily,
133        StridedStageFamily,
134        Option<StridedStageFamily>,
135    >;
136    type GlobalConvolution =
137        SimpleConvolutionFamily<Self::StageMatmul, AsyncFullTmaLoading, AsyncFullTmaLoading>;
138
139    type Args = TensorMapArgs;
140
141    fn into_tensor_handle<R: Runtime>(
142        client: &ComputeClient<R>,
143        handle: &TensorHandleRef<'_, R>,
144        dtype: StorageType,
145        operation: ConvolutionOperation,
146    ) -> Result<TensorHandle<R>, LaunchError> {
147        into_tensor_handle_tma(client, handle, dtype, operation)
148    }
149
150    fn filter_line_sizes(line_sizes: AvailableLineSizes) -> AvailableLineSizes {
151        AvailableLineSizes {
152            lhs: vec![1],
153            rhs: vec![1],
154            out: line_sizes.out,
155        }
156    }
157
158    fn selection<R: Runtime>(
159        client: &ComputeClient<R>,
160        problem: &ConvolutionProblem,
161        plane_dim: u32,
162        line_sizes: &MatmulLineSizes,
163        dtypes: &mut MatmulElems,
164    ) -> Result<TilingBlueprint, MatmulSetupError> {
165        if line_sizes.lhs > 1 || line_sizes.rhs > 1 {
166            return Err(MatmulSetupError::InvalidConfig(Box::new(
167                "Not available with input line sizes > 1",
168            )));
169        }
170
171        Ok(convolution_matmul_selection::<TMM, R>(
172            client, problem, plane_dim, false, line_sizes, dtypes,
173        )?)
174    }
175}