cubecl_convolution/homogeneous/
base.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::{
4 CubeOption, CubeOptionExpand, FastDivmod, FastDivmodArgs,
5 tensor::r#virtual::{ReadWrite, VirtualTensor},
6};
7
8use crate::base::{Convolution, ConvolutionFamily, RuntimeArgs};
9
10use cubecl_matmul::components::{
11 Ident,
12 global::{
13 GlobalConfig,
14 args::{MatmulArgs, TensorInput, TensorInputIdent, TensorOutput},
15 },
16};
17
18type Input<Args, EI> = <Args as MatmulArgs>::Input<EI>;
19type Output<Args, EO> = <Args as MatmulArgs>::Output<EO>;
20
21#[cube(launch_unchecked)]
22pub(crate) fn implicit_conv<
23 Args: MatmulArgs,
24 EI: Numeric,
25 ES: Numeric,
26 EA: Numeric,
27 EO: Numeric,
28 GMM: ConvolutionFamily,
29>(
30 inputs: &Input<Args, EI>,
31 bias: &CubeOption<Tensor<Line<EO>>>,
32 output: &mut Output<Args, EO>,
33 runtime_args: RuntimeArgs,
34 #[comptime] config: GMM::Config,
35) {
36 let mut state = Args::init_state(inputs, output);
37
38 let lhs = TensorInput::<EI, EO, Args>::new(&state, TensorInputIdent::Lhs);
39 let rhs = TensorInput::<EI, EO, Args>::new(&state, TensorInputIdent::Rhs);
40 let mut out = TensorOutput::<EI, EO, Args>::new(&mut state);
41
42 let lhs = VirtualTensor::<EI>::new::<TensorInput<EI, EO, Args>>(&lhs);
43 let rhs = VirtualTensor::<EI>::new::<TensorInput<EI, EO, Args>>(&rhs);
44 let out = VirtualTensor::<EO, ReadWrite>::new::<TensorOutput<EI, EO, Args>>(&mut out);
45
46 let x_offset = CUBE_POS_X * config.tiling_scheme().elements_in_stage_m();
47 let y_offset = CUBE_POS_Y * config.tiling_scheme().elements_in_stage_n();
48 let k_range = (0, runtime_args.size_k);
49
50 let bias = match bias {
51 CubeOption::Some(bias) => {
52 CubeOption::new_Some(VirtualTensor::<EO>::new::<Tensor<Line<EO>>>(bias))
53 }
54 CubeOption::None => CubeOption::new_None(),
55 };
56
57 GMM::Convolution::<(EI, ES, EA, EO)>::execute(
58 GMM::Convolution::<(EI, ES, EA, EO)>::init_lhs_loader(
59 lhs,
60 x_offset,
61 k_range.0,
62 &runtime_args,
63 config,
64 ),
65 GMM::Convolution::<(EI, ES, EA, EO)>::init_rhs_loader(
66 rhs,
67 k_range.0,
68 y_offset,
69 &runtime_args,
70 config,
71 ),
72 GMM::Convolution::<(EI, ES, EA, EO)>::init_bias_loader(bias, y_offset, config),
73 GMM::Convolution::<(EI, ES, EA, EO)>::init_writer(out, x_offset, y_offset),
74 &mut GMM::Convolution::<(EI, ES, EA, EO)>::init_accumulator(config),
75 k_range,
76 config,
77 );
78}
79
80pub(crate) fn shape_divmod<'a, R: Runtime>(
81 client: &ComputeClient<R::Server, R::Channel>,
82 shape: &[usize],
83) -> SequenceArg<'a, R, FastDivmod> {
84 let shape = shape
85 .iter()
86 .map(|s| FastDivmodArgs::new(client, *s as u32))
87 .collect();
88 SequenceArg { values: shape }
89}
90
91pub mod config {
92 use std::ops::Deref;
93
94 use crate::{ConvGemmConfig, base::Dimensionality};
95 use cubecl_matmul::components::{
96 InputIdent, MatmulLineSizes, MatmulSetupError, MatrixLayout, TilingScheme,
97 global::{
98 GlobalConfig, PlaneRoleConfig, SpecializedLoadingSides, load::LoaderMode,
99 multi_stage::EventLoadingMode,
100 },
101 };
102
103 use super::*;
104
105 #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
106 pub struct ConvolutionConfig<M: GlobalConfig> {
107 matmul: M,
108 kernel_size: [u32; 3],
109 stride: [u32; 3],
110 dilation: [u32; 3],
111 padding: [i32; 3],
112 dimensionality: Dimensionality,
113 num_stages: u32,
114 }
115
116 impl<M: GlobalConfig> Deref for ConvolutionConfig<M> {
117 type Target = M;
118
119 fn deref(&self) -> &Self::Target {
120 &self.matmul
121 }
122 }
123
124 impl<M: GlobalConfig> GlobalConfig for ConvolutionConfig<M> {
125 type StageConfig = M::StageConfig;
126
127 fn stage_config(&self) -> Self::StageConfig {
128 self.matmul.stage_config()
129 }
130
131 fn global_line_size<I: Into<Ident>>(&self, ident: I) -> u32 {
132 self.matmul.global_line_size(ident)
133 }
134
135 fn matrix_layout<I: Into<Ident>>(&self, ident: I) -> MatrixLayout {
136 self.matmul.matrix_layout(ident)
137 }
138
139 fn num_loading_planes<I: Into<Ident>>(&self, ident: I) -> u32 {
140 self.matmul.num_loading_planes(ident)
141 }
142
143 fn plane_dim(&self) -> u32 {
144 self.matmul.plane_dim()
145 }
146
147 fn check_row_bounds<I: Into<Ident>>(&self, ident: I) -> bool {
148 self.matmul.check_row_bounds(ident)
149 }
150
151 fn check_col_bounds<I: Into<Ident>>(&self, ident: I) -> bool {
152 self.matmul.check_col_bounds(ident)
153 }
154
155 fn check_k_bounds(&self) -> bool {
156 self.matmul.check_k_bounds()
157 }
158
159 fn precompute_job(&self) -> bool {
160 self.matmul.precompute_job()
161 }
162
163 fn num_stages(&self, _ident: InputIdent) -> u32 {
164 self.num_stages
165 }
166
167 fn loader_mode(&self) -> LoaderMode {
168 self.matmul.loader_mode()
169 }
170
171 fn tiling_scheme(&self) -> TilingScheme {
172 self.matmul.tiling_scheme()
173 }
174
175 fn event_loading_mode(&self, ident: InputIdent) -> EventLoadingMode {
176 self.matmul.event_loading_mode(ident)
177 }
178
179 fn plane_role_config(&self) -> PlaneRoleConfig {
180 self.matmul.plane_role_config()
181 }
182
183 fn specialized_loading_sides(&self) -> SpecializedLoadingSides {
184 self.matmul.specialized_loading_sides()
185 }
186
187 fn cube_dim(&self) -> CubeDim {
188 CubeDim::new(self.plane_dim(), self.tiling_scheme().tiles_in_stage_m(), 1)
189 }
190 }
191
192 impl<M: GlobalConfig> ConvGemmConfig for ConvolutionConfig<M> {
193 fn kernel_size(&self, dim: u32) -> u32 {
194 self.kernel_size[dim as usize]
195 }
196
197 fn dilation(&self, dim: u32) -> u32 {
198 self.dilation[dim as usize]
199 }
200
201 fn stride(&self, dim: u32) -> u32 {
202 self.stride[dim as usize]
203 }
204
205 fn padding(&self, dim: u32) -> i32 {
206 self.padding[dim as usize]
207 }
208
209 fn dimensionality(&self) -> Dimensionality {
210 self.dimensionality
211 }
212
213 fn line_sizes(&self) -> cubecl_matmul::components::MatmulLineSizes {
214 MatmulLineSizes {
215 lhs: self.global_line_size(Ident::Lhs) as u8,
216 rhs: self.global_line_size(Ident::Rhs) as u8,
217 out: self.global_line_size(Ident::Out) as u8,
218 }
219 }
220 }
221
222 impl<M: GlobalConfig> ConvolutionConfig<M> {
223 #[allow(clippy::too_many_arguments)]
224 pub fn new(
225 matmul: M,
226 kernel_size: &[u32],
227 stride: &[u32],
228 dilation: &[u32],
229 padding: &[i32],
230 dim: Dimensionality,
231 num_stages: u32,
232 ) -> Result<Self, MatmulSetupError> {
233 let dims = kernel_size.len();
234
235 let mut this = Self {
236 matmul,
237 kernel_size: [0; 3],
238 stride: [0; 3],
239 dilation: [0; 3],
240 padding: [0; 3],
241 dimensionality: dim,
242 num_stages,
243 };
244 this.kernel_size[0..dims].copy_from_slice(kernel_size);
245 this.stride[0..dims].copy_from_slice(stride);
246 this.dilation[0..dims].copy_from_slice(dilation);
247 this.padding[0..dims].copy_from_slice(padding);
248 Ok(this)
249 }
250
251 pub fn to_matmul_config(self) -> M {
252 self.matmul
253 }
254 }
255}