Skip to main content

cubek_convolution/kernels/algorithm/
simple.rs

1use cubecl::{
2    server::LaunchError,
3    {Runtime, client::ComputeClient, ir::StorageType, prelude::TensorBinding},
4};
5use cubek_matmul::components::global::read::FullLoadingStrategy;
6use cubek_matmul::components::{
7    global::read::sync_full_cyclic::SyncFullCyclicLoading,
8    stage::{ColMajorTilingOrder, RowMajorTilingOrder},
9};
10use cubek_matmul::{
11    components::global::read::{
12        async_full_tma::AsyncFullTmaLoading, sync_full_strided::SyncFullStridedLoading,
13        sync_full_tilewise::SyncFullTilewiseLoading,
14    },
15    routines::simple::SimpleAlgorithm,
16};
17use cubek_matmul::{
18    definition::AvailableVectorSizes,
19    launch::{TensorArgs, TensorMapArgs},
20};
21use cubek_std::tile::Strided;
22use std::marker::PhantomData;
23
24use crate::{
25    algorithm::{contiguous_pitched_layout, into_tensor_handle_tma},
26    components::{
27        ConvolutionOperation,
28        global::{
29            args::RuntimeArgs,
30            read::strategy::{
31                async_full_cyclic::AsyncFullCyclicLoading,
32                async_full_strided::AsyncFullStridedLoading, sync_bias::SyncBiasLoading,
33            },
34        },
35    },
36};
37
38use super::Algorithm;
39
40/// Cmma convolution
41pub struct SimpleConv<LL: FullLoadingStrategy<RuntimeArgs>, LR: FullLoadingStrategy<RuntimeArgs>> {
42    _loader: PhantomData<(LL, LR)>,
43}
44
45pub type SimpleSyncCyclicConv = SimpleConv<
46    SyncFullCyclicLoading<RowMajorTilingOrder>,
47    SyncFullCyclicLoading<ColMajorTilingOrder>,
48>;
49pub type SimpleSyncStridedConv = SimpleConv<SyncFullStridedLoading, SyncFullStridedLoading>;
50pub type SimpleSyncTilewiseConv = SimpleConv<
51    SyncFullTilewiseLoading<RowMajorTilingOrder>,
52    SyncFullTilewiseLoading<ColMajorTilingOrder>,
53>;
54pub type SimpleAsyncCyclicConv = SimpleConv<
55    AsyncFullCyclicLoading<RowMajorTilingOrder>,
56    AsyncFullCyclicLoading<ColMajorTilingOrder>,
57>;
58pub type SimpleAsyncStridedConv = SimpleConv<AsyncFullStridedLoading, AsyncFullStridedLoading>;
59
60pub struct SimpleAsyncTmaConv;
61
62impl<
63    LL: FullLoadingStrategy<RuntimeArgs, TileKind = Strided>,
64    LR: FullLoadingStrategy<RuntimeArgs, TileKind = Strided, SyncStrategy = LL::SyncStrategy>,
65> Algorithm for SimpleConv<LL, LR>
66{
67    type Routine = SimpleAlgorithm<LL, LR, SyncBiasLoading>;
68    type Args = TensorArgs<RuntimeArgs>;
69
70    fn correct_layout<R: Runtime>(
71        client: &ComputeClient<R>,
72        handle: TensorBinding<R>,
73        dtype: StorageType,
74        _operation: ConvolutionOperation,
75    ) -> Result<TensorBinding<R>, LaunchError> {
76        contiguous_pitched_layout(client, handle, dtype)
77    }
78}
79
80impl Algorithm for SimpleAsyncTmaConv {
81    type Routine = SimpleAlgorithm<AsyncFullTmaLoading, AsyncFullTmaLoading, SyncBiasLoading>;
82
83    type Args = TensorMapArgs<RuntimeArgs>;
84
85    fn correct_layout<R: Runtime>(
86        client: &ComputeClient<R>,
87        handle: TensorBinding<R>,
88        dtype: StorageType,
89        operation: ConvolutionOperation,
90    ) -> Result<TensorBinding<R>, LaunchError> {
91        into_tensor_handle_tma(client, handle, dtype, operation)
92    }
93
94    fn filter_vector_sizes(vector_sizes: AvailableVectorSizes) -> AvailableVectorSizes {
95        AvailableVectorSizes {
96            lhs: vec![1],
97            rhs: vec![1],
98            out: vector_sizes.out,
99        }
100    }
101}