cubecl_convolution/kernels/layered/algorithm/
simple.rs

1use std::marker::PhantomData;
2
3use cubecl_core::{
4    Runtime,
5    client::ComputeClient,
6    prelude::{Numeric, TensorHandleRef},
7};
8use cubecl_matmul::components::{
9    MatmulElems, MatmulSelection, MatmulSetupError, stage::StridedStageFamily, tile::io::Strided,
10};
11
12use cubecl_matmul::components::stage::NumStages;
13use cubecl_matmul::components::{
14    MatmulIdent, global::args::TensorArgs, stage::PlaneMatmulFamily, tile::TileMatmulFamily,
15};
16
17use cubecl_std::{
18    CubeOption,
19    tensor::{TensorHandle, into_contiguous},
20};
21
22use crate::components::{
23    ConvolutionProblem, convolution_matmul_selection,
24    global::single_stage::simple::SimpleConvolutionFamily,
25};
26
27use super::Algorithm;
28
29/// Cmma convolution
30pub struct SimpleConvAlgorithm<TMM: TileMatmulFamily> {
31    _tmm: PhantomData<TMM>,
32}
33
34impl<
35    TMM: TileMatmulFamily<
36            LhsTile = Strided,
37            RhsTile = Strided,
38            AccTile = CubeOption<Strided>,
39            OutTile = Strided,
40        >,
41> Algorithm for SimpleConvAlgorithm<TMM>
42{
43    type TileMatmul = TMM;
44    type StageMatmul = PlaneMatmulFamily<
45        Self::TileMatmul,
46        StridedStageFamily,
47        StridedStageFamily,
48        Option<StridedStageFamily>,
49    >;
50    type GlobalConvolution = SimpleConvolutionFamily<Self::StageMatmul>;
51
52    type Args = TensorArgs;
53
54    fn into_tensor_handle<R: Runtime, E: Numeric>(
55        client: &ComputeClient<R::Server>,
56        handle: &TensorHandleRef<'_, R>,
57        ident: MatmulIdent,
58    ) -> TensorHandle<R, E> {
59        if has_valid_layout(handle, ident) {
60            TensorHandle::from_ref(handle)
61        } else {
62            into_contiguous(client, handle)
63        }
64    }
65
66    // TODO this is not the same as tma stages, it's stages in the sense of double buffering in matmul
67    fn num_stages() -> NumStages {
68        (1, 1).into()
69    }
70
71    fn selection<R: Runtime>(
72        client: &ComputeClient<R::Server>,
73        problem: &ConvolutionProblem,
74        plane_dim: u32,
75        matmul_elems: MatmulElems,
76    ) -> Result<MatmulSelection, MatmulSetupError> {
77        Ok(convolution_matmul_selection::<TMM, R>(
78            client,
79            problem,
80            plane_dim,
81            matmul_elems,
82        ))
83    }
84}
85
86fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>, ident: MatmulIdent) -> bool {
87    let rank = handle.shape.len();
88    let dim_c = rank - 1;
89    match ident {
90        MatmulIdent::Lhs => handle.strides[dim_c] == 1,
91        MatmulIdent::Rhs => handle.strides[dim_c] == 1,
92        MatmulIdent::Out => unreachable!(),
93    }
94}