cubecl_convolution/kernels/layered/algorithm/
simple.rs1use cubecl_core::server::LaunchError;
2use cubecl_core::{Runtime, client::ComputeClient, ir::StorageType, prelude::TensorHandleRef};
3use cubecl_matmul::components::{
4 MatmulElems, MatmulLineSizes, MatmulSelection, MatmulSetupError, stage::StridedStageFamily,
5 tile::io::Strided,
6};
7use cubecl_matmul::components::{
8 global::args::TensorArgs, stage::PlaneMatmulFamily, tile::TileMatmulFamily,
9};
10use cubecl_matmul::components::{
11 global::read::sync_full_cyclic::SyncFullCyclicLoading,
12 stage::{ColMajorTilingOrder, NumStages, RowMajorTilingOrder},
13};
14use cubecl_std::{
15 CubeOption,
16 tensor::{TensorHandle, into_contiguous_pitched},
17};
18use std::marker::PhantomData;
19
20use crate::components::{
21 ConvolutionProblem, convolution_matmul_selection,
22 global::{
23 read::full_reader::FullLoadingStrategy, single_stage::simple::SimpleConvolutionFamily,
24 },
25};
26
27use super::Algorithm;
28
29pub struct SimpleConvAlgorithm<
31 TMM: TileMatmulFamily,
32 LL: FullLoadingStrategy = SyncFullCyclicLoading<RowMajorTilingOrder>,
33 LR: FullLoadingStrategy = SyncFullCyclicLoading<ColMajorTilingOrder>,
34> {
35 _tmm: PhantomData<TMM>,
36 _loader: PhantomData<(LL, LR)>,
37}
38
39impl<
40 TMM: TileMatmulFamily<
41 LhsTile = Strided,
42 RhsTile = Strided,
43 AccTile = CubeOption<Strided>,
44 OutTile = Strided,
45 >,
46 LL: FullLoadingStrategy,
47 LR: FullLoadingStrategy<SyncStrategy = LL::SyncStrategy>,
48> Algorithm for SimpleConvAlgorithm<TMM, LL, LR>
49{
50 type TileMatmul = TMM;
51 type StageMatmul = PlaneMatmulFamily<
52 Self::TileMatmul,
53 StridedStageFamily,
54 StridedStageFamily,
55 Option<StridedStageFamily>,
56 >;
57 type GlobalConvolution = SimpleConvolutionFamily<Self::StageMatmul, LL, LR>;
58
59 type Args = TensorArgs;
60
61 fn into_tensor_handle<R: Runtime>(
62 client: &ComputeClient<R>,
63 handle: &TensorHandleRef<'_, R>,
64 dtype: StorageType,
65 ) -> Result<TensorHandle<R>, LaunchError> {
66 if has_valid_layout(handle) {
67 Ok(TensorHandle::from_ref(handle, dtype))
68 } else {
69 into_contiguous_pitched(client, handle, dtype)
70 }
71 }
72
73 fn num_stages() -> NumStages {
75 (1, 1).into()
76 }
77
78 fn selection<R: Runtime>(
79 client: &ComputeClient<R>,
80 problem: &ConvolutionProblem,
81 plane_dim: u32,
82 line_sizes: &MatmulLineSizes,
83 dtypes: &mut MatmulElems,
84 ) -> Result<MatmulSelection, MatmulSetupError> {
85 Ok(convolution_matmul_selection::<TMM, R>(
86 client,
87 problem,
88 plane_dim,
89 TMM::should_swizzle(client),
90 line_sizes,
91 dtypes,
92 )?)
93 }
94}
95
96fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>) -> bool {
97 let rank = handle.shape.len();
98 let dim_c = rank - 1;
99 handle.strides[dim_c] == 1
100}