cubecl_linalg/convolution/
base.rs

1use crate::matmul::{
2    components::{
3        InputRuntimeArg, InvalidConfigError, MatmulPrecision, MatmulProblem, MatmulSpec,
4        MatrixLayout, OutputRuntimeArg,
5        global::{AccumulatorLoader, OutputLoader},
6    },
7    kernels::MatmulAvailabilityError,
8};
9use cubecl_core as cubecl;
10use cubecl_core::prelude::*;
11use cubecl_std::{
12    CubeOption, FastDivmod,
13    tensor::r#virtual::{ReadWrite, VirtualTensor},
14};
15
16use super::ConvGemmConfig;
17
18#[derive(CubeType, CubeLaunch, Clone)]
19pub struct RuntimeArgs {
20    pub size_m: u32,
21    pub size_n: u32,
22    pub size_k: u32,
23    pub padded_channels: FastDivmod,
24    pub out_h: FastDivmod,
25    pub out_w: FastDivmod,
26}
27
28pub trait ConvolutionFamily:
29    ConvolutionConfigFactory<Config: ConvGemmConfig> + ConvolutionLaunch
30{
31    type Convolution<MP: MatmulPrecision>: Convolution<MP, Config = Self::Config>;
32}
33
34#[cube]
35pub trait Convolution<MP: MatmulPrecision>: 'static + Send + Sync {
36    type LhsLoader: CubeType;
37    type RhsLoader: CubeType;
38    type Config: ConvGemmConfig;
39    type AccumulatorLoader: AccumulatorLoader<MP>;
40
41    type Out: OutputLoader<MP::EO>;
42    type Accumulator: CubeType;
43
44    /// Performs the convolution over data loaded by the
45    /// LHS and RHS loaders, over the range given for K, and stores with
46    /// using the output unloader.
47    ///
48    /// To compute the whole range of k values, use k_range=(0, K) where
49    /// K is the K dimension of LHS and RHS.
50    fn execute(
51        lhs_loader: Self::LhsLoader,
52        rhs_loader: Self::RhsLoader,
53        acc_loader: Self::AccumulatorLoader,
54        unloader: Self::Out,
55        acc: &mut Self::Accumulator,
56        k_range: (u32, u32),
57        #[comptime] config: Self::Config,
58    );
59
60    fn init_lhs_loader(
61        lhs: VirtualTensor<MP::EI>,
62        x_offset: u32,
63        y_offset: u32,
64        runtime_args: &RuntimeArgs,
65        #[comptime] config: Self::Config,
66    ) -> Self::LhsLoader;
67
68    fn init_rhs_loader(
69        rhs: VirtualTensor<MP::EI>,
70        x_offset: u32,
71        y_offset: u32,
72        runtime_args: &RuntimeArgs,
73        #[comptime] config: Self::Config,
74    ) -> Self::RhsLoader;
75
76    fn init_bias_loader(
77        bias: CubeOption<VirtualTensor<MP::EO>>,
78        n_offset: u32,
79        #[comptime] config: Self::Config,
80    ) -> Self::AccumulatorLoader;
81
82    fn init_unloader(
83        out: VirtualTensor<MP::EO, ReadWrite>,
84        x_offset: u32,
85        y_offset: u32,
86    ) -> Self::Out;
87
88    fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator;
89}
90
91/// Provides configuration for a matmul kernel at any level
92pub trait ConvolutionConfigFactory: Send + Sync + 'static {
93    /// Configuration tailored to the matmul implementation
94    type Config: ConvGemmConfig;
95    type Input;
96
97    /// Asserts that the configuration for this matmul will lead to a valid computation
98    fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError>;
99
100    fn make_config(
101        input: Self::Input,
102        problem: &ConvolutionProblem,
103        cube_dim: &CubeDim,
104        cube_count: &CubeCount,
105    ) -> Self::Config;
106
107    fn check_availability<R: Runtime, MP: MatmulPrecision>(
108        client: &ComputeClient<R::Server, R::Channel>,
109        config: &Self::Config,
110    ) -> Result<(), MatmulAvailabilityError>;
111}
112
113/// Provides launch entry point to solve a matmul
114pub trait ConvolutionLaunch: ConvolutionConfigFactory {
115    /// Entry point
116    ///
117    /// # Safety
118    ///
119    /// Out-of-bounds can happen
120    #[allow(clippy::too_many_arguments)]
121    unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
122        client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
123        cube_dim: CubeDim,
124        cube_count: CubeCount,
125        input: InputRuntimeArg<'a, MS, R>,
126        bias: Option<TensorArg<'a, R>>,
127        output: OutputRuntimeArg<'a, MS, R>,
128        problem: &ConvolutionProblem,
129        config: <Self as ConvolutionConfigFactory>::Config,
130    );
131}
132
133#[derive(Clone, Debug)]
134/// Description of a matmul problem to solve, regardless of actual data
135pub struct ConvolutionProblem {
136    pub m: usize,
137    pub n: usize,
138    pub k: usize,
139    pub lhs_layout: MatrixLayout,
140    pub rhs_layout: MatrixLayout,
141    pub lhs_line_size: u8,
142    pub rhs_line_size: u8,
143    pub out_line_size: u8,
144
145    pub kernel_size: (u32, u32),
146    pub stride: (u32, u32),
147    pub padding: (i32, i32),
148    pub dilation: (u32, u32),
149
150    pub batches: usize,
151    pub height: usize,
152    pub width: usize,
153    pub channels: usize,
154
155    pub out_h: usize,
156    pub out_w: usize,
157}
158
159impl ConvolutionProblem {
160    pub fn as_matmul_problem(&self) -> MatmulProblem {
161        MatmulProblem {
162            m: self.m,
163            n: self.n,
164            k: self.k,
165            batches: (vec![], vec![]),
166            lhs_layout: self.lhs_layout,
167            rhs_layout: self.rhs_layout,
168            lhs_line_size: self.lhs_line_size,
169            rhs_line_size: self.rhs_line_size,
170            out_line_size: self.out_line_size,
171        }
172    }
173}