cubek_convolution/routines/
simple.rs1use cubecl::{
2 server::LaunchError,
3 {Runtime, client::ComputeClient, ir::StorageType, prelude::TensorBinding},
4};
5use cubek_matmul::components::global::read::FullLoadingStrategy;
6use cubek_matmul::components::{
7 global::read::sync_full_cyclic::SyncFullCyclicLoading,
8 stage::{ColMajorTilingOrder, RowMajorTilingOrder},
9};
10use cubek_matmul::{
11 components::global::read::{
12 async_full_tma::AsyncFullTmaLoading, sync_full_strided::SyncFullStridedLoading,
13 sync_full_tilewise::SyncFullTilewiseLoading,
14 },
15 routines::simple::{SimpleAlgorithm, SimpleArgs},
16};
17use cubek_matmul::{
18 definition::{AvailableVectorSizes, TilingBlueprint},
19 launch::{TensorArgs, TensorMapArgs},
20};
21use cubek_std::tile::Strided;
22use std::marker::PhantomData;
23
24use crate::{
25 components::{
26 ConvolutionOperation,
27 global::{
28 args::RuntimeArgs,
29 read::strategy::{
30 async_full_cyclic::AsyncFullCyclicLoading,
31 async_full_strided::AsyncFullStridedLoading, sync_bias::SyncBiasLoading,
32 },
33 },
34 },
35 routines::{Routine, contiguous_pitched_layout, into_tensor_handle_tma},
36};
37
38pub struct SimpleConv<LL: FullLoadingStrategy<RuntimeArgs>, LR: FullLoadingStrategy<RuntimeArgs>> {
40 _loader: PhantomData<(LL, LR)>,
41}
42
43pub type SimpleSyncCyclicConv = SimpleConv<
44 SyncFullCyclicLoading<RowMajorTilingOrder>,
45 SyncFullCyclicLoading<ColMajorTilingOrder>,
46>;
47pub type SimpleSyncStridedConv = SimpleConv<SyncFullStridedLoading, SyncFullStridedLoading>;
48pub type SimpleSyncTilewiseConv = SimpleConv<
49 SyncFullTilewiseLoading<RowMajorTilingOrder>,
50 SyncFullTilewiseLoading<ColMajorTilingOrder>,
51>;
52pub type SimpleAsyncCyclicConv = SimpleConv<
53 AsyncFullCyclicLoading<RowMajorTilingOrder>,
54 AsyncFullCyclicLoading<ColMajorTilingOrder>,
55>;
56pub type SimpleAsyncStridedConv = SimpleConv<AsyncFullStridedLoading, AsyncFullStridedLoading>;
57
58pub struct SimpleAsyncTmaConv;
59
60impl<
61 LL: FullLoadingStrategy<RuntimeArgs, TileKind = Strided>,
62 LR: FullLoadingStrategy<RuntimeArgs, TileKind = Strided, SyncStrategy = LL::SyncStrategy>,
63> Routine for SimpleConv<LL, LR>
64{
65 type Blueprint = TilingBlueprint;
66 type Strategy = SimpleArgs;
67 type MatmulRoutine = SimpleAlgorithm<LL, LR, SyncBiasLoading>;
68 type Args = TensorArgs<RuntimeArgs>;
69
70 fn correct_layout<R: Runtime>(
71 client: &ComputeClient<R>,
72 handle: TensorBinding<R>,
73 dtype: StorageType,
74 _operation: ConvolutionOperation,
75 ) -> Result<TensorBinding<R>, LaunchError> {
76 contiguous_pitched_layout(client, handle, dtype)
77 }
78}
79
80impl Routine for SimpleAsyncTmaConv {
81 type Blueprint = TilingBlueprint;
82 type Strategy = SimpleArgs;
83 type MatmulRoutine = SimpleAlgorithm<AsyncFullTmaLoading, AsyncFullTmaLoading, SyncBiasLoading>;
84 type Args = TensorMapArgs<RuntimeArgs>;
85
86 fn correct_layout<R: Runtime>(
87 client: &ComputeClient<R>,
88 handle: TensorBinding<R>,
89 dtype: StorageType,
90 operation: ConvolutionOperation,
91 ) -> Result<TensorBinding<R>, LaunchError> {
92 into_tensor_handle_tma(client, handle, dtype, operation)
93 }
94
95 fn filter_vector_sizes(vector_sizes: AvailableVectorSizes) -> AvailableVectorSizes {
96 AvailableVectorSizes {
97 lhs: vec![1],
98 rhs: vec![1],
99 out: vector_sizes.out,
100 }
101 }
102}