cubecl_convolution/
base.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_matmul::components::{
4    AvailableLineSizes, InputRuntimeArg, MatmulLineSizes, MatmulPrecision, MatmulProblem,
5    MatmulSelection, MatmulSetupError, MatmulSpec, MatrixLayout, OutputRuntimeArg,
6    global::{AccumulatorLoader, GlobalWriter},
7};
8use cubecl_std::{
9    CubeOption, FastDivmod,
10    tensor::r#virtual::{ReadWrite, VirtualTensor},
11};
12
13use super::ConvGemmConfig;
14
15#[derive(CubeType, CubeLaunch, Clone)]
16pub struct RuntimeArgs {
17    pub size_m: u32,
18    pub size_n: u32,
19    pub size_k: u32,
20    pub padded_channels: FastDivmod,
21    pub out_shape: Sequence<FastDivmod>,
22}
23
24pub trait ConvolutionFamily:
25    ConvolutionConfigFactory<Config: ConvGemmConfig> + ConvolutionLaunch
26{
27    type Convolution<MP: MatmulPrecision>: Convolution<MP, Config = Self::Config>;
28
29    fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes;
30}
31
32#[cube]
33pub trait Convolution<MP: MatmulPrecision>: 'static + Send + Sync {
34    type LhsLoader: CubeType;
35    type RhsLoader: CubeType;
36    type Config: ConvGemmConfig;
37    type AccumulatorLoader: AccumulatorLoader<MP>;
38
39    type Writer: GlobalWriter<MP::EO>;
40    type Accumulator: CubeType;
41
42    /// Performs the convolution over data loaded by the
43    /// LHS and RHS loaders, over the range given for K, and stores with
44    /// using the output writer.
45    ///
46    /// To compute the whole range of k values, use k_range=(0, K) where
47    /// K is the K dimension of LHS and RHS.
48    fn execute(
49        lhs_loader: Self::LhsLoader,
50        rhs_loader: Self::RhsLoader,
51        acc_loader: Self::AccumulatorLoader,
52        writer: Self::Writer,
53        acc: &mut Self::Accumulator,
54        k_range: (u32, u32),
55        #[comptime] config: Self::Config,
56    );
57
58    fn init_lhs_loader(
59        lhs: VirtualTensor<MP::EI>,
60        x_offset: u32,
61        y_offset: u32,
62        runtime_args: &RuntimeArgs,
63        #[comptime] config: Self::Config,
64    ) -> Self::LhsLoader;
65
66    fn init_rhs_loader(
67        rhs: VirtualTensor<MP::EI>,
68        x_offset: u32,
69        y_offset: u32,
70        runtime_args: &RuntimeArgs,
71        #[comptime] config: Self::Config,
72    ) -> Self::RhsLoader;
73
74    fn init_bias_loader(
75        bias: CubeOption<VirtualTensor<MP::EO>>,
76        n_offset: u32,
77        #[comptime] config: Self::Config,
78    ) -> Self::AccumulatorLoader;
79
80    fn init_writer(
81        out: VirtualTensor<MP::EO, ReadWrite>,
82        x_offset: u32,
83        y_offset: u32,
84    ) -> Self::Writer;
85
86    fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator;
87}
88
89/// Provides configuration for a matmul kernel at any level
90pub trait ConvolutionConfigFactory: Send + Sync + 'static {
91    /// Configuration tailored to the matmul implementation
92    type Config: ConvGemmConfig;
93
94    fn setup<R: Runtime, MP: MatmulPrecision>(
95        client: &ComputeClient<R::Server, R::Channel>,
96        problem: &ConvolutionProblem,
97        selection: &MatmulSelection,
98        line_sizes: &MatmulLineSizes,
99    ) -> Result<Self::Config, MatmulSetupError>;
100}
101
102/// Provides launch entry point to solve a matmul
103pub trait ConvolutionLaunch: ConvolutionConfigFactory {
104    /// Entry point
105    ///
106    /// # Safety
107    ///
108    /// Out-of-bounds can happen
109    #[allow(clippy::too_many_arguments)]
110    unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
111        client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
112        cube_dim: CubeDim,
113        cube_count: CubeCount,
114        input: InputRuntimeArg<'a, MS, R>,
115        bias: Option<TensorArg<'a, R>>,
116        output: OutputRuntimeArg<'a, MS, R>,
117        problem: &ConvolutionProblem,
118        config: <Self as ConvolutionConfigFactory>::Config,
119    );
120}
121
122#[derive(Clone, Debug)]
123/// Description of a matmul problem to solve, regardless of actual data
124pub struct ConvolutionProblem {
125    pub m: usize,
126    pub n: usize,
127    pub k: usize,
128    pub lhs_layout: MatrixLayout,
129    pub rhs_layout: MatrixLayout,
130
131    pub kernel_size: Vec<u32>,
132    pub stride: Vec<u32>,
133    pub padding: Vec<i32>,
134    pub dilation: Vec<u32>,
135
136    pub batches: usize,
137    pub channels: usize,
138    pub shape: Vec<usize>,
139    pub out_shape: Vec<usize>,
140
141    pub dimensionality: Dimensionality,
142}
143
144impl ConvolutionProblem {
145    pub fn as_matmul_problem(&self) -> MatmulProblem {
146        MatmulProblem {
147            m: self.m,
148            n: self.n,
149            k: self.k,
150            lhs_batches: vec![],
151            rhs_batches: vec![],
152            lhs_layout: self.lhs_layout,
153            rhs_layout: self.rhs_layout,
154        }
155    }
156}
157
158/// Spatial dimensionality of an operation
159#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
160pub enum Dimensionality {
161    Dim1,
162    Dim2,
163    Dim3,
164}