Skip to main content

cubek_convolution/kernels/algorithm/
simple.rs

1use cubecl::server::LaunchError;
2use cubecl::std::{CubeOption, tensor::TensorHandle};
3use cubecl::{Runtime, client::ComputeClient, ir::StorageType, prelude::TensorHandleRef};
4use cubek_matmul::components::{global::read::FullLoadingStrategy, tile::TileMatmulFamily};
5use cubek_matmul::components::{
6    global::read::sync_full_cyclic::SyncFullCyclicLoading,
7    stage::{ColMajorTilingOrder, RowMajorTilingOrder},
8};
9use cubek_matmul::definition::AvailableLineSizes;
10use cubek_matmul::launch::{TensorArgs, TensorMapArgs};
11use cubek_matmul::{
12    components::{
13        global::read::{
14            async_full_tma::AsyncFullTmaLoading, sync_full_strided::SyncFullStridedLoading,
15            sync_full_tilewise::SyncFullTilewiseLoading,
16        },
17        tile::io::Strided,
18    },
19    routines::simple::SimpleAlgorithm,
20};
21use std::marker::PhantomData;
22
23use crate::{
24    algorithm::{into_tensor_handle, into_tensor_handle_tma},
25    components::{
26        ConvolutionOperation,
27        global::{
28            args::RuntimeArgs,
29            read::strategy::{
30                async_full_cyclic::AsyncFullCyclicLoading,
31                async_full_strided::AsyncFullStridedLoading, sync_bias::SyncBiasLoading,
32            },
33        },
34    },
35};
36
37use super::Algorithm;
38
39/// Cmma convolution
40pub struct SimpleConv<
41    TMM: TileMatmulFamily,
42    LL: FullLoadingStrategy<RuntimeArgs>,
43    LR: FullLoadingStrategy<RuntimeArgs>,
44> {
45    _tmm: PhantomData<TMM>,
46    _loader: PhantomData<(LL, LR)>,
47}
48
49pub type SimpleSyncCyclicConv<TMM> = SimpleConv<
50    TMM,
51    SyncFullCyclicLoading<RowMajorTilingOrder>,
52    SyncFullCyclicLoading<ColMajorTilingOrder>,
53>;
54pub type SimpleSyncStridedConv<TMM> =
55    SimpleConv<TMM, SyncFullStridedLoading, SyncFullStridedLoading>;
56pub type SimpleSyncTilewiseConv<TMM> = SimpleConv<
57    TMM,
58    SyncFullTilewiseLoading<RowMajorTilingOrder>,
59    SyncFullTilewiseLoading<ColMajorTilingOrder>,
60>;
61pub type SimpleAsyncCyclicConv<TMM> = SimpleConv<
62    TMM,
63    AsyncFullCyclicLoading<RowMajorTilingOrder>,
64    AsyncFullCyclicLoading<ColMajorTilingOrder>,
65>;
66pub type SimpleAsyncStridedConv<TMM> =
67    SimpleConv<TMM, AsyncFullStridedLoading, AsyncFullStridedLoading>;
68
69pub struct SimpleAsyncTmaConv<TMM: TileMatmulFamily> {
70    _tmm: PhantomData<TMM>,
71}
72
73impl<
74    TMM: TileMatmulFamily<
75            LhsTile = Strided,
76            RhsTile = Strided,
77            AccTile = CubeOption<Strided>,
78            OutTile = Strided,
79        >,
80    LL: FullLoadingStrategy<RuntimeArgs, TileKind = Strided>,
81    LR: FullLoadingStrategy<RuntimeArgs, TileKind = Strided, SyncStrategy = LL::SyncStrategy>,
82> Algorithm for SimpleConv<TMM, LL, LR>
83{
84    type Routine = SimpleAlgorithm<TMM, LL, LR, SyncBiasLoading>;
85    type Args = TensorArgs<RuntimeArgs>;
86
87    fn into_tensor_handle<R: Runtime>(
88        client: &ComputeClient<R>,
89        handle: &TensorHandleRef<'_, R>,
90        dtype: StorageType,
91        _operation: ConvolutionOperation,
92    ) -> Result<TensorHandle<R>, LaunchError> {
93        into_tensor_handle(client, handle, dtype)
94    }
95}
96
97impl<
98    TMM: TileMatmulFamily<
99            LhsTile = Strided,
100            RhsTile = Strided,
101            AccTile = CubeOption<Strided>,
102            OutTile = Strided,
103        >,
104> Algorithm for SimpleAsyncTmaConv<TMM>
105{
106    type Routine = SimpleAlgorithm<TMM, AsyncFullTmaLoading, AsyncFullTmaLoading, SyncBiasLoading>;
107
108    type Args = TensorMapArgs<RuntimeArgs>;
109
110    fn into_tensor_handle<R: Runtime>(
111        client: &ComputeClient<R>,
112        handle: &TensorHandleRef<'_, R>,
113        dtype: StorageType,
114        operation: ConvolutionOperation,
115    ) -> Result<TensorHandle<R>, LaunchError> {
116        into_tensor_handle_tma(client, handle, dtype, operation)
117    }
118
119    fn filter_line_sizes(line_sizes: AvailableLineSizes) -> AvailableLineSizes {
120        AvailableLineSizes {
121            lhs: vec![1],
122            rhs: vec![1],
123            out: line_sizes.out,
124        }
125    }
126}