cubecl_linalg/convolution/algorithm/
simple.rs1use std::marker::PhantomData;
2
3use cubecl_core::{
4 CubeCount, CubeDim, Runtime,
5 client::ComputeClient,
6 prelude::{Numeric, TensorHandleRef},
7};
8
9use crate::{
10 convolution::{base::ConvolutionProblem, homogeneous::simple::SimpleConvolutionFamily},
11 matmul::components::{
12 InputIdent, MatmulSelection,
13 global::args::TensorArgs,
14 stage::{FullReaderFamily, plane_matmul::PlaneMatmulFamily},
15 tile::TileMatmulFamily,
16 },
17 tensor::{TensorHandle, into_contiguous},
18};
19
20use super::Algorithm;
21
22pub struct SimpleConvAlgorithm<TMM: TileMatmulFamily> {
24 _tmm: PhantomData<TMM>,
25}
26
27impl<TMM: TileMatmulFamily> Algorithm for SimpleConvAlgorithm<TMM> {
28 type TileMatmul = TMM;
29 type StageMatmul = PlaneMatmulFamily<Self::TileMatmul, FullReaderFamily>;
30 type GlobalConvolution = SimpleConvolutionFamily<Self::StageMatmul>;
31
32 type Args = TensorArgs;
33
34 fn cube_dim(selection: &MatmulSelection) -> CubeDim {
35 CubeDim::new(selection.plane_dim, selection.tile_count.m, 1)
36 }
37
38 fn cube_count(selection: &MatmulSelection, problem: &ConvolutionProblem) -> CubeCount {
39 let m_stage = selection.tile_count.m * selection.tile_shape.m;
40 let n_stage = selection.tile_count.n * selection.tile_shape.n;
41 let cubes_needed_m = (problem.m as u32).div_ceil(m_stage);
42 let cubes_needed_n = (problem.n as u32).div_ceil(n_stage);
43
44 CubeCount::Static(cubes_needed_m, cubes_needed_n, 1)
45 }
46
47 fn into_tensor_handle<R: Runtime, E: Numeric>(
48 client: &ComputeClient<R::Server, R::Channel>,
49 handle: &TensorHandleRef<'_, R>,
50 ident: crate::matmul::components::InputIdent,
51 ) -> TensorHandle<R, E> {
52 let mut handle = if has_valid_layout(handle, ident) {
53 TensorHandle::from_ref(handle)
54 } else {
55 into_contiguous(client, handle)
56 };
57 match ident {
58 InputIdent::Lhs => handle,
59 InputIdent::Rhs => {
60 handle.shape = vec![handle.shape[1..].iter().product(), handle.shape[0]];
62 handle.strides = vec![handle.strides[3], handle.strides[0]];
63 handle
64 }
65 }
66 }
67}
68
69fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>, ident: InputIdent) -> bool {
70 match ident {
71 InputIdent::Lhs => handle.strides[3] == 1,
72 InputIdent::Rhs => {
73 let mut strides = handle.strides.to_vec();
74 strides.sort();
75 let ordered = handle.strides == strides;
76 let contiguous_k = strides[3] * handle.shape[3] == strides[2]
77 && strides[2] * handle.shape[2] == handle.strides[1];
78 ordered && contiguous_k
79 }
80 }
81}