cubek_convolution/kernels/algorithm/
simple.rs1use cubecl::{
2 server::LaunchError,
3 {Runtime, client::ComputeClient, ir::StorageType, prelude::TensorBinding},
4};
5use cubek_matmul::components::{global::read::FullLoadingStrategy, tile::TileMatmulFamily};
6use cubek_matmul::components::{
7 global::read::sync_full_cyclic::SyncFullCyclicLoading,
8 stage::{ColMajorTilingOrder, RowMajorTilingOrder},
9};
10use cubek_matmul::{
11 components::global::read::{
12 async_full_tma::AsyncFullTmaLoading, sync_full_strided::SyncFullStridedLoading,
13 sync_full_tilewise::SyncFullTilewiseLoading,
14 },
15 routines::simple::SimpleAlgorithm,
16};
17use cubek_matmul::{
18 definition::AvailableVectorSizes,
19 launch::{TensorArgs, TensorMapArgs},
20};
21use cubek_std::tile::Strided;
22use std::marker::PhantomData;
23
24use crate::{
25 algorithm::{contiguous_pitched_layout, into_tensor_handle_tma},
26 components::{
27 ConvolutionOperation,
28 global::{
29 args::RuntimeArgs,
30 read::strategy::{
31 async_full_cyclic::AsyncFullCyclicLoading,
32 async_full_strided::AsyncFullStridedLoading, sync_bias::SyncBiasLoading,
33 },
34 },
35 },
36};
37
38use super::Algorithm;
39
40pub struct SimpleConv<
42 TMM: TileMatmulFamily,
43 LL: FullLoadingStrategy<RuntimeArgs>,
44 LR: FullLoadingStrategy<RuntimeArgs>,
45> {
46 _tmm: PhantomData<TMM>,
47 _loader: PhantomData<(LL, LR)>,
48}
49
50pub type SimpleSyncCyclicConv<TMM> = SimpleConv<
51 TMM,
52 SyncFullCyclicLoading<RowMajorTilingOrder>,
53 SyncFullCyclicLoading<ColMajorTilingOrder>,
54>;
55pub type SimpleSyncStridedConv<TMM> =
56 SimpleConv<TMM, SyncFullStridedLoading, SyncFullStridedLoading>;
57pub type SimpleSyncTilewiseConv<TMM> = SimpleConv<
58 TMM,
59 SyncFullTilewiseLoading<RowMajorTilingOrder>,
60 SyncFullTilewiseLoading<ColMajorTilingOrder>,
61>;
62pub type SimpleAsyncCyclicConv<TMM> = SimpleConv<
63 TMM,
64 AsyncFullCyclicLoading<RowMajorTilingOrder>,
65 AsyncFullCyclicLoading<ColMajorTilingOrder>,
66>;
67pub type SimpleAsyncStridedConv<TMM> =
68 SimpleConv<TMM, AsyncFullStridedLoading, AsyncFullStridedLoading>;
69
70pub struct SimpleAsyncTmaConv<TMM: TileMatmulFamily> {
71 _tmm: PhantomData<TMM>,
72}
73
74impl<
75 TMM: TileMatmulFamily<
76 LhsTile = Strided,
77 RhsTile = Strided,
78 AccTile = Option<Strided>,
79 OutTile = Strided,
80 >,
81 LL: FullLoadingStrategy<RuntimeArgs, TileKind = Strided>,
82 LR: FullLoadingStrategy<RuntimeArgs, TileKind = Strided, SyncStrategy = LL::SyncStrategy>,
83> Algorithm for SimpleConv<TMM, LL, LR>
84{
85 type Routine = SimpleAlgorithm<TMM, LL, LR, SyncBiasLoading>;
86 type Args = TensorArgs<RuntimeArgs>;
87
88 fn correct_layout<R: Runtime>(
89 client: &ComputeClient<R>,
90 handle: TensorBinding<R>,
91 dtype: StorageType,
92 _operation: ConvolutionOperation,
93 ) -> Result<TensorBinding<R>, LaunchError> {
94 contiguous_pitched_layout(client, handle, dtype)
95 }
96}
97
98impl<
99 TMM: TileMatmulFamily<
100 LhsTile = Strided,
101 RhsTile = Strided,
102 AccTile = Option<Strided>,
103 OutTile = Strided,
104 >,
105> Algorithm for SimpleAsyncTmaConv<TMM>
106{
107 type Routine = SimpleAlgorithm<TMM, AsyncFullTmaLoading, AsyncFullTmaLoading, SyncBiasLoading>;
108
109 type Args = TensorMapArgs<RuntimeArgs>;
110
111 fn correct_layout<R: Runtime>(
112 client: &ComputeClient<R>,
113 handle: TensorBinding<R>,
114 dtype: StorageType,
115 operation: ConvolutionOperation,
116 ) -> Result<TensorBinding<R>, LaunchError> {
117 into_tensor_handle_tma(client, handle, dtype, operation)
118 }
119
120 fn filter_vector_sizes(vector_sizes: AvailableVectorSizes) -> AvailableVectorSizes {
121 AvailableVectorSizes {
122 lhs: vec![1],
123 rhs: vec![1],
124 out: vector_sizes.out,
125 }
126 }
127}