cubek_convolution/kernels/algorithm/
simple.rs1use cubecl::server::LaunchError;
2use cubecl::std::{CubeOption, tensor::TensorHandle};
3use cubecl::{Runtime, client::ComputeClient, ir::StorageType, prelude::TensorHandleRef};
4use cubek_matmul::components::{global::read::FullLoadingStrategy, tile::TileMatmulFamily};
5use cubek_matmul::components::{
6 global::read::sync_full_cyclic::SyncFullCyclicLoading,
7 stage::{ColMajorTilingOrder, RowMajorTilingOrder},
8};
9use cubek_matmul::definition::AvailableLineSizes;
10use cubek_matmul::launch::{TensorArgs, TensorMapArgs};
11use cubek_matmul::{
12 components::{
13 global::read::{
14 async_full_tma::AsyncFullTmaLoading, sync_full_strided::SyncFullStridedLoading,
15 sync_full_tilewise::SyncFullTilewiseLoading,
16 },
17 tile::io::Strided,
18 },
19 routines::simple::SimpleAlgorithm,
20};
21use std::marker::PhantomData;
22
23use crate::{
24 algorithm::{into_tensor_handle, into_tensor_handle_tma},
25 components::{
26 ConvolutionOperation,
27 global::{
28 args::RuntimeArgs,
29 read::strategy::{
30 async_full_cyclic::AsyncFullCyclicLoading,
31 async_full_strided::AsyncFullStridedLoading, sync_bias::SyncBiasLoading,
32 },
33 },
34 },
35};
36
37use super::Algorithm;
38
39pub struct SimpleConv<
41 TMM: TileMatmulFamily,
42 LL: FullLoadingStrategy<RuntimeArgs>,
43 LR: FullLoadingStrategy<RuntimeArgs>,
44> {
45 _tmm: PhantomData<TMM>,
46 _loader: PhantomData<(LL, LR)>,
47}
48
49pub type SimpleSyncCyclicConv<TMM> = SimpleConv<
50 TMM,
51 SyncFullCyclicLoading<RowMajorTilingOrder>,
52 SyncFullCyclicLoading<ColMajorTilingOrder>,
53>;
54pub type SimpleSyncStridedConv<TMM> =
55 SimpleConv<TMM, SyncFullStridedLoading, SyncFullStridedLoading>;
56pub type SimpleSyncTilewiseConv<TMM> = SimpleConv<
57 TMM,
58 SyncFullTilewiseLoading<RowMajorTilingOrder>,
59 SyncFullTilewiseLoading<ColMajorTilingOrder>,
60>;
61pub type SimpleAsyncCyclicConv<TMM> = SimpleConv<
62 TMM,
63 AsyncFullCyclicLoading<RowMajorTilingOrder>,
64 AsyncFullCyclicLoading<ColMajorTilingOrder>,
65>;
66pub type SimpleAsyncStridedConv<TMM> =
67 SimpleConv<TMM, AsyncFullStridedLoading, AsyncFullStridedLoading>;
68
69pub struct SimpleAsyncTmaConv<TMM: TileMatmulFamily> {
70 _tmm: PhantomData<TMM>,
71}
72
73impl<
74 TMM: TileMatmulFamily<
75 LhsTile = Strided,
76 RhsTile = Strided,
77 AccTile = CubeOption<Strided>,
78 OutTile = Strided,
79 >,
80 LL: FullLoadingStrategy<RuntimeArgs, TileKind = Strided>,
81 LR: FullLoadingStrategy<RuntimeArgs, TileKind = Strided, SyncStrategy = LL::SyncStrategy>,
82> Algorithm for SimpleConv<TMM, LL, LR>
83{
84 type Routine = SimpleAlgorithm<TMM, LL, LR, SyncBiasLoading>;
85 type Args = TensorArgs<RuntimeArgs>;
86
87 fn into_tensor_handle<R: Runtime>(
88 client: &ComputeClient<R>,
89 handle: &TensorHandleRef<'_, R>,
90 dtype: StorageType,
91 _operation: ConvolutionOperation,
92 ) -> Result<TensorHandle<R>, LaunchError> {
93 into_tensor_handle(client, handle, dtype)
94 }
95}
96
97impl<
98 TMM: TileMatmulFamily<
99 LhsTile = Strided,
100 RhsTile = Strided,
101 AccTile = CubeOption<Strided>,
102 OutTile = Strided,
103 >,
104> Algorithm for SimpleAsyncTmaConv<TMM>
105{
106 type Routine = SimpleAlgorithm<TMM, AsyncFullTmaLoading, AsyncFullTmaLoading, SyncBiasLoading>;
107
108 type Args = TensorMapArgs<RuntimeArgs>;
109
110 fn into_tensor_handle<R: Runtime>(
111 client: &ComputeClient<R>,
112 handle: &TensorHandleRef<'_, R>,
113 dtype: StorageType,
114 operation: ConvolutionOperation,
115 ) -> Result<TensorHandle<R>, LaunchError> {
116 into_tensor_handle_tma(client, handle, dtype, operation)
117 }
118
119 fn filter_line_sizes(line_sizes: AvailableLineSizes) -> AvailableLineSizes {
120 AvailableLineSizes {
121 lhs: vec![1],
122 rhs: vec![1],
123 out: line_sizes.out,
124 }
125 }
126}