cubecl_convolution/kernels/layered/algorithm/
simple.rs

1use cubecl_core::server::LaunchError;
2use cubecl_core::{Runtime, client::ComputeClient, ir::StorageType, prelude::TensorHandleRef};
3use cubecl_matmul::components::stage::NumStages;
4use cubecl_matmul::components::{
5    MatmulElems, MatmulSelection, MatmulSetupError, stage::StridedStageFamily, tile::io::Strided,
6};
7use cubecl_matmul::components::{
8    MatmulIdent, global::args::TensorArgs, stage::PlaneMatmulFamily, tile::TileMatmulFamily,
9};
10use cubecl_std::{
11    CubeOption,
12    tensor::{TensorHandle, into_contiguous},
13};
14use std::marker::PhantomData;
15
16use crate::components::{
17    ConvolutionProblem, convolution_matmul_selection,
18    global::single_stage::simple::SimpleConvolutionFamily,
19};
20
21use super::Algorithm;
22
23/// Cmma convolution
24pub struct SimpleConvAlgorithm<TMM: TileMatmulFamily> {
25    _tmm: PhantomData<TMM>,
26}
27
28impl<
29    TMM: TileMatmulFamily<
30            LhsTile = Strided,
31            RhsTile = Strided,
32            AccTile = CubeOption<Strided>,
33            OutTile = Strided,
34        >,
35> Algorithm for SimpleConvAlgorithm<TMM>
36{
37    type TileMatmul = TMM;
38    type StageMatmul = PlaneMatmulFamily<
39        Self::TileMatmul,
40        StridedStageFamily,
41        StridedStageFamily,
42        Option<StridedStageFamily>,
43    >;
44    type GlobalConvolution = SimpleConvolutionFamily<Self::StageMatmul>;
45
46    type Args = TensorArgs;
47
48    fn into_tensor_handle<R: Runtime>(
49        client: &ComputeClient<R>,
50        handle: &TensorHandleRef<'_, R>,
51        ident: MatmulIdent,
52        dtype: StorageType,
53    ) -> Result<TensorHandle<R>, LaunchError> {
54        if has_valid_layout(handle, ident) {
55            Ok(TensorHandle::from_ref(handle, dtype))
56        } else {
57            into_contiguous(client, handle, dtype)
58        }
59    }
60
61    // TODO this is not the same as tma stages, it's stages in the sense of double buffering in matmul
62    fn num_stages() -> NumStages {
63        (1, 1).into()
64    }
65
66    fn selection<R: Runtime>(
67        client: &ComputeClient<R>,
68        problem: &ConvolutionProblem,
69        plane_dim: u32,
70        dtypes: &mut MatmulElems,
71    ) -> Result<MatmulSelection, MatmulSetupError> {
72        Ok(convolution_matmul_selection::<TMM, R>(
73            client, problem, plane_dim, dtypes,
74        )?)
75    }
76}
77
78fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>, ident: MatmulIdent) -> bool {
79    let rank = handle.shape.len();
80    let dim_c = rank - 1;
81    match ident {
82        MatmulIdent::Lhs => handle.strides[dim_c] == 1,
83        MatmulIdent::Rhs => handle.strides[dim_c] == 1,
84        MatmulIdent::Out => unreachable!(),
85    }
86}