cubecl_linalg/convolution/
base.rs1use crate::matmul::{
2 components::{
3 InputRuntimeArg, InvalidConfigError, MatmulPrecision, MatmulProblem, MatmulSpec,
4 MatrixLayout, OutputRuntimeArg,
5 global::{AccumulatorLoader, OutputLoader},
6 },
7 kernels::MatmulAvailabilityError,
8};
9use cubecl_core as cubecl;
10use cubecl_core::prelude::*;
11use cubecl_std::{
12 CubeOption, FastDivmod,
13 tensor::r#virtual::{ReadWrite, VirtualTensor},
14};
15
16use super::ConvGemmConfig;
17
18#[derive(CubeType, CubeLaunch, Clone)]
19pub struct RuntimeArgs {
20 pub size_m: u32,
21 pub size_n: u32,
22 pub size_k: u32,
23 pub padded_channels: FastDivmod,
24 pub out_h: FastDivmod,
25 pub out_w: FastDivmod,
26}
27
28pub trait ConvolutionFamily:
29 ConvolutionConfigFactory<Config: ConvGemmConfig> + ConvolutionLaunch
30{
31 type Convolution<MP: MatmulPrecision>: Convolution<MP, Config = Self::Config>;
32}
33
34#[cube]
35pub trait Convolution<MP: MatmulPrecision>: 'static + Send + Sync {
36 type LhsLoader: CubeType;
37 type RhsLoader: CubeType;
38 type Config: ConvGemmConfig;
39 type AccumulatorLoader: AccumulatorLoader<MP>;
40
41 type Out: OutputLoader<MP::EO>;
42 type Accumulator: CubeType;
43
44 fn execute(
51 lhs_loader: Self::LhsLoader,
52 rhs_loader: Self::RhsLoader,
53 acc_loader: Self::AccumulatorLoader,
54 unloader: Self::Out,
55 acc: &mut Self::Accumulator,
56 k_range: (u32, u32),
57 #[comptime] config: Self::Config,
58 );
59
60 fn init_lhs_loader(
61 lhs: VirtualTensor<MP::EI>,
62 x_offset: u32,
63 y_offset: u32,
64 runtime_args: &RuntimeArgs,
65 #[comptime] config: Self::Config,
66 ) -> Self::LhsLoader;
67
68 fn init_rhs_loader(
69 rhs: VirtualTensor<MP::EI>,
70 x_offset: u32,
71 y_offset: u32,
72 runtime_args: &RuntimeArgs,
73 #[comptime] config: Self::Config,
74 ) -> Self::RhsLoader;
75
76 fn init_bias_loader(
77 bias: CubeOption<VirtualTensor<MP::EO>>,
78 n_offset: u32,
79 #[comptime] config: Self::Config,
80 ) -> Self::AccumulatorLoader;
81
82 fn init_unloader(
83 out: VirtualTensor<MP::EO, ReadWrite>,
84 x_offset: u32,
85 y_offset: u32,
86 ) -> Self::Out;
87
88 fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator;
89}
90
91pub trait ConvolutionConfigFactory: Send + Sync + 'static {
93 type Config: ConvGemmConfig;
95 type Input;
96
97 fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError>;
99
100 fn make_config(
101 input: Self::Input,
102 problem: &ConvolutionProblem,
103 cube_dim: &CubeDim,
104 cube_count: &CubeCount,
105 ) -> Self::Config;
106
107 fn check_availability<R: Runtime, MP: MatmulPrecision>(
108 client: &ComputeClient<R::Server, R::Channel>,
109 config: &Self::Config,
110 ) -> Result<(), MatmulAvailabilityError>;
111}
112
113pub trait ConvolutionLaunch: ConvolutionConfigFactory {
115 #[allow(clippy::too_many_arguments)]
121 unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
122 client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
123 cube_dim: CubeDim,
124 cube_count: CubeCount,
125 input: InputRuntimeArg<'a, MS, R>,
126 bias: Option<TensorArg<'a, R>>,
127 output: OutputRuntimeArg<'a, MS, R>,
128 problem: &ConvolutionProblem,
129 config: <Self as ConvolutionConfigFactory>::Config,
130 );
131}
132
133#[derive(Clone, Debug)]
134pub struct ConvolutionProblem {
136 pub m: usize,
137 pub n: usize,
138 pub k: usize,
139 pub lhs_layout: MatrixLayout,
140 pub rhs_layout: MatrixLayout,
141 pub lhs_line_size: u8,
142 pub rhs_line_size: u8,
143 pub out_line_size: u8,
144
145 pub kernel_size: (u32, u32),
146 pub stride: (u32, u32),
147 pub padding: (i32, i32),
148 pub dilation: (u32, u32),
149
150 pub batches: usize,
151 pub height: usize,
152 pub width: usize,
153 pub channels: usize,
154
155 pub out_h: usize,
156 pub out_w: usize,
157}
158
159impl ConvolutionProblem {
160 pub fn as_matmul_problem(&self) -> MatmulProblem {
161 MatmulProblem {
162 m: self.m,
163 n: self.n,
164 k: self.k,
165 batches: (vec![], vec![]),
166 lhs_layout: self.lhs_layout,
167 rhs_layout: self.rhs_layout,
168 lhs_line_size: self.lhs_line_size,
169 rhs_line_size: self.rhs_line_size,
170 out_line_size: self.out_line_size,
171 }
172 }
173}