Skip to main content

cubek_convolution/routines/
specialized.rs

1use cubecl::{
2    Runtime, client::ComputeClient, ir::StorageType, prelude::TensorBinding, server::LaunchError,
3};
4use cubek_matmul::{
5    components::{
6        global::read::{
7            AsyncPartialLoadingStrategy, async_partial_cyclic::AsyncPartialCyclicLoading,
8            async_partial_strided::AsyncPartialStridedLoading,
9            async_partial_tma::AsyncPartialTmaLoading,
10        },
11        stage::ColMajorTilingOrder,
12    },
13    definition::{AvailableVectorSizes, TilingBlueprint},
14    launch::{TensorArgs, TensorMapArgs},
15    routines::specialized::{SpecializedAlgorithm, SpecializedStrategy},
16};
17use cubek_std::tile::Strided;
18use std::marker::PhantomData;
19
20use crate::{
21    components::{
22        ConvolutionOperation,
23        global::{args::RuntimeArgs, read::strategy::sync_bias::SyncBiasLoading},
24    },
25    routines::{Routine, contiguous_pitched_layout, into_tensor_handle_tma},
26};
27
28/// Cmma convolution with a partial async loading strategy.
29pub struct SpecializedConv<L: AsyncPartialLoadingStrategy<RuntimeArgs>> {
30    _loader: PhantomData<L>,
31}
32
33pub type SpecializedAsyncCyclicConv =
34    SpecializedConv<AsyncPartialCyclicLoading<ColMajorTilingOrder>>;
35pub type SpecializedAsyncStridedConv = SpecializedConv<AsyncPartialStridedLoading>;
36
37pub struct SpecializedTmaConv;
38
39impl<L: AsyncPartialLoadingStrategy<RuntimeArgs, TileKind = Strided>> Routine
40    for SpecializedConv<L>
41{
42    type Blueprint = TilingBlueprint;
43    type Strategy = SpecializedStrategy;
44    type MatmulRoutine = SpecializedAlgorithm<L, SyncBiasLoading>;
45    type Args = TensorArgs<RuntimeArgs>;
46    const IS_SPECIALIZED: bool = true;
47
48    fn correct_layout<R: Runtime>(
49        client: &ComputeClient<R>,
50        handle: TensorBinding<R>,
51        dtype: StorageType,
52        _operation: ConvolutionOperation,
53    ) -> Result<TensorBinding<R>, LaunchError> {
54        contiguous_pitched_layout(client, handle, dtype)
55    }
56}
57
58impl Routine for SpecializedTmaConv {
59    type Blueprint = TilingBlueprint;
60    type Strategy = SpecializedStrategy;
61    type MatmulRoutine = SpecializedAlgorithm<AsyncPartialTmaLoading, SyncBiasLoading>;
62    type Args = TensorMapArgs<RuntimeArgs>;
63    const IS_SPECIALIZED: bool = true;
64
65    fn correct_layout<R: Runtime>(
66        client: &ComputeClient<R>,
67        handle: TensorBinding<R>,
68        dtype: StorageType,
69        operation: ConvolutionOperation,
70    ) -> Result<TensorBinding<R>, LaunchError> {
71        into_tensor_handle_tma(client, handle, dtype, operation)
72    }
73
74    fn filter_vector_sizes(vector_sizes: AvailableVectorSizes) -> AvailableVectorSizes {
75        AvailableVectorSizes {
76            lhs: vec![1],
77            rhs: vec![1],
78            out: vector_sizes.out,
79        }
80    }
81}