cubek_convolution/kernels/forward/algorithm/
simple.rs1use cubecl::server::LaunchError;
2use cubecl::std::{CubeOption, tensor::TensorHandle};
3use cubecl::{Runtime, client::ComputeClient, ir::StorageType, prelude::TensorHandleRef};
4use cubek_matmul::components::tile::TileMatmulFamily;
5use cubek_matmul::components::{
6 global::read::sync_full_cyclic::SyncFullCyclicLoading,
7 stage::{ColMajorTilingOrder, RowMajorTilingOrder},
8};
9use cubek_matmul::components::{
10 global::read::{
11 async_full_tma::AsyncFullTmaLoading, sync_full_strided::SyncFullStridedLoading,
12 sync_full_tilewise::SyncFullTilewiseLoading,
13 },
14 stage::StridedStageFamily,
15 tile::io::Strided,
16};
17use cubek_matmul::definition::{MatmulElems, MatmulLineSizes, MatmulSetupError, TilingBlueprint};
18use cubek_matmul::launch::{TensorArgs, TensorMapArgs};
19use cubek_matmul::{components::stage::PlaneMatmulFamily, definition::AvailableLineSizes};
20use std::marker::PhantomData;
21
22use crate::{
23 components::{
24 ConvolutionOperation, ConvolutionProblem, convolution_matmul_selection,
25 global::{
26 read::{
27 full_reader::FullLoadingStrategy,
28 strategy::{
29 async_full_cyclic::AsyncFullCyclicLoading,
30 async_full_strided::AsyncFullStridedLoading,
31 },
32 },
33 single_stage::simple::SimpleConvolutionFamily,
34 },
35 },
36 kernels::forward::{into_tensor_handle, into_tensor_handle_tma},
37};
38
39use super::Algorithm;
40
41pub struct SimpleConv<TMM: TileMatmulFamily, LL: FullLoadingStrategy, LR: FullLoadingStrategy> {
43 _tmm: PhantomData<TMM>,
44 _loader: PhantomData<(LL, LR)>,
45}
46
47pub type SimpleSyncCyclicConv<TMM> = SimpleConv<
48 TMM,
49 SyncFullCyclicLoading<RowMajorTilingOrder>,
50 SyncFullCyclicLoading<ColMajorTilingOrder>,
51>;
52pub type SimpleSyncStridedConv<TMM> =
53 SimpleConv<TMM, SyncFullStridedLoading, SyncFullStridedLoading>;
54pub type SimpleSyncTilewiseConv<TMM> = SimpleConv<
55 TMM,
56 SyncFullTilewiseLoading<RowMajorTilingOrder>,
57 SyncFullTilewiseLoading<ColMajorTilingOrder>,
58>;
59pub type SimpleAsyncCyclicConv<TMM> = SimpleConv<
60 TMM,
61 AsyncFullCyclicLoading<RowMajorTilingOrder>,
62 AsyncFullCyclicLoading<ColMajorTilingOrder>,
63>;
64pub type SimpleAsyncStridedConv<TMM> =
65 SimpleConv<TMM, AsyncFullStridedLoading, AsyncFullStridedLoading>;
66
67pub struct SimpleAsyncTmaConv<TMM: TileMatmulFamily> {
68 _tmm: PhantomData<TMM>,
69}
70
71impl<
72 TMM: TileMatmulFamily<
73 LhsTile = Strided,
74 RhsTile = Strided,
75 AccTile = CubeOption<Strided>,
76 OutTile = Strided,
77 >,
78 LL: FullLoadingStrategy,
79 LR: FullLoadingStrategy<SyncStrategy = LL::SyncStrategy>,
80> Algorithm for SimpleConv<TMM, LL, LR>
81{
82 type TileMatmul = TMM;
83 type StageMatmul = PlaneMatmulFamily<
84 Self::TileMatmul,
85 StridedStageFamily,
86 StridedStageFamily,
87 Option<StridedStageFamily>,
88 >;
89 type GlobalConvolution = SimpleConvolutionFamily<Self::StageMatmul, LL, LR>;
90
91 type Args = TensorArgs;
92
93 fn into_tensor_handle<R: Runtime>(
94 client: &ComputeClient<R>,
95 handle: &TensorHandleRef<'_, R>,
96 dtype: StorageType,
97 _operation: ConvolutionOperation,
98 ) -> Result<TensorHandle<R>, LaunchError> {
99 into_tensor_handle(client, handle, dtype)
100 }
101
102 fn selection<R: Runtime>(
103 client: &ComputeClient<R>,
104 problem: &ConvolutionProblem,
105 plane_dim: u32,
106 line_sizes: &MatmulLineSizes,
107 dtypes: &mut MatmulElems,
108 ) -> Result<TilingBlueprint, MatmulSetupError> {
109 Ok(convolution_matmul_selection::<TMM, R>(
110 client,
111 problem,
112 plane_dim,
113 TMM::should_swizzle(client),
114 line_sizes,
115 dtypes,
116 )?)
117 }
118}
119
120impl<
121 TMM: TileMatmulFamily<
122 LhsTile = Strided,
123 RhsTile = Strided,
124 AccTile = CubeOption<Strided>,
125 OutTile = Strided,
126 >,
127> Algorithm for SimpleAsyncTmaConv<TMM>
128{
129 type TileMatmul = TMM;
130 type StageMatmul = PlaneMatmulFamily<
131 Self::TileMatmul,
132 StridedStageFamily,
133 StridedStageFamily,
134 Option<StridedStageFamily>,
135 >;
136 type GlobalConvolution =
137 SimpleConvolutionFamily<Self::StageMatmul, AsyncFullTmaLoading, AsyncFullTmaLoading>;
138
139 type Args = TensorMapArgs;
140
141 fn into_tensor_handle<R: Runtime>(
142 client: &ComputeClient<R>,
143 handle: &TensorHandleRef<'_, R>,
144 dtype: StorageType,
145 operation: ConvolutionOperation,
146 ) -> Result<TensorHandle<R>, LaunchError> {
147 into_tensor_handle_tma(client, handle, dtype, operation)
148 }
149
150 fn filter_line_sizes(line_sizes: AvailableLineSizes) -> AvailableLineSizes {
151 AvailableLineSizes {
152 lhs: vec![1],
153 rhs: vec![1],
154 out: line_sizes.out,
155 }
156 }
157
158 fn selection<R: Runtime>(
159 client: &ComputeClient<R>,
160 problem: &ConvolutionProblem,
161 plane_dim: u32,
162 line_sizes: &MatmulLineSizes,
163 dtypes: &mut MatmulElems,
164 ) -> Result<TilingBlueprint, MatmulSetupError> {
165 if line_sizes.lhs > 1 || line_sizes.rhs > 1 {
166 return Err(MatmulSetupError::InvalidConfig(Box::new(
167 "Not available with input line sizes > 1",
168 )));
169 }
170
171 Ok(convolution_matmul_selection::<TMM, R>(
172 client, problem, plane_dim, false, line_sizes, dtypes,
173 )?)
174 }
175}