Skip to main content

cubek_convolution/kernels/algorithm/
specialized.rs

1use std::marker::PhantomData;
2
3use cubecl::{
4    Runtime,
5    client::ComputeClient,
6    ir::StorageType,
7    prelude::TensorHandleRef,
8    server::LaunchError,
9    std::{CubeOption, tensor::TensorHandle},
10};
11use cubek_matmul::{
12    components::{
13        global::read::{AsyncPartialLoadingStrategy, async_partial_tma::AsyncPartialTmaLoading},
14        tile::{TileMatmulFamily, io::Strided},
15    },
16    definition::AvailableLineSizes,
17    launch::{TensorArgs, TensorMapArgs},
18    routines::specialized::SpecializedAlgorithm,
19};
20
21use crate::{
22    algorithm::{Algorithm, into_tensor_handle, into_tensor_handle_tma},
23    components::{
24        ConvolutionOperation,
25        global::{args::RuntimeArgs, read::strategy::sync_bias::SyncBiasLoading},
26    },
27};
28
29/// Cmma convolution
30pub struct SpecializedConv<TMM: TileMatmulFamily, L: AsyncPartialLoadingStrategy<RuntimeArgs>> {
31    _tmm: PhantomData<TMM>,
32    _loader: PhantomData<L>,
33}
34
35// pub type SpecializedCyclicConv<TMM> =
36//     SpecializedConv<TMM, AsyncPartialCyclicLoading<ColMajorTilingOrder>>;
37// pub type SpecializedStridedConv<TMM> = SpecializedConv<TMM, AsyncPartialStridedLoading>;
38
39pub struct SpecializedTmaConv<TMM: TileMatmulFamily> {
40    _tmm: PhantomData<TMM>,
41}
42
43impl<
44    TMM: TileMatmulFamily<
45            LhsTile = Strided,
46            RhsTile = Strided,
47            AccTile = CubeOption<Strided>,
48            OutTile = Strided,
49        >,
50    L: AsyncPartialLoadingStrategy<RuntimeArgs, TileKind = Strided>,
51> Algorithm for SpecializedConv<TMM, L>
52{
53    type Routine = SpecializedAlgorithm<TMM, L, SyncBiasLoading>;
54    type Args = TensorArgs<RuntimeArgs>;
55    const IS_SPECIALIZED: bool = true;
56
57    fn into_tensor_handle<R: Runtime>(
58        client: &ComputeClient<R>,
59        handle: &TensorHandleRef<'_, R>,
60        dtype: StorageType,
61        _operation: ConvolutionOperation,
62    ) -> Result<TensorHandle<R>, LaunchError> {
63        into_tensor_handle(client, handle, dtype)
64    }
65}
66
67impl<
68    TMM: TileMatmulFamily<
69            LhsTile = Strided,
70            RhsTile = Strided,
71            AccTile = CubeOption<Strided>,
72            OutTile = Strided,
73        >,
74> Algorithm for SpecializedTmaConv<TMM>
75{
76    type Routine = SpecializedAlgorithm<TMM, AsyncPartialTmaLoading, SyncBiasLoading>;
77    type Args = TensorMapArgs<RuntimeArgs>;
78    const IS_SPECIALIZED: bool = true;
79
80    fn into_tensor_handle<R: Runtime>(
81        client: &ComputeClient<R>,
82        handle: &TensorHandleRef<'_, R>,
83        dtype: StorageType,
84        operation: ConvolutionOperation,
85    ) -> Result<TensorHandle<R>, LaunchError> {
86        into_tensor_handle_tma(client, handle, dtype, operation)
87    }
88
89    fn filter_line_sizes(line_sizes: AvailableLineSizes) -> AvailableLineSizes {
90        AvailableLineSizes {
91            lhs: vec![1],
92            rhs: vec![1],
93            out: line_sizes.out,
94        }
95    }
96}