cubecl_convolution/kernels/layered/algorithm/
simple.rs

1use cubecl_core::{Runtime, client::ComputeClient, ir::StorageType, prelude::TensorHandleRef};
2use cubecl_matmul::components::stage::NumStages;
3use cubecl_matmul::components::{
4    MatmulElems, MatmulSelection, MatmulSetupError, stage::StridedStageFamily, tile::io::Strided,
5};
6use cubecl_matmul::components::{
7    MatmulIdent, global::args::TensorArgs, stage::PlaneMatmulFamily, tile::TileMatmulFamily,
8};
9use cubecl_std::{
10    CubeOption,
11    tensor::{TensorHandle, into_contiguous},
12};
13use std::marker::PhantomData;
14
15use crate::components::{
16    ConvolutionProblem, convolution_matmul_selection,
17    global::single_stage::simple::SimpleConvolutionFamily,
18};
19
20use super::Algorithm;
21
22/// Cmma convolution
23pub struct SimpleConvAlgorithm<TMM: TileMatmulFamily> {
24    _tmm: PhantomData<TMM>,
25}
26
27impl<
28    TMM: TileMatmulFamily<
29            LhsTile = Strided,
30            RhsTile = Strided,
31            AccTile = CubeOption<Strided>,
32            OutTile = Strided,
33        >,
34> Algorithm for SimpleConvAlgorithm<TMM>
35{
36    type TileMatmul = TMM;
37    type StageMatmul = PlaneMatmulFamily<
38        Self::TileMatmul,
39        StridedStageFamily,
40        StridedStageFamily,
41        Option<StridedStageFamily>,
42    >;
43    type GlobalConvolution = SimpleConvolutionFamily<Self::StageMatmul>;
44
45    type Args = TensorArgs;
46
47    fn into_tensor_handle<R: Runtime>(
48        client: &ComputeClient<R::Server>,
49        handle: &TensorHandleRef<'_, R>,
50        ident: MatmulIdent,
51        dtype: StorageType,
52    ) -> TensorHandle<R> {
53        if has_valid_layout(handle, ident) {
54            TensorHandle::from_ref(handle, dtype)
55        } else {
56            into_contiguous(client, handle, dtype)
57        }
58    }
59
60    // TODO this is not the same as tma stages, it's stages in the sense of double buffering in matmul
61    fn num_stages() -> NumStages {
62        (1, 1).into()
63    }
64
65    fn selection<R: Runtime>(
66        client: &ComputeClient<R::Server>,
67        problem: &ConvolutionProblem,
68        plane_dim: u32,
69        dtypes: &mut MatmulElems,
70    ) -> Result<MatmulSelection, MatmulSetupError> {
71        Ok(convolution_matmul_selection::<TMM, R>(
72            client, problem, plane_dim, dtypes,
73        )?)
74    }
75}
76
77fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>, ident: MatmulIdent) -> bool {
78    let rank = handle.shape.len();
79    let dim_c = rank - 1;
80    match ident {
81        MatmulIdent::Lhs => handle.strides[dim_c] == 1,
82        MatmulIdent::Rhs => handle.strides[dim_c] == 1,
83        MatmulIdent::Out => unreachable!(),
84    }
85}