cubek_convolution/kernels/algorithm/
specialized.rs1use cubecl::{
2 Runtime, client::ComputeClient, ir::StorageType, prelude::TensorBinding, server::LaunchError,
3};
4use cubek_matmul::{
5 components::{
6 global::read::{AsyncPartialLoadingStrategy, async_partial_tma::AsyncPartialTmaLoading},
7 tile::TileMatmulFamily,
8 },
9 definition::AvailableVectorSizes,
10 launch::{TensorArgs, TensorMapArgs},
11 routines::specialized::SpecializedAlgorithm,
12};
13use cubek_std::tile::Strided;
14use std::marker::PhantomData;
15
16use crate::{
17 algorithm::{Algorithm, contiguous_pitched_layout, into_tensor_handle_tma},
18 components::{
19 ConvolutionOperation,
20 global::{args::RuntimeArgs, read::strategy::sync_bias::SyncBiasLoading},
21 },
22};
23
24pub struct SpecializedConv<TMM: TileMatmulFamily, L: AsyncPartialLoadingStrategy<RuntimeArgs>> {
26 _tmm: PhantomData<TMM>,
27 _loader: PhantomData<L>,
28}
29
30pub struct SpecializedTmaConv<TMM: TileMatmulFamily> {
35 _tmm: PhantomData<TMM>,
36}
37
38impl<
39 TMM: TileMatmulFamily<
40 LhsTile = Strided,
41 RhsTile = Strided,
42 AccTile = Option<Strided>,
43 OutTile = Strided,
44 >,
45 L: AsyncPartialLoadingStrategy<RuntimeArgs, TileKind = Strided>,
46> Algorithm for SpecializedConv<TMM, L>
47{
48 type Routine = SpecializedAlgorithm<TMM, L, SyncBiasLoading>;
49 type Args = TensorArgs<RuntimeArgs>;
50 const IS_SPECIALIZED: bool = true;
51
52 fn correct_layout<R: Runtime>(
53 client: &ComputeClient<R>,
54 handle: TensorBinding<R>,
55 dtype: StorageType,
56 _operation: ConvolutionOperation,
57 ) -> Result<TensorBinding<R>, LaunchError> {
58 contiguous_pitched_layout(client, handle, dtype)
59 }
60}
61
62impl<
63 TMM: TileMatmulFamily<
64 LhsTile = Strided,
65 RhsTile = Strided,
66 AccTile = Option<Strided>,
67 OutTile = Strided,
68 >,
69> Algorithm for SpecializedTmaConv<TMM>
70{
71 type Routine = SpecializedAlgorithm<TMM, AsyncPartialTmaLoading, SyncBiasLoading>;
72 type Args = TensorMapArgs<RuntimeArgs>;
73 const IS_SPECIALIZED: bool = true;
74
75 fn correct_layout<R: Runtime>(
76 client: &ComputeClient<R>,
77 handle: TensorBinding<R>,
78 dtype: StorageType,
79 operation: ConvolutionOperation,
80 ) -> Result<TensorBinding<R>, LaunchError> {
81 into_tensor_handle_tma(client, handle, dtype, operation)
82 }
83
84 fn filter_vector_sizes(vector_sizes: AvailableVectorSizes) -> AvailableVectorSizes {
85 AvailableVectorSizes {
86 lhs: vec![1],
87 rhs: vec![1],
88 out: vector_sizes.out,
89 }
90 }
91}