cubek_convolution/kernels/algorithm/
simple.rs1use cubecl::{
2 server::LaunchError,
3 {Runtime, client::ComputeClient, ir::StorageType, prelude::TensorBinding},
4};
5use cubek_matmul::components::global::read::FullLoadingStrategy;
6use cubek_matmul::components::{
7 global::read::sync_full_cyclic::SyncFullCyclicLoading,
8 stage::{ColMajorTilingOrder, RowMajorTilingOrder},
9};
10use cubek_matmul::{
11 components::global::read::{
12 async_full_tma::AsyncFullTmaLoading, sync_full_strided::SyncFullStridedLoading,
13 sync_full_tilewise::SyncFullTilewiseLoading,
14 },
15 routines::simple::SimpleAlgorithm,
16};
17use cubek_matmul::{
18 definition::AvailableVectorSizes,
19 launch::{TensorArgs, TensorMapArgs},
20};
21use cubek_std::tile::Strided;
22use std::marker::PhantomData;
23
24use crate::{
25 algorithm::{contiguous_pitched_layout, into_tensor_handle_tma},
26 components::{
27 ConvolutionOperation,
28 global::{
29 args::RuntimeArgs,
30 read::strategy::{
31 async_full_cyclic::AsyncFullCyclicLoading,
32 async_full_strided::AsyncFullStridedLoading, sync_bias::SyncBiasLoading,
33 },
34 },
35 },
36};
37
38use super::Algorithm;
39
40pub struct SimpleConv<LL: FullLoadingStrategy<RuntimeArgs>, LR: FullLoadingStrategy<RuntimeArgs>> {
42 _loader: PhantomData<(LL, LR)>,
43}
44
45pub type SimpleSyncCyclicConv = SimpleConv<
46 SyncFullCyclicLoading<RowMajorTilingOrder>,
47 SyncFullCyclicLoading<ColMajorTilingOrder>,
48>;
49pub type SimpleSyncStridedConv = SimpleConv<SyncFullStridedLoading, SyncFullStridedLoading>;
50pub type SimpleSyncTilewiseConv = SimpleConv<
51 SyncFullTilewiseLoading<RowMajorTilingOrder>,
52 SyncFullTilewiseLoading<ColMajorTilingOrder>,
53>;
54pub type SimpleAsyncCyclicConv = SimpleConv<
55 AsyncFullCyclicLoading<RowMajorTilingOrder>,
56 AsyncFullCyclicLoading<ColMajorTilingOrder>,
57>;
58pub type SimpleAsyncStridedConv = SimpleConv<AsyncFullStridedLoading, AsyncFullStridedLoading>;
59
60pub struct SimpleAsyncTmaConv;
61
62impl<
63 LL: FullLoadingStrategy<RuntimeArgs, TileKind = Strided>,
64 LR: FullLoadingStrategy<RuntimeArgs, TileKind = Strided, SyncStrategy = LL::SyncStrategy>,
65> Algorithm for SimpleConv<LL, LR>
66{
67 type Routine = SimpleAlgorithm<LL, LR, SyncBiasLoading>;
68 type Args = TensorArgs<RuntimeArgs>;
69
70 fn correct_layout<R: Runtime>(
71 client: &ComputeClient<R>,
72 handle: TensorBinding<R>,
73 dtype: StorageType,
74 _operation: ConvolutionOperation,
75 ) -> Result<TensorBinding<R>, LaunchError> {
76 contiguous_pitched_layout(client, handle, dtype)
77 }
78}
79
80impl Algorithm for SimpleAsyncTmaConv {
81 type Routine = SimpleAlgorithm<AsyncFullTmaLoading, AsyncFullTmaLoading, SyncBiasLoading>;
82
83 type Args = TensorMapArgs<RuntimeArgs>;
84
85 fn correct_layout<R: Runtime>(
86 client: &ComputeClient<R>,
87 handle: TensorBinding<R>,
88 dtype: StorageType,
89 operation: ConvolutionOperation,
90 ) -> Result<TensorBinding<R>, LaunchError> {
91 into_tensor_handle_tma(client, handle, dtype, operation)
92 }
93
94 fn filter_vector_sizes(vector_sizes: AvailableVectorSizes) -> AvailableVectorSizes {
95 AvailableVectorSizes {
96 lhs: vec![1],
97 rhs: vec![1],
98 out: vector_sizes.out,
99 }
100 }
101}