cubecl_convolution/homogeneous/
base.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::{
4    CubeOption, CubeOptionExpand, FastDivmod, FastDivmodArgs,
5    tensor::r#virtual::{ReadWrite, VirtualTensor},
6};
7
8use crate::base::{Convolution, ConvolutionFamily, RuntimeArgs};
9
10use cubecl_matmul::components::{
11    Ident,
12    global::{
13        GlobalConfig,
14        args::{MatmulArgs, TensorInput, TensorInputIdent, TensorOutput},
15    },
16};
17
18type Input<Args, EI> = <Args as MatmulArgs>::Input<EI>;
19type Output<Args, EO> = <Args as MatmulArgs>::Output<EO>;
20
21#[cube(launch_unchecked)]
22pub(crate) fn implicit_conv<
23    Args: MatmulArgs,
24    EI: Numeric,
25    ES: Numeric,
26    EA: Numeric,
27    EO: Numeric,
28    GMM: ConvolutionFamily,
29>(
30    inputs: &Input<Args, EI>,
31    bias: &CubeOption<Tensor<Line<EO>>>,
32    output: &mut Output<Args, EO>,
33    runtime_args: RuntimeArgs,
34    #[comptime] config: GMM::Config,
35) {
36    let mut state = Args::init_state(inputs, output);
37
38    let lhs = TensorInput::<EI, EO, Args>::new(&state, TensorInputIdent::Lhs);
39    let rhs = TensorInput::<EI, EO, Args>::new(&state, TensorInputIdent::Rhs);
40    let mut out = TensorOutput::<EI, EO, Args>::new(&mut state);
41
42    let lhs = VirtualTensor::<EI>::new::<TensorInput<EI, EO, Args>>(&lhs);
43    let rhs = VirtualTensor::<EI>::new::<TensorInput<EI, EO, Args>>(&rhs);
44    let out = VirtualTensor::<EO, ReadWrite>::new::<TensorOutput<EI, EO, Args>>(&mut out);
45
46    let x_offset = CUBE_POS_X * config.tiling_scheme().elements_in_stage_m();
47    let y_offset = CUBE_POS_Y * config.tiling_scheme().elements_in_stage_n();
48    let k_range = (0, runtime_args.size_k);
49
50    let bias = match bias {
51        CubeOption::Some(bias) => {
52            CubeOption::new_Some(VirtualTensor::<EO>::new::<Tensor<Line<EO>>>(bias))
53        }
54        CubeOption::None => CubeOption::new_None(),
55    };
56
57    GMM::Convolution::<(EI, ES, EA, EO)>::execute(
58        GMM::Convolution::<(EI, ES, EA, EO)>::init_lhs_loader(
59            lhs,
60            x_offset,
61            k_range.0,
62            &runtime_args,
63            config,
64        ),
65        GMM::Convolution::<(EI, ES, EA, EO)>::init_rhs_loader(
66            rhs,
67            k_range.0,
68            y_offset,
69            &runtime_args,
70            config,
71        ),
72        GMM::Convolution::<(EI, ES, EA, EO)>::init_bias_loader(bias, y_offset, config),
73        GMM::Convolution::<(EI, ES, EA, EO)>::init_writer(out, x_offset, y_offset),
74        &mut GMM::Convolution::<(EI, ES, EA, EO)>::init_accumulator(config),
75        k_range,
76        config,
77    );
78}
79
80pub(crate) fn shape_divmod<'a, R: Runtime>(
81    client: &ComputeClient<R::Server, R::Channel>,
82    shape: &[usize],
83) -> SequenceArg<'a, R, FastDivmod> {
84    let shape = shape
85        .iter()
86        .map(|s| FastDivmodArgs::new(client, *s as u32))
87        .collect();
88    SequenceArg { values: shape }
89}
90
91pub mod config {
92    use std::ops::Deref;
93
94    use crate::{ConvGemmConfig, base::Dimensionality};
95    use cubecl_matmul::components::{
96        InputIdent, MatmulLineSizes, MatmulSetupError, MatrixLayout, TilingScheme,
97        global::{
98            GlobalConfig, PlaneRoleConfig, SpecializedLoadingSides, load::LoaderMode,
99            multi_stage::EventLoadingMode,
100        },
101    };
102
103    use super::*;
104
105    #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
106    pub struct ConvolutionConfig<M: GlobalConfig> {
107        matmul: M,
108        kernel_size: [u32; 3],
109        stride: [u32; 3],
110        dilation: [u32; 3],
111        padding: [i32; 3],
112        dimensionality: Dimensionality,
113        num_stages: u32,
114    }
115
116    impl<M: GlobalConfig> Deref for ConvolutionConfig<M> {
117        type Target = M;
118
119        fn deref(&self) -> &Self::Target {
120            &self.matmul
121        }
122    }
123
124    impl<M: GlobalConfig> GlobalConfig for ConvolutionConfig<M> {
125        type StageConfig = M::StageConfig;
126
127        fn stage_config(&self) -> Self::StageConfig {
128            self.matmul.stage_config()
129        }
130
131        fn global_line_size<I: Into<Ident>>(&self, ident: I) -> u32 {
132            self.matmul.global_line_size(ident)
133        }
134
135        fn matrix_layout<I: Into<Ident>>(&self, ident: I) -> MatrixLayout {
136            self.matmul.matrix_layout(ident)
137        }
138
139        fn num_loading_planes<I: Into<Ident>>(&self, ident: I) -> u32 {
140            self.matmul.num_loading_planes(ident)
141        }
142
143        fn plane_dim(&self) -> u32 {
144            self.matmul.plane_dim()
145        }
146
147        fn check_row_bounds<I: Into<Ident>>(&self, ident: I) -> bool {
148            self.matmul.check_row_bounds(ident)
149        }
150
151        fn check_col_bounds<I: Into<Ident>>(&self, ident: I) -> bool {
152            self.matmul.check_col_bounds(ident)
153        }
154
155        fn check_k_bounds(&self) -> bool {
156            self.matmul.check_k_bounds()
157        }
158
159        fn precompute_job(&self) -> bool {
160            self.matmul.precompute_job()
161        }
162
163        fn num_stages(&self, _ident: InputIdent) -> u32 {
164            self.num_stages
165        }
166
167        fn loader_mode(&self) -> LoaderMode {
168            self.matmul.loader_mode()
169        }
170
171        fn tiling_scheme(&self) -> TilingScheme {
172            self.matmul.tiling_scheme()
173        }
174
175        fn event_loading_mode(&self, ident: InputIdent) -> EventLoadingMode {
176            self.matmul.event_loading_mode(ident)
177        }
178
179        fn plane_role_config(&self) -> PlaneRoleConfig {
180            self.matmul.plane_role_config()
181        }
182
183        fn specialized_loading_sides(&self) -> SpecializedLoadingSides {
184            self.matmul.specialized_loading_sides()
185        }
186
187        fn cube_dim(&self) -> CubeDim {
188            CubeDim::new(self.plane_dim(), self.tiling_scheme().tiles_in_stage_m(), 1)
189        }
190    }
191
192    impl<M: GlobalConfig> ConvGemmConfig for ConvolutionConfig<M> {
193        fn kernel_size(&self, dim: u32) -> u32 {
194            self.kernel_size[dim as usize]
195        }
196
197        fn dilation(&self, dim: u32) -> u32 {
198            self.dilation[dim as usize]
199        }
200
201        fn stride(&self, dim: u32) -> u32 {
202            self.stride[dim as usize]
203        }
204
205        fn padding(&self, dim: u32) -> i32 {
206            self.padding[dim as usize]
207        }
208
209        fn dimensionality(&self) -> Dimensionality {
210            self.dimensionality
211        }
212
213        fn line_sizes(&self) -> cubecl_matmul::components::MatmulLineSizes {
214            MatmulLineSizes {
215                lhs: self.global_line_size(Ident::Lhs) as u8,
216                rhs: self.global_line_size(Ident::Rhs) as u8,
217                out: self.global_line_size(Ident::Out) as u8,
218            }
219        }
220    }
221
222    impl<M: GlobalConfig> ConvolutionConfig<M> {
223        #[allow(clippy::too_many_arguments)]
224        pub fn new(
225            matmul: M,
226            kernel_size: &[u32],
227            stride: &[u32],
228            dilation: &[u32],
229            padding: &[i32],
230            dim: Dimensionality,
231            num_stages: u32,
232        ) -> Result<Self, MatmulSetupError> {
233            let dims = kernel_size.len();
234
235            let mut this = Self {
236                matmul,
237                kernel_size: [0; 3],
238                stride: [0; 3],
239                dilation: [0; 3],
240                padding: [0; 3],
241                dimensionality: dim,
242                num_stages,
243            };
244            this.kernel_size[0..dims].copy_from_slice(kernel_size);
245            this.stride[0..dims].copy_from_slice(stride);
246            this.dilation[0..dims].copy_from_slice(dilation);
247            this.padding[0..dims].copy_from_slice(padding);
248            Ok(this)
249        }
250
251        pub fn to_matmul_config(self) -> M {
252            self.matmul
253        }
254    }
255}