cubecl_convolution/components/global/
base.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_matmul::components::{
4    AccG, AvailableLineSizes, LhsG, MatmulElems, MatmulLineSizes, MatmulPrecision, MatmulSelection,
5    MatmulSetupError, RhsG,
6    global::GlobalWriter,
7    stage::{ContiguousTilingLayout, RowMajorTilingOrder},
8};
9use cubecl_std::{
10    CubeOption,
11    tensor::{View, layout::Coords2d},
12};
13
14use crate::components::{
15    ConvGemmConfig, ConvolutionProblem,
16    global::{args::RuntimeArgs, entry_point::ConvolutionLaunch},
17};
18
19pub type ConvTilingLayout = ContiguousTilingLayout<RowMajorTilingOrder>;
20
21pub type GlobalConfig<F> = <F as GlobalConvolutionFamily>::Config;
22
23pub trait GlobalConvolutionFamily: ConvolutionLaunch<Self::Config> + 'static {
24    /// Configuration tailored to the matmul implementation
25    type Config: ConvGemmConfig;
26    type Convolution<MP: MatmulPrecision>: GlobalConvolution<MP, Config = Self::Config>;
27
28    fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes;
29
30    fn setup<R: Runtime>(
31        client: &ComputeClient<R>,
32        problem: &ConvolutionProblem,
33        selection: &MatmulSelection,
34        line_sizes: &MatmulLineSizes,
35        dtypes: &MatmulElems,
36    ) -> Result<Self::Config, MatmulSetupError>;
37}
38
39#[cube]
40pub trait GlobalConvolution<MP: MatmulPrecision>: 'static + Send + Sync {
41    /// The global reader for the Lhs (input feature map) tensor
42    type LhsGlobalReader: CubeType;
43    /// The global reader for the Rhs (weight) tensor
44    type RhsGlobalReader: CubeType;
45    /// The global reader for the accumulator (bias) tensor
46    type AccGlobalReader: CubeType;
47    /// The config type of the convolution
48    type Config: ConvGemmConfig;
49
50    /// The writer used to write the results to the output feature map
51    type GlobalWriter: GlobalWriter<MP::Acc>;
52    /// The type of the tile matmul accumulator
53    type Accumulators: CubeType;
54
55    /// Performs the convolution over data loaded by the
56    /// LHS and RHS readers, over the range given for K, and stores with
57    /// using the output writer.
58    ///
59    /// To compute the whole range of k values, use k_range=(0, K) where
60    /// K is the K dimension of LHS and RHS.
61    fn execute(
62        lhs_reader: Self::LhsGlobalReader,
63        rhs_reader: Self::RhsGlobalReader,
64        acc_reader: Self::AccGlobalReader,
65        writer: Self::GlobalWriter,
66        acc: &mut Self::Accumulators,
67        k_range: (u32, u32),
68        #[comptime] config: Self::Config,
69    );
70
71    /// Initializes the global reader for the input feature map with an appropriate layout
72    fn init_lhs_global_reader(
73        lhs: View<Line<LhsG<MP>>, Coords2d>,
74        offset: Coords2d,
75        slice_size: Coords2d,
76        runtime_args: &RuntimeArgs,
77        #[comptime] config: Self::Config,
78    ) -> Self::LhsGlobalReader;
79
80    /// Initializes the global reader for the weights with an appropriate layout
81    fn init_rhs_global_reader(
82        rhs: View<Line<RhsG<MP>>, Coords2d>,
83        runtime_args: &RuntimeArgs,
84        #[comptime] config: Self::Config,
85    ) -> Self::RhsGlobalReader;
86
87    /// Initializes the global reader for the bias with an appropriate layout
88    fn init_bias_global_reader(
89        bias: CubeOption<View<Line<AccG<MP>>, Coords2d>>,
90        #[comptime] config: Self::Config,
91    ) -> Self::AccGlobalReader;
92
93    /// Initializes the output feature map global writer with an appropriate layout
94    fn init_global_writer(
95        out: View<Line<AccG<MP>>, Coords2d, ReadWrite>,
96        #[comptime] config: Self::Config,
97    ) -> Self::GlobalWriter;
98
99    /// Initializes a new accumulator for the tile matmul
100    fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulators;
101}