cubecl_convolution/kernels/layered/algorithm/
simple.rs1use std::marker::PhantomData;
2
3use cubecl_core::{
4 Runtime,
5 client::ComputeClient,
6 prelude::{Numeric, TensorHandleRef},
7};
8use cubecl_matmul::components::{
9 MatmulElems, MatmulSelection, MatmulSetupError, stage::StridedStageFamily, tile::io::Strided,
10};
11
12use cubecl_matmul::components::stage::NumStages;
13use cubecl_matmul::components::{
14 MatmulIdent, global::args::TensorArgs, stage::PlaneMatmulFamily, tile::TileMatmulFamily,
15};
16
17use cubecl_std::{
18 CubeOption,
19 tensor::{TensorHandle, into_contiguous},
20};
21
22use crate::components::{
23 ConvolutionProblem, convolution_matmul_selection,
24 global::single_stage::simple::SimpleConvolutionFamily,
25};
26
27use super::Algorithm;
28
29pub struct SimpleConvAlgorithm<TMM: TileMatmulFamily> {
31 _tmm: PhantomData<TMM>,
32}
33
34impl<
35 TMM: TileMatmulFamily<
36 LhsTile = Strided,
37 RhsTile = Strided,
38 AccTile = CubeOption<Strided>,
39 OutTile = Strided,
40 >,
41> Algorithm for SimpleConvAlgorithm<TMM>
42{
43 type TileMatmul = TMM;
44 type StageMatmul = PlaneMatmulFamily<
45 Self::TileMatmul,
46 StridedStageFamily,
47 StridedStageFamily,
48 Option<StridedStageFamily>,
49 >;
50 type GlobalConvolution = SimpleConvolutionFamily<Self::StageMatmul>;
51
52 type Args = TensorArgs;
53
54 fn into_tensor_handle<R: Runtime, E: Numeric>(
55 client: &ComputeClient<R::Server>,
56 handle: &TensorHandleRef<'_, R>,
57 ident: MatmulIdent,
58 ) -> TensorHandle<R, E> {
59 if has_valid_layout(handle, ident) {
60 TensorHandle::from_ref(handle)
61 } else {
62 into_contiguous(client, handle)
63 }
64 }
65
66 fn num_stages() -> NumStages {
68 (1, 1).into()
69 }
70
71 fn selection<R: Runtime>(
72 client: &ComputeClient<R::Server>,
73 problem: &ConvolutionProblem,
74 plane_dim: u32,
75 matmul_elems: MatmulElems,
76 ) -> Result<MatmulSelection, MatmulSetupError> {
77 Ok(convolution_matmul_selection::<TMM, R>(
78 client,
79 problem,
80 plane_dim,
81 matmul_elems,
82 ))
83 }
84}
85
86fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>, ident: MatmulIdent) -> bool {
87 let rank = handle.shape.len();
88 let dim_c = rank - 1;
89 match ident {
90 MatmulIdent::Lhs => handle.strides[dim_c] == 1,
91 MatmulIdent::Rhs => handle.strides[dim_c] == 1,
92 MatmulIdent::Out => unreachable!(),
93 }
94}