cubecl_convolution/algorithm/
simple.rs

1use std::marker::PhantomData;
2
3use cubecl_core::ir::Elem;
4use cubecl_core::{
5    Runtime,
6    client::ComputeClient,
7    prelude::{Numeric, TensorHandleRef},
8};
9use cubecl_matmul::components::MatmulSelection;
10
11use crate::{
12    base::ConvolutionProblem, homogeneous::simple::SimpleConvolutionFamily,
13    selection::convolution_matmul_selection,
14};
15use cubecl_matmul::components::stage::NumStages;
16use cubecl_matmul::components::{
17    InputIdent,
18    global::args::TensorArgs,
19    stage::{FullReaderFamily, PlaneMatmulFamily},
20    tile::TileMatmulFamily,
21};
22
23use cubecl_std::tensor::{TensorHandle, into_contiguous};
24
25use super::Algorithm;
26
27/// Cmma convolution
28pub struct SimpleConvAlgorithm<TMM: TileMatmulFamily> {
29    _tmm: PhantomData<TMM>,
30}
31
32impl<TMM: TileMatmulFamily> Algorithm for SimpleConvAlgorithm<TMM> {
33    type TileMatmul = TMM;
34    type StageMatmul = PlaneMatmulFamily<Self::TileMatmul, FullReaderFamily, FullReaderFamily>;
35    type GlobalConvolution = SimpleConvolutionFamily<Self::StageMatmul>;
36
37    type Args = TensorArgs;
38
39    fn into_tensor_handle<R: Runtime, E: Numeric>(
40        client: &ComputeClient<R::Server, R::Channel>,
41        handle: &TensorHandleRef<'_, R>,
42        ident: cubecl_matmul::components::InputIdent,
43    ) -> TensorHandle<R, E> {
44        let rank = handle.shape.len();
45        let dim_c = rank - 1;
46        let mut handle = if has_valid_layout(handle, ident) {
47            TensorHandle::from_ref(handle)
48        } else {
49            into_contiguous(client, handle)
50        };
51        match ident {
52            InputIdent::Lhs => handle,
53            InputIdent::Rhs => {
54                // Reshape to (K, N) so the loader knows how to handle it
55                handle.shape = vec![handle.shape[1..].iter().product(), handle.shape[0]];
56                handle.strides = vec![handle.strides[dim_c], handle.strides[0]];
57                handle
58            }
59        }
60    }
61
62    // TODO this is not the same as tma stages, it's stages in the sense of double buffering in matmul
63    fn num_stages() -> NumStages {
64        (1, 1).into()
65    }
66
67    fn selection<R: Runtime>(
68        client: &ComputeClient<R::Server, R::Channel>,
69        problem: &ConvolutionProblem,
70        plane_dim: u32,
71        elem_stage: Elem,
72        elem_acc: Elem,
73    ) -> MatmulSelection {
74        convolution_matmul_selection::<TMM, R>(client, problem, plane_dim, elem_stage, elem_acc)
75    }
76}
77
78fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>, ident: InputIdent) -> bool {
79    let rank = handle.shape.len();
80    let dim_c = rank - 1;
81    match ident {
82        InputIdent::Lhs => handle.strides[dim_c] == 1,
83        InputIdent::Rhs => {
84            let mut strides = handle.strides.to_vec();
85            strides.sort();
86            let ordered = handle.strides == strides;
87            let mut contiguous_k = true;
88            for i in 1..dim_c {
89                contiguous_k &= strides[i] == strides[i + 1] * handle.shape[i + 1];
90            }
91            ordered && contiguous_k
92        }
93    }
94}