Skip to main content

cubek_convolution/routines/
simple.rs

1use cubecl::{
2    server::LaunchError,
3    {Runtime, client::ComputeClient, ir::StorageType, prelude::TensorBinding},
4};
5use cubek_matmul::components::global::read::FullLoadingStrategy;
6use cubek_matmul::components::{
7    global::read::sync_full_cyclic::SyncFullCyclicLoading,
8    stage::{ColMajorTilingOrder, RowMajorTilingOrder},
9};
10use cubek_matmul::{
11    components::global::read::{
12        async_full_tma::AsyncFullTmaLoading, sync_full_strided::SyncFullStridedLoading,
13        sync_full_tilewise::SyncFullTilewiseLoading,
14    },
15    routines::simple::{SimpleAlgorithm, SimpleArgs},
16};
17use cubek_matmul::{
18    definition::{AvailableVectorSizes, TilingBlueprint},
19    launch::{TensorArgs, TensorMapArgs},
20};
21use cubek_std::tile::Strided;
22use std::marker::PhantomData;
23
24use crate::{
25    components::{
26        ConvolutionOperation,
27        global::{
28            args::RuntimeArgs,
29            read::strategy::{
30                async_full_cyclic::AsyncFullCyclicLoading,
31                async_full_strided::AsyncFullStridedLoading, sync_bias::SyncBiasLoading,
32            },
33        },
34    },
35    routines::{Routine, contiguous_pitched_layout, into_tensor_handle_tma},
36};
37
38/// Cmma convolution
39pub struct SimpleConv<LL: FullLoadingStrategy<RuntimeArgs>, LR: FullLoadingStrategy<RuntimeArgs>> {
40    _loader: PhantomData<(LL, LR)>,
41}
42
43pub type SimpleSyncCyclicConv = SimpleConv<
44    SyncFullCyclicLoading<RowMajorTilingOrder>,
45    SyncFullCyclicLoading<ColMajorTilingOrder>,
46>;
47pub type SimpleSyncStridedConv = SimpleConv<SyncFullStridedLoading, SyncFullStridedLoading>;
48pub type SimpleSyncTilewiseConv = SimpleConv<
49    SyncFullTilewiseLoading<RowMajorTilingOrder>,
50    SyncFullTilewiseLoading<ColMajorTilingOrder>,
51>;
52pub type SimpleAsyncCyclicConv = SimpleConv<
53    AsyncFullCyclicLoading<RowMajorTilingOrder>,
54    AsyncFullCyclicLoading<ColMajorTilingOrder>,
55>;
56pub type SimpleAsyncStridedConv = SimpleConv<AsyncFullStridedLoading, AsyncFullStridedLoading>;
57
58pub struct SimpleAsyncTmaConv;
59
60impl<
61    LL: FullLoadingStrategy<RuntimeArgs, TileKind = Strided>,
62    LR: FullLoadingStrategy<RuntimeArgs, TileKind = Strided, SyncStrategy = LL::SyncStrategy>,
63> Routine for SimpleConv<LL, LR>
64{
65    type Blueprint = TilingBlueprint;
66    type Strategy = SimpleArgs;
67    type MatmulRoutine = SimpleAlgorithm<LL, LR, SyncBiasLoading>;
68    type Args = TensorArgs<RuntimeArgs>;
69
70    fn correct_layout<R: Runtime>(
71        client: &ComputeClient<R>,
72        handle: TensorBinding<R>,
73        dtype: StorageType,
74        _operation: ConvolutionOperation,
75    ) -> Result<TensorBinding<R>, LaunchError> {
76        contiguous_pitched_layout(client, handle, dtype)
77    }
78}
79
80impl Routine for SimpleAsyncTmaConv {
81    type Blueprint = TilingBlueprint;
82    type Strategy = SimpleArgs;
83    type MatmulRoutine = SimpleAlgorithm<AsyncFullTmaLoading, AsyncFullTmaLoading, SyncBiasLoading>;
84    type Args = TensorMapArgs<RuntimeArgs>;
85
86    fn correct_layout<R: Runtime>(
87        client: &ComputeClient<R>,
88        handle: TensorBinding<R>,
89        dtype: StorageType,
90        operation: ConvolutionOperation,
91    ) -> Result<TensorBinding<R>, LaunchError> {
92        into_tensor_handle_tma(client, handle, dtype, operation)
93    }
94
95    fn filter_vector_sizes(vector_sizes: AvailableVectorSizes) -> AvailableVectorSizes {
96        AvailableVectorSizes {
97            lhs: vec![1],
98            rhs: vec![1],
99            out: vector_sizes.out,
100        }
101    }
102}