cubecl_linalg/convolution/homogeneous/
base.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::{
4    CubeOption, CubeOptionExpand,
5    tensor::r#virtual::{ReadWrite, VirtualTensor},
6};
7
8use crate::{
9    convolution::base::{Convolution, ConvolutionFamily, RuntimeArgs},
10    matmul::components::{
11        Ident,
12        global::{
13            GlobalConfig,
14            args::{MatmulArgs, TensorInput, TensorInputIdent, TensorOutput},
15        },
16    },
17};
18
19type Input<Args, EI> = <Args as MatmulArgs>::Input<EI>;
20type Output<Args, EO> = <Args as MatmulArgs>::Output<EO>;
21
22#[cube(launch_unchecked)]
23pub(crate) fn implicit_conv<
24    Args: MatmulArgs,
25    EI: Numeric,
26    ES: Numeric,
27    EA: Numeric,
28    EO: Numeric,
29    GMM: ConvolutionFamily,
30>(
31    inputs: &Input<Args, EI>,
32    bias: &CubeOption<Tensor<Line<EO>>>,
33    output: &mut Output<Args, EO>,
34    runtime_args: RuntimeArgs,
35    #[comptime] config: GMM::Config,
36) {
37    let mut state = Args::init_state(inputs, output);
38
39    let lhs = TensorInput::<EI, EO, Args>::new(&state, TensorInputIdent::Lhs);
40    let rhs = TensorInput::<EI, EO, Args>::new(&state, TensorInputIdent::Rhs);
41    let mut out = TensorOutput::<EI, EO, Args>::new(&mut state);
42
43    let lhs = VirtualTensor::<EI>::new::<TensorInput<EI, EO, Args>>(&lhs);
44    let rhs = VirtualTensor::<EI>::new::<TensorInput<EI, EO, Args>>(&rhs);
45    let out = VirtualTensor::<EO, ReadWrite>::new::<TensorOutput<EI, EO, Args>>(&mut out);
46
47    let x_offset = CUBE_POS_X * config.tiling_dimensions(Ident::Lhs).total_row();
48    let y_offset = CUBE_POS_Y * config.tiling_dimensions(Ident::Rhs).total_col();
49    let k_range = (0, runtime_args.size_k);
50
51    let bias = match bias {
52        CubeOption::Some(bias) => {
53            CubeOption::new_Some(VirtualTensor::<EO>::new::<Tensor<Line<EO>>>(bias))
54        }
55        CubeOption::None => CubeOption::new_None(),
56    };
57
58    GMM::Convolution::<(EI, ES, EA, EO)>::execute(
59        GMM::Convolution::<(EI, ES, EA, EO)>::init_lhs_loader(
60            lhs,
61            x_offset,
62            k_range.0,
63            &runtime_args,
64            config,
65        ),
66        GMM::Convolution::<(EI, ES, EA, EO)>::init_rhs_loader(
67            rhs,
68            k_range.0,
69            y_offset,
70            &runtime_args,
71            config,
72        ),
73        GMM::Convolution::<(EI, ES, EA, EO)>::init_bias_loader(bias, y_offset, config),
74        GMM::Convolution::<(EI, ES, EA, EO)>::init_unloader(out, x_offset, y_offset),
75        &mut GMM::Convolution::<(EI, ES, EA, EO)>::init_accumulator(config),
76        k_range,
77        config,
78    );
79}
80
81pub mod config {
82    use std::ops::Deref;
83
84    use crate::{
85        convolution::ConvGemmConfig,
86        matmul::components::{
87            MatmulConfig, MatrixLayout, TilingDimensions, global::GlobalConfig,
88            global::PRECOMPUTE_JOB,
89        },
90    };
91
92    use super::*;
93
94    #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
95    pub struct HomogeneousConfig<M: GlobalConfig> {
96        matmul: M,
97        kernel_size: (u32, u32),
98        stride: (u32, u32),
99        dilation: (u32, u32),
100        padding: (i32, i32),
101    }
102
103    impl<M: GlobalConfig> Deref for HomogeneousConfig<M> {
104        type Target = M;
105
106        fn deref(&self) -> &Self::Target {
107            &self.matmul
108        }
109    }
110
111    impl<M: GlobalConfig> GlobalConfig for HomogeneousConfig<M> {
112        type SmmConfig = M::SmmConfig;
113
114        fn to_smm_config(&self) -> Self::SmmConfig {
115            self.matmul.to_smm_config()
116        }
117
118        fn global_line_size<I: Into<Ident>>(&self, ident: I) -> u32 {
119            self.matmul.global_line_size(ident)
120        }
121
122        fn tiling_dimensions<I: Into<Ident>>(&self, ident: I) -> TilingDimensions {
123            self.matmul.tiling_dimensions(ident)
124        }
125
126        fn matrix_layout<I: Into<Ident>>(&self, ident: I) -> MatrixLayout {
127            self.matmul.matrix_layout(ident)
128        }
129
130        fn num_planes(&self) -> u32 {
131            self.matmul.num_planes()
132        }
133
134        fn plane_dim(&self) -> u32 {
135            self.matmul.plane_dim()
136        }
137
138        fn check_row_bounds<I: Into<Ident>>(&self, ident: I) -> bool {
139            self.matmul.check_row_bounds(ident)
140        }
141
142        fn check_col_bounds<I: Into<Ident>>(&self, ident: I) -> bool {
143            self.matmul.check_col_bounds(ident)
144        }
145
146        fn check_k_bounds(&self) -> bool {
147            self.matmul.check_k_bounds()
148        }
149
150        fn precompute_job(&self) -> bool {
151            PRECOMPUTE_JOB
152        }
153    }
154
155    impl<M: GlobalConfig> ConvGemmConfig for HomogeneousConfig<M> {
156        fn kernel_size(&self, dim: u32) -> u32 {
157            match dim {
158                0 => self.kernel_size.0,
159                1 => self.kernel_size.1,
160                _ => unreachable!(),
161            }
162        }
163
164        fn dilation(&self, dim: u32) -> u32 {
165            match dim {
166                0 => self.dilation.0,
167                1 => self.dilation.1,
168                _ => unreachable!(),
169            }
170        }
171
172        fn stride(&self, dim: u32) -> u32 {
173            match dim {
174                0 => self.stride.0,
175                1 => self.stride.1,
176                _ => unreachable!(),
177            }
178        }
179
180        fn padding(&self, dim: u32) -> i32 {
181            match dim {
182                0 => self.padding.0,
183                1 => self.padding.1,
184                _ => unreachable!(),
185            }
186        }
187    }
188
189    impl<M: GlobalConfig> MatmulConfig for HomogeneousConfig<M> {}
190
191    impl<M: GlobalConfig> HomogeneousConfig<M> {
192        #[allow(clippy::too_many_arguments)]
193        pub fn new(
194            matmul: M,
195            kernel_size: (u32, u32),
196            stride: (u32, u32),
197            dilation: (u32, u32),
198            padding: (i32, i32),
199        ) -> Self {
200            Self {
201                matmul,
202                kernel_size,
203                stride,
204                dilation,
205                padding,
206            }
207        }
208
209        pub fn to_matmul_config(self) -> M {
210            self.matmul
211        }
212    }
213}