1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_matmul::components::{
4 AvailableLineSizes, InputRuntimeArg, MatmulLineSizes, MatmulPrecision, MatmulProblem,
5 MatmulSelection, MatmulSetupError, MatmulSpec, MatrixLayout, OutputRuntimeArg,
6 global::{AccumulatorLoader, GlobalWriter},
7};
8use cubecl_std::{
9 CubeOption, FastDivmod,
10 tensor::r#virtual::{ReadWrite, VirtualTensor},
11};
12
13use super::ConvGemmConfig;
14
15#[derive(CubeType, CubeLaunch, Clone)]
16pub struct RuntimeArgs {
17 pub size_m: u32,
18 pub size_n: u32,
19 pub size_k: u32,
20 pub padded_channels: FastDivmod,
21 pub out_shape: Sequence<FastDivmod>,
22}
23
24pub trait ConvolutionFamily:
25 ConvolutionConfigFactory<Config: ConvGemmConfig> + ConvolutionLaunch
26{
27 type Convolution<MP: MatmulPrecision>: Convolution<MP, Config = Self::Config>;
28
29 fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes;
30}
31
32#[cube]
33pub trait Convolution<MP: MatmulPrecision>: 'static + Send + Sync {
34 type LhsLoader: CubeType;
35 type RhsLoader: CubeType;
36 type Config: ConvGemmConfig;
37 type AccumulatorLoader: AccumulatorLoader<MP>;
38
39 type Writer: GlobalWriter<MP::EO>;
40 type Accumulator: CubeType;
41
42 fn execute(
49 lhs_loader: Self::LhsLoader,
50 rhs_loader: Self::RhsLoader,
51 acc_loader: Self::AccumulatorLoader,
52 writer: Self::Writer,
53 acc: &mut Self::Accumulator,
54 k_range: (u32, u32),
55 #[comptime] config: Self::Config,
56 );
57
58 fn init_lhs_loader(
59 lhs: VirtualTensor<MP::EI>,
60 x_offset: u32,
61 y_offset: u32,
62 runtime_args: &RuntimeArgs,
63 #[comptime] config: Self::Config,
64 ) -> Self::LhsLoader;
65
66 fn init_rhs_loader(
67 rhs: VirtualTensor<MP::EI>,
68 x_offset: u32,
69 y_offset: u32,
70 runtime_args: &RuntimeArgs,
71 #[comptime] config: Self::Config,
72 ) -> Self::RhsLoader;
73
74 fn init_bias_loader(
75 bias: CubeOption<VirtualTensor<MP::EO>>,
76 n_offset: u32,
77 #[comptime] config: Self::Config,
78 ) -> Self::AccumulatorLoader;
79
80 fn init_writer(
81 out: VirtualTensor<MP::EO, ReadWrite>,
82 x_offset: u32,
83 y_offset: u32,
84 ) -> Self::Writer;
85
86 fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator;
87}
88
89pub trait ConvolutionConfigFactory: Send + Sync + 'static {
91 type Config: ConvGemmConfig;
93
94 fn setup<R: Runtime, MP: MatmulPrecision>(
95 client: &ComputeClient<R::Server, R::Channel>,
96 problem: &ConvolutionProblem,
97 selection: &MatmulSelection,
98 line_sizes: &MatmulLineSizes,
99 ) -> Result<Self::Config, MatmulSetupError>;
100}
101
102pub trait ConvolutionLaunch: ConvolutionConfigFactory {
104 #[allow(clippy::too_many_arguments)]
110 unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
111 client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
112 cube_dim: CubeDim,
113 cube_count: CubeCount,
114 input: InputRuntimeArg<'a, MS, R>,
115 bias: Option<TensorArg<'a, R>>,
116 output: OutputRuntimeArg<'a, MS, R>,
117 problem: &ConvolutionProblem,
118 config: <Self as ConvolutionConfigFactory>::Config,
119 );
120}
121
122#[derive(Clone, Debug)]
123pub struct ConvolutionProblem {
125 pub m: usize,
126 pub n: usize,
127 pub k: usize,
128 pub lhs_layout: MatrixLayout,
129 pub rhs_layout: MatrixLayout,
130
131 pub kernel_size: Vec<u32>,
132 pub stride: Vec<u32>,
133 pub padding: Vec<i32>,
134 pub dilation: Vec<u32>,
135
136 pub batches: usize,
137 pub channels: usize,
138 pub shape: Vec<usize>,
139 pub out_shape: Vec<usize>,
140
141 pub dimensionality: Dimensionality,
142}
143
144impl ConvolutionProblem {
145 pub fn as_matmul_problem(&self) -> MatmulProblem {
146 MatmulProblem {
147 m: self.m,
148 n: self.n,
149 k: self.k,
150 lhs_batches: vec![],
151 rhs_batches: vec![],
152 lhs_layout: self.lhs_layout,
153 rhs_layout: self.rhs_layout,
154 }
155 }
156}
157
158#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
160pub enum Dimensionality {
161 Dim1,
162 Dim2,
163 Dim3,
164}