Skip to main content

cubek_convolution/kernels/algorithm/
simple.rs

1use cubecl::{
2    server::LaunchError,
3    {Runtime, client::ComputeClient, ir::StorageType, prelude::TensorBinding},
4};
5use cubek_matmul::components::{global::read::FullLoadingStrategy, tile::TileMatmulFamily};
6use cubek_matmul::components::{
7    global::read::sync_full_cyclic::SyncFullCyclicLoading,
8    stage::{ColMajorTilingOrder, RowMajorTilingOrder},
9};
10use cubek_matmul::{
11    components::global::read::{
12        async_full_tma::AsyncFullTmaLoading, sync_full_strided::SyncFullStridedLoading,
13        sync_full_tilewise::SyncFullTilewiseLoading,
14    },
15    routines::simple::SimpleAlgorithm,
16};
17use cubek_matmul::{
18    definition::AvailableVectorSizes,
19    launch::{TensorArgs, TensorMapArgs},
20};
21use cubek_std::tile::Strided;
22use std::marker::PhantomData;
23
24use crate::{
25    algorithm::{contiguous_pitched_layout, into_tensor_handle_tma},
26    components::{
27        ConvolutionOperation,
28        global::{
29            args::RuntimeArgs,
30            read::strategy::{
31                async_full_cyclic::AsyncFullCyclicLoading,
32                async_full_strided::AsyncFullStridedLoading, sync_bias::SyncBiasLoading,
33            },
34        },
35    },
36};
37
38use super::Algorithm;
39
40/// Cmma convolution
41pub struct SimpleConv<
42    TMM: TileMatmulFamily,
43    LL: FullLoadingStrategy<RuntimeArgs>,
44    LR: FullLoadingStrategy<RuntimeArgs>,
45> {
46    _tmm: PhantomData<TMM>,
47    _loader: PhantomData<(LL, LR)>,
48}
49
50pub type SimpleSyncCyclicConv<TMM> = SimpleConv<
51    TMM,
52    SyncFullCyclicLoading<RowMajorTilingOrder>,
53    SyncFullCyclicLoading<ColMajorTilingOrder>,
54>;
55pub type SimpleSyncStridedConv<TMM> =
56    SimpleConv<TMM, SyncFullStridedLoading, SyncFullStridedLoading>;
57pub type SimpleSyncTilewiseConv<TMM> = SimpleConv<
58    TMM,
59    SyncFullTilewiseLoading<RowMajorTilingOrder>,
60    SyncFullTilewiseLoading<ColMajorTilingOrder>,
61>;
62pub type SimpleAsyncCyclicConv<TMM> = SimpleConv<
63    TMM,
64    AsyncFullCyclicLoading<RowMajorTilingOrder>,
65    AsyncFullCyclicLoading<ColMajorTilingOrder>,
66>;
67pub type SimpleAsyncStridedConv<TMM> =
68    SimpleConv<TMM, AsyncFullStridedLoading, AsyncFullStridedLoading>;
69
70pub struct SimpleAsyncTmaConv<TMM: TileMatmulFamily> {
71    _tmm: PhantomData<TMM>,
72}
73
74impl<
75    TMM: TileMatmulFamily<
76            LhsTile = Strided,
77            RhsTile = Strided,
78            AccTile = Option<Strided>,
79            OutTile = Strided,
80        >,
81    LL: FullLoadingStrategy<RuntimeArgs, TileKind = Strided>,
82    LR: FullLoadingStrategy<RuntimeArgs, TileKind = Strided, SyncStrategy = LL::SyncStrategy>,
83> Algorithm for SimpleConv<TMM, LL, LR>
84{
85    type Routine = SimpleAlgorithm<TMM, LL, LR, SyncBiasLoading>;
86    type Args = TensorArgs<RuntimeArgs>;
87
88    fn correct_layout<R: Runtime>(
89        client: &ComputeClient<R>,
90        handle: TensorBinding<R>,
91        dtype: StorageType,
92        _operation: ConvolutionOperation,
93    ) -> Result<TensorBinding<R>, LaunchError> {
94        contiguous_pitched_layout(client, handle, dtype)
95    }
96}
97
98impl<
99    TMM: TileMatmulFamily<
100            LhsTile = Strided,
101            RhsTile = Strided,
102            AccTile = Option<Strided>,
103            OutTile = Strided,
104        >,
105> Algorithm for SimpleAsyncTmaConv<TMM>
106{
107    type Routine = SimpleAlgorithm<TMM, AsyncFullTmaLoading, AsyncFullTmaLoading, SyncBiasLoading>;
108
109    type Args = TensorMapArgs<RuntimeArgs>;
110
111    fn correct_layout<R: Runtime>(
112        client: &ComputeClient<R>,
113        handle: TensorBinding<R>,
114        dtype: StorageType,
115        operation: ConvolutionOperation,
116    ) -> Result<TensorBinding<R>, LaunchError> {
117        into_tensor_handle_tma(client, handle, dtype, operation)
118    }
119
120    fn filter_vector_sizes(vector_sizes: AvailableVectorSizes) -> AvailableVectorSizes {
121        AvailableVectorSizes {
122            lhs: vec![1],
123            rhs: vec![1],
124            out: vector_sizes.out,
125        }
126    }
127}