cubecl_linalg/convolution/homogeneous/
base.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::{
4 CubeOption, CubeOptionExpand,
5 tensor::r#virtual::{ReadWrite, VirtualTensor},
6};
7
8use crate::{
9 convolution::base::{Convolution, ConvolutionFamily, RuntimeArgs},
10 matmul::components::{
11 Ident,
12 global::{
13 GlobalConfig,
14 args::{MatmulArgs, TensorInput, TensorInputIdent, TensorOutput},
15 },
16 },
17};
18
19type Input<Args, EI> = <Args as MatmulArgs>::Input<EI>;
20type Output<Args, EO> = <Args as MatmulArgs>::Output<EO>;
21
22#[cube(launch_unchecked)]
23pub(crate) fn implicit_conv<
24 Args: MatmulArgs,
25 EI: Numeric,
26 ES: Numeric,
27 EA: Numeric,
28 EO: Numeric,
29 GMM: ConvolutionFamily,
30>(
31 inputs: &Input<Args, EI>,
32 bias: &CubeOption<Tensor<Line<EO>>>,
33 output: &mut Output<Args, EO>,
34 runtime_args: RuntimeArgs,
35 #[comptime] config: GMM::Config,
36) {
37 let mut state = Args::init_state(inputs, output);
38
39 let lhs = TensorInput::<EI, EO, Args>::new(&state, TensorInputIdent::Lhs);
40 let rhs = TensorInput::<EI, EO, Args>::new(&state, TensorInputIdent::Rhs);
41 let mut out = TensorOutput::<EI, EO, Args>::new(&mut state);
42
43 let lhs = VirtualTensor::<EI>::new::<TensorInput<EI, EO, Args>>(&lhs);
44 let rhs = VirtualTensor::<EI>::new::<TensorInput<EI, EO, Args>>(&rhs);
45 let out = VirtualTensor::<EO, ReadWrite>::new::<TensorOutput<EI, EO, Args>>(&mut out);
46
47 let x_offset = CUBE_POS_X * config.tiling_dimensions(Ident::Lhs).total_row();
48 let y_offset = CUBE_POS_Y * config.tiling_dimensions(Ident::Rhs).total_col();
49 let k_range = (0, runtime_args.size_k);
50
51 let bias = match bias {
52 CubeOption::Some(bias) => {
53 CubeOption::new_Some(VirtualTensor::<EO>::new::<Tensor<Line<EO>>>(bias))
54 }
55 CubeOption::None => CubeOption::new_None(),
56 };
57
58 GMM::Convolution::<(EI, ES, EA, EO)>::execute(
59 GMM::Convolution::<(EI, ES, EA, EO)>::init_lhs_loader(
60 lhs,
61 x_offset,
62 k_range.0,
63 &runtime_args,
64 config,
65 ),
66 GMM::Convolution::<(EI, ES, EA, EO)>::init_rhs_loader(
67 rhs,
68 k_range.0,
69 y_offset,
70 &runtime_args,
71 config,
72 ),
73 GMM::Convolution::<(EI, ES, EA, EO)>::init_bias_loader(bias, y_offset, config),
74 GMM::Convolution::<(EI, ES, EA, EO)>::init_unloader(out, x_offset, y_offset),
75 &mut GMM::Convolution::<(EI, ES, EA, EO)>::init_accumulator(config),
76 k_range,
77 config,
78 );
79}
80
81pub mod config {
82 use std::ops::Deref;
83
84 use crate::{
85 convolution::ConvGemmConfig,
86 matmul::components::{
87 MatmulConfig, MatrixLayout, TilingDimensions, global::GlobalConfig,
88 global::PRECOMPUTE_JOB,
89 },
90 };
91
92 use super::*;
93
94 #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
95 pub struct HomogeneousConfig<M: GlobalConfig> {
96 matmul: M,
97 kernel_size: (u32, u32),
98 stride: (u32, u32),
99 dilation: (u32, u32),
100 padding: (i32, i32),
101 }
102
103 impl<M: GlobalConfig> Deref for HomogeneousConfig<M> {
104 type Target = M;
105
106 fn deref(&self) -> &Self::Target {
107 &self.matmul
108 }
109 }
110
111 impl<M: GlobalConfig> GlobalConfig for HomogeneousConfig<M> {
112 type SmmConfig = M::SmmConfig;
113
114 fn to_smm_config(&self) -> Self::SmmConfig {
115 self.matmul.to_smm_config()
116 }
117
118 fn global_line_size<I: Into<Ident>>(&self, ident: I) -> u32 {
119 self.matmul.global_line_size(ident)
120 }
121
122 fn tiling_dimensions<I: Into<Ident>>(&self, ident: I) -> TilingDimensions {
123 self.matmul.tiling_dimensions(ident)
124 }
125
126 fn matrix_layout<I: Into<Ident>>(&self, ident: I) -> MatrixLayout {
127 self.matmul.matrix_layout(ident)
128 }
129
130 fn num_planes(&self) -> u32 {
131 self.matmul.num_planes()
132 }
133
134 fn plane_dim(&self) -> u32 {
135 self.matmul.plane_dim()
136 }
137
138 fn check_row_bounds<I: Into<Ident>>(&self, ident: I) -> bool {
139 self.matmul.check_row_bounds(ident)
140 }
141
142 fn check_col_bounds<I: Into<Ident>>(&self, ident: I) -> bool {
143 self.matmul.check_col_bounds(ident)
144 }
145
146 fn check_k_bounds(&self) -> bool {
147 self.matmul.check_k_bounds()
148 }
149
150 fn precompute_job(&self) -> bool {
151 PRECOMPUTE_JOB
152 }
153 }
154
155 impl<M: GlobalConfig> ConvGemmConfig for HomogeneousConfig<M> {
156 fn kernel_size(&self, dim: u32) -> u32 {
157 match dim {
158 0 => self.kernel_size.0,
159 1 => self.kernel_size.1,
160 _ => unreachable!(),
161 }
162 }
163
164 fn dilation(&self, dim: u32) -> u32 {
165 match dim {
166 0 => self.dilation.0,
167 1 => self.dilation.1,
168 _ => unreachable!(),
169 }
170 }
171
172 fn stride(&self, dim: u32) -> u32 {
173 match dim {
174 0 => self.stride.0,
175 1 => self.stride.1,
176 _ => unreachable!(),
177 }
178 }
179
180 fn padding(&self, dim: u32) -> i32 {
181 match dim {
182 0 => self.padding.0,
183 1 => self.padding.1,
184 _ => unreachable!(),
185 }
186 }
187 }
188
189 impl<M: GlobalConfig> MatmulConfig for HomogeneousConfig<M> {}
190
191 impl<M: GlobalConfig> HomogeneousConfig<M> {
192 #[allow(clippy::too_many_arguments)]
193 pub fn new(
194 matmul: M,
195 kernel_size: (u32, u32),
196 stride: (u32, u32),
197 dilation: (u32, u32),
198 padding: (i32, i32),
199 ) -> Self {
200 Self {
201 matmul,
202 kernel_size,
203 stride,
204 dilation,
205 padding,
206 }
207 }
208
209 pub fn to_matmul_config(self) -> M {
210 self.matmul
211 }
212 }
213}