cubecl_convolution/algorithm/
simple.rs1use std::marker::PhantomData;
2
3use cubecl_core::ir::Elem;
4use cubecl_core::{
5 Runtime,
6 client::ComputeClient,
7 prelude::{Numeric, TensorHandleRef},
8};
9use cubecl_matmul::components::MatmulSelection;
10
11use crate::{
12 base::ConvolutionProblem, homogeneous::simple::SimpleConvolutionFamily,
13 selection::convolution_matmul_selection,
14};
15use cubecl_matmul::components::stage::NumStages;
16use cubecl_matmul::components::{
17 InputIdent,
18 global::args::TensorArgs,
19 stage::{FullReaderFamily, PlaneMatmulFamily},
20 tile::TileMatmulFamily,
21};
22
23use cubecl_std::tensor::{TensorHandle, into_contiguous};
24
25use super::Algorithm;
26
27pub struct SimpleConvAlgorithm<TMM: TileMatmulFamily> {
29 _tmm: PhantomData<TMM>,
30}
31
32impl<TMM: TileMatmulFamily> Algorithm for SimpleConvAlgorithm<TMM> {
33 type TileMatmul = TMM;
34 type StageMatmul = PlaneMatmulFamily<Self::TileMatmul, FullReaderFamily, FullReaderFamily>;
35 type GlobalConvolution = SimpleConvolutionFamily<Self::StageMatmul>;
36
37 type Args = TensorArgs;
38
39 fn into_tensor_handle<R: Runtime, E: Numeric>(
40 client: &ComputeClient<R::Server, R::Channel>,
41 handle: &TensorHandleRef<'_, R>,
42 ident: cubecl_matmul::components::InputIdent,
43 ) -> TensorHandle<R, E> {
44 let rank = handle.shape.len();
45 let dim_c = rank - 1;
46 let mut handle = if has_valid_layout(handle, ident) {
47 TensorHandle::from_ref(handle)
48 } else {
49 into_contiguous(client, handle)
50 };
51 match ident {
52 InputIdent::Lhs => handle,
53 InputIdent::Rhs => {
54 handle.shape = vec![handle.shape[1..].iter().product(), handle.shape[0]];
56 handle.strides = vec![handle.strides[dim_c], handle.strides[0]];
57 handle
58 }
59 }
60 }
61
62 fn num_stages() -> NumStages {
64 (1, 1).into()
65 }
66
67 fn selection<R: Runtime>(
68 client: &ComputeClient<R::Server, R::Channel>,
69 problem: &ConvolutionProblem,
70 plane_dim: u32,
71 elem_stage: Elem,
72 elem_acc: Elem,
73 ) -> MatmulSelection {
74 convolution_matmul_selection::<TMM, R>(client, problem, plane_dim, elem_stage, elem_acc)
75 }
76}
77
78fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>, ident: InputIdent) -> bool {
79 let rank = handle.shape.len();
80 let dim_c = rank - 1;
81 match ident {
82 InputIdent::Lhs => handle.strides[dim_c] == 1,
83 InputIdent::Rhs => {
84 let mut strides = handle.strides.to_vec();
85 strides.sort();
86 let ordered = handle.strides == strides;
87 let mut contiguous_k = true;
88 for i in 1..dim_c {
89 contiguous_k &= strides[i] == strides[i + 1] * handle.shape[i + 1];
90 }
91 ordered && contiguous_k
92 }
93 }
94}