cubecl_convolution/kernels/layered/algorithm/
simple.rs1use cubecl_core::server::LaunchError;
2use cubecl_core::{Runtime, client::ComputeClient, ir::StorageType, prelude::TensorHandleRef};
3use cubecl_matmul::components::stage::NumStages;
4use cubecl_matmul::components::{
5 MatmulElems, MatmulSelection, MatmulSetupError, stage::StridedStageFamily, tile::io::Strided,
6};
7use cubecl_matmul::components::{
8 MatmulIdent, global::args::TensorArgs, stage::PlaneMatmulFamily, tile::TileMatmulFamily,
9};
10use cubecl_std::{
11 CubeOption,
12 tensor::{TensorHandle, into_contiguous},
13};
14use std::marker::PhantomData;
15
16use crate::components::{
17 ConvolutionProblem, convolution_matmul_selection,
18 global::single_stage::simple::SimpleConvolutionFamily,
19};
20
21use super::Algorithm;
22
23pub struct SimpleConvAlgorithm<TMM: TileMatmulFamily> {
25 _tmm: PhantomData<TMM>,
26}
27
28impl<
29 TMM: TileMatmulFamily<
30 LhsTile = Strided,
31 RhsTile = Strided,
32 AccTile = CubeOption<Strided>,
33 OutTile = Strided,
34 >,
35> Algorithm for SimpleConvAlgorithm<TMM>
36{
37 type TileMatmul = TMM;
38 type StageMatmul = PlaneMatmulFamily<
39 Self::TileMatmul,
40 StridedStageFamily,
41 StridedStageFamily,
42 Option<StridedStageFamily>,
43 >;
44 type GlobalConvolution = SimpleConvolutionFamily<Self::StageMatmul>;
45
46 type Args = TensorArgs;
47
48 fn into_tensor_handle<R: Runtime>(
49 client: &ComputeClient<R>,
50 handle: &TensorHandleRef<'_, R>,
51 ident: MatmulIdent,
52 dtype: StorageType,
53 ) -> Result<TensorHandle<R>, LaunchError> {
54 if has_valid_layout(handle, ident) {
55 Ok(TensorHandle::from_ref(handle, dtype))
56 } else {
57 into_contiguous(client, handle, dtype)
58 }
59 }
60
61 fn num_stages() -> NumStages {
63 (1, 1).into()
64 }
65
66 fn selection<R: Runtime>(
67 client: &ComputeClient<R>,
68 problem: &ConvolutionProblem,
69 plane_dim: u32,
70 dtypes: &mut MatmulElems,
71 ) -> Result<MatmulSelection, MatmulSetupError> {
72 Ok(convolution_matmul_selection::<TMM, R>(
73 client, problem, plane_dim, dtypes,
74 )?)
75 }
76}
77
78fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>, ident: MatmulIdent) -> bool {
79 let rank = handle.shape.len();
80 let dim_c = rank - 1;
81 match ident {
82 MatmulIdent::Lhs => handle.strides[dim_c] == 1,
83 MatmulIdent::Rhs => handle.strides[dim_c] == 1,
84 MatmulIdent::Out => unreachable!(),
85 }
86}