cubecl_linalg/convolution/algorithm/
simple.rs

1use std::marker::PhantomData;
2
3use cubecl_core::{
4    CubeCount, CubeDim, Runtime,
5    client::ComputeClient,
6    prelude::{Numeric, TensorHandleRef},
7};
8
9use crate::{
10    convolution::{base::ConvolutionProblem, homogeneous::simple::SimpleConvolutionFamily},
11    matmul::components::{
12        InputIdent, MatmulSelection,
13        global::args::TensorArgs,
14        stage::{FullReaderFamily, plane_matmul::PlaneMatmulFamily},
15        tile::TileMatmulFamily,
16    },
17    tensor::{TensorHandle, into_contiguous},
18};
19
20use super::Algorithm;
21
22/// Cmma convolution
23pub struct SimpleConvAlgorithm<TMM: TileMatmulFamily> {
24    _tmm: PhantomData<TMM>,
25}
26
27impl<TMM: TileMatmulFamily> Algorithm for SimpleConvAlgorithm<TMM> {
28    type TileMatmul = TMM;
29    type StageMatmul = PlaneMatmulFamily<Self::TileMatmul, FullReaderFamily>;
30    type GlobalConvolution = SimpleConvolutionFamily<Self::StageMatmul>;
31
32    type Args = TensorArgs;
33
34    fn cube_dim(selection: &MatmulSelection) -> CubeDim {
35        CubeDim::new(selection.plane_dim, selection.tile_count.m, 1)
36    }
37
38    fn cube_count(selection: &MatmulSelection, problem: &ConvolutionProblem) -> CubeCount {
39        let m_stage = selection.tile_count.m * selection.tile_shape.m;
40        let n_stage = selection.tile_count.n * selection.tile_shape.n;
41        let cubes_needed_m = (problem.m as u32).div_ceil(m_stage);
42        let cubes_needed_n = (problem.n as u32).div_ceil(n_stage);
43
44        CubeCount::Static(cubes_needed_m, cubes_needed_n, 1)
45    }
46
47    fn into_tensor_handle<R: Runtime, E: Numeric>(
48        client: &ComputeClient<R::Server, R::Channel>,
49        handle: &TensorHandleRef<'_, R>,
50        ident: crate::matmul::components::InputIdent,
51    ) -> TensorHandle<R, E> {
52        let mut handle = if has_valid_layout(handle, ident) {
53            TensorHandle::from_ref(handle)
54        } else {
55            into_contiguous(client, handle)
56        };
57        match ident {
58            InputIdent::Lhs => handle,
59            InputIdent::Rhs => {
60                // Reshape to (K, N) so the loader knows how to handle it
61                handle.shape = vec![handle.shape[1..].iter().product(), handle.shape[0]];
62                handle.strides = vec![handle.strides[3], handle.strides[0]];
63                handle
64            }
65        }
66    }
67}
68
69fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>, ident: InputIdent) -> bool {
70    match ident {
71        InputIdent::Lhs => handle.strides[3] == 1,
72        InputIdent::Rhs => {
73            let mut strides = handle.strides.to_vec();
74            strides.sort();
75            let ordered = handle.strides == strides;
76            let contiguous_k = strides[3] * handle.shape[3] == strides[2]
77                && strides[2] * handle.shape[2] == handle.strides[1];
78            ordered && contiguous_k
79        }
80    }
81}