cubecl_convolution/kernels/layered/algorithm/
simple.rs1use cubecl_core::{Runtime, client::ComputeClient, ir::StorageType, prelude::TensorHandleRef};
2use cubecl_matmul::components::stage::NumStages;
3use cubecl_matmul::components::{
4 MatmulElems, MatmulSelection, MatmulSetupError, stage::StridedStageFamily, tile::io::Strided,
5};
6use cubecl_matmul::components::{
7 MatmulIdent, global::args::TensorArgs, stage::PlaneMatmulFamily, tile::TileMatmulFamily,
8};
9use cubecl_std::{
10 CubeOption,
11 tensor::{TensorHandle, into_contiguous},
12};
13use std::marker::PhantomData;
14
15use crate::components::{
16 ConvolutionProblem, convolution_matmul_selection,
17 global::single_stage::simple::SimpleConvolutionFamily,
18};
19
20use super::Algorithm;
21
22pub struct SimpleConvAlgorithm<TMM: TileMatmulFamily> {
24 _tmm: PhantomData<TMM>,
25}
26
27impl<
28 TMM: TileMatmulFamily<
29 LhsTile = Strided,
30 RhsTile = Strided,
31 AccTile = CubeOption<Strided>,
32 OutTile = Strided,
33 >,
34> Algorithm for SimpleConvAlgorithm<TMM>
35{
36 type TileMatmul = TMM;
37 type StageMatmul = PlaneMatmulFamily<
38 Self::TileMatmul,
39 StridedStageFamily,
40 StridedStageFamily,
41 Option<StridedStageFamily>,
42 >;
43 type GlobalConvolution = SimpleConvolutionFamily<Self::StageMatmul>;
44
45 type Args = TensorArgs;
46
47 fn into_tensor_handle<R: Runtime>(
48 client: &ComputeClient<R::Server>,
49 handle: &TensorHandleRef<'_, R>,
50 ident: MatmulIdent,
51 dtype: StorageType,
52 ) -> TensorHandle<R> {
53 if has_valid_layout(handle, ident) {
54 TensorHandle::from_ref(handle, dtype)
55 } else {
56 into_contiguous(client, handle, dtype)
57 }
58 }
59
60 fn num_stages() -> NumStages {
62 (1, 1).into()
63 }
64
65 fn selection<R: Runtime>(
66 client: &ComputeClient<R::Server>,
67 problem: &ConvolutionProblem,
68 plane_dim: u32,
69 dtypes: &mut MatmulElems,
70 ) -> Result<MatmulSelection, MatmulSetupError> {
71 Ok(convolution_matmul_selection::<TMM, R>(
72 client, problem, plane_dim, dtypes,
73 )?)
74 }
75}
76
77fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>, ident: MatmulIdent) -> bool {
78 let rank = handle.shape.len();
79 let dim_c = rank - 1;
80 match ident {
81 MatmulIdent::Lhs => handle.strides[dim_c] == 1,
82 MatmulIdent::Rhs => handle.strides[dim_c] == 1,
83 MatmulIdent::Out => unreachable!(),
84 }
85}