Skip to main content

cubek_convolution/kernels/algorithm/
specialized.rs

1use cubecl::{
2    Runtime, client::ComputeClient, ir::StorageType, prelude::TensorBinding, server::LaunchError,
3};
4use cubek_matmul::{
5    components::global::read::{
6        AsyncPartialLoadingStrategy, async_partial_tma::AsyncPartialTmaLoading,
7    },
8    definition::AvailableVectorSizes,
9    launch::{TensorArgs, TensorMapArgs},
10    routines::specialized::SpecializedAlgorithm,
11};
12use cubek_std::tile::Strided;
13use std::marker::PhantomData;
14
15use crate::{
16    algorithm::{Algorithm, contiguous_pitched_layout, into_tensor_handle_tma},
17    components::{
18        ConvolutionOperation,
19        global::{args::RuntimeArgs, read::strategy::sync_bias::SyncBiasLoading},
20    },
21};
22
23/// Cmma convolution
24pub struct SpecializedConv<L: AsyncPartialLoadingStrategy<RuntimeArgs>> {
25    _loader: PhantomData<L>,
26}
27
28pub struct SpecializedTmaConv;
29
30impl<L: AsyncPartialLoadingStrategy<RuntimeArgs, TileKind = Strided>> Algorithm
31    for SpecializedConv<L>
32{
33    type Routine = SpecializedAlgorithm<L, SyncBiasLoading>;
34    type Args = TensorArgs<RuntimeArgs>;
35    const IS_SPECIALIZED: bool = true;
36
37    fn correct_layout<R: Runtime>(
38        client: &ComputeClient<R>,
39        handle: TensorBinding<R>,
40        dtype: StorageType,
41        _operation: ConvolutionOperation,
42    ) -> Result<TensorBinding<R>, LaunchError> {
43        contiguous_pitched_layout(client, handle, dtype)
44    }
45}
46
47impl Algorithm for SpecializedTmaConv {
48    type Routine = SpecializedAlgorithm<AsyncPartialTmaLoading, SyncBiasLoading>;
49    type Args = TensorMapArgs<RuntimeArgs>;
50    const IS_SPECIALIZED: bool = true;
51
52    fn correct_layout<R: Runtime>(
53        client: &ComputeClient<R>,
54        handle: TensorBinding<R>,
55        dtype: StorageType,
56        operation: ConvolutionOperation,
57    ) -> Result<TensorBinding<R>, LaunchError> {
58        into_tensor_handle_tma(client, handle, dtype, operation)
59    }
60
61    fn filter_vector_sizes(vector_sizes: AvailableVectorSizes) -> AvailableVectorSizes {
62        AvailableVectorSizes {
63            lhs: vec![1],
64            rhs: vec![1],
65            out: vector_sizes.out,
66        }
67    }
68}