Skip to main content

cubek_convolution/kernels/algorithm/
specialized.rs

1use cubecl::{
2    Runtime, client::ComputeClient, ir::StorageType, prelude::TensorBinding, server::LaunchError,
3};
4use cubek_matmul::{
5    components::{
6        global::read::{AsyncPartialLoadingStrategy, async_partial_tma::AsyncPartialTmaLoading},
7        tile::TileMatmulFamily,
8    },
9    definition::AvailableVectorSizes,
10    launch::{TensorArgs, TensorMapArgs},
11    routines::specialized::SpecializedAlgorithm,
12};
13use cubek_std::tile::Strided;
14use std::marker::PhantomData;
15
16use crate::{
17    algorithm::{Algorithm, contiguous_pitched_layout, into_tensor_handle_tma},
18    components::{
19        ConvolutionOperation,
20        global::{args::RuntimeArgs, read::strategy::sync_bias::SyncBiasLoading},
21    },
22};
23
24/// Cmma convolution
25pub struct SpecializedConv<TMM: TileMatmulFamily, L: AsyncPartialLoadingStrategy<RuntimeArgs>> {
26    _tmm: PhantomData<TMM>,
27    _loader: PhantomData<L>,
28}
29
30// pub type SpecializedCyclicConv<TMM> =
31//     SpecializedConv<TMM, AsyncPartialCyclicLoading<ColMajorTilingOrder>>;
32// pub type SpecializedStridedConv<TMM> = SpecializedConv<TMM, AsyncPartialStridedLoading>;
33
34pub struct SpecializedTmaConv<TMM: TileMatmulFamily> {
35    _tmm: PhantomData<TMM>,
36}
37
38impl<
39    TMM: TileMatmulFamily<
40            LhsTile = Strided,
41            RhsTile = Strided,
42            AccTile = Option<Strided>,
43            OutTile = Strided,
44        >,
45    L: AsyncPartialLoadingStrategy<RuntimeArgs, TileKind = Strided>,
46> Algorithm for SpecializedConv<TMM, L>
47{
48    type Routine = SpecializedAlgorithm<TMM, L, SyncBiasLoading>;
49    type Args = TensorArgs<RuntimeArgs>;
50    const IS_SPECIALIZED: bool = true;
51
52    fn correct_layout<R: Runtime>(
53        client: &ComputeClient<R>,
54        handle: TensorBinding<R>,
55        dtype: StorageType,
56        _operation: ConvolutionOperation,
57    ) -> Result<TensorBinding<R>, LaunchError> {
58        contiguous_pitched_layout(client, handle, dtype)
59    }
60}
61
62impl<
63    TMM: TileMatmulFamily<
64            LhsTile = Strided,
65            RhsTile = Strided,
66            AccTile = Option<Strided>,
67            OutTile = Strided,
68        >,
69> Algorithm for SpecializedTmaConv<TMM>
70{
71    type Routine = SpecializedAlgorithm<TMM, AsyncPartialTmaLoading, SyncBiasLoading>;
72    type Args = TensorMapArgs<RuntimeArgs>;
73    const IS_SPECIALIZED: bool = true;
74
75    fn correct_layout<R: Runtime>(
76        client: &ComputeClient<R>,
77        handle: TensorBinding<R>,
78        dtype: StorageType,
79        operation: ConvolutionOperation,
80    ) -> Result<TensorBinding<R>, LaunchError> {
81        into_tensor_handle_tma(client, handle, dtype, operation)
82    }
83
84    fn filter_vector_sizes(vector_sizes: AvailableVectorSizes) -> AvailableVectorSizes {
85        AvailableVectorSizes {
86            lhs: vec![1],
87            rhs: vec![1],
88            out: vector_sizes.out,
89        }
90    }
91}