cubecl_convolution/components/global/
base.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_matmul::components::{
4    AccG, AvailableLineSizes, LhsG, MatmulLineSizes, MatmulPrecision, MatmulSelection,
5    MatmulSetupError, RhsG,
6    global::GlobalWriter,
7    stage::{ContiguousTilingLayout, RowMajorTilingOrder},
8};
9use cubecl_std::{
10    CubeOption,
11    tensor::{View, layout::Coords2d},
12};
13
14use crate::{
15    components::{ConvGemmConfig, ConvolutionProblem, global::entry_point::ConvolutionLaunch},
16    kernels::layered::selector::RuntimeArgs,
17};
18
19pub type ConvTilingLayout = ContiguousTilingLayout<RowMajorTilingOrder>;
20
21pub type GlobalConfig<F> = <F as GlobalConvolutionFamily>::Config;
22
23pub trait GlobalConvolutionFamily: ConvolutionLaunch<Self::Config> + 'static {
24    /// Configuration tailored to the matmul implementation
25    type Config: ConvGemmConfig;
26    type Convolution<MP: MatmulPrecision>: GlobalConvolution<MP, Config = Self::Config>;
27
28    fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes;
29
30    fn setup<R: Runtime, MP: MatmulPrecision>(
31        client: &ComputeClient<R::Server>,
32        problem: &ConvolutionProblem,
33        selection: &MatmulSelection,
34        line_sizes: &MatmulLineSizes,
35    ) -> Result<Self::Config, MatmulSetupError>;
36}
37
38#[cube]
39pub trait GlobalConvolution<MP: MatmulPrecision>: 'static + Send + Sync {
40    /// The global reader for the Lhs (input feature map) tensor
41    type LhsGlobalReader: CubeType;
42    /// The global reader for the Rhs (weight) tensor
43    type RhsGlobalReader: CubeType;
44    /// The global reader for the accumulator (bias) tensor
45    type AccGlobalReader: CubeType;
46    /// The config type of the convolution
47    type Config: ConvGemmConfig;
48
49    /// The writer used to write the results to the output feature map
50    type GlobalWriter: GlobalWriter<MP::Acc>;
51    /// The type of the tile matmul accumulator
52    type Accumulators: CubeType;
53
54    /// Performs the convolution over data loaded by the
55    /// LHS and RHS readers, over the range given for K, and stores with
56    /// using the output writer.
57    ///
58    /// To compute the whole range of k values, use k_range=(0, K) where
59    /// K is the K dimension of LHS and RHS.
60    fn execute(
61        lhs_reader: Self::LhsGlobalReader,
62        rhs_reader: Self::RhsGlobalReader,
63        acc_reader: Self::AccGlobalReader,
64        writer: Self::GlobalWriter,
65        acc: &mut Self::Accumulators,
66        k_range: (u32, u32),
67        #[comptime] config: Self::Config,
68    );
69
70    /// Initializes the global reader for the input feature map with an appropriate layout
71    fn init_lhs_global_reader(
72        lhs: View<Line<LhsG<MP>>, Coords2d>,
73        offset: Coords2d,
74        slice_size: Coords2d,
75        runtime_args: &RuntimeArgs,
76        #[comptime] config: Self::Config,
77    ) -> Self::LhsGlobalReader;
78
79    /// Initializes the global reader for the weights with an appropriate layout
80    fn init_rhs_global_reader(
81        rhs: View<Line<RhsG<MP>>, Coords2d>,
82        #[comptime] config: Self::Config,
83    ) -> Self::RhsGlobalReader;
84
85    /// Initializes the global reader for the bias with an appropriate layout
86    fn init_bias_global_reader(
87        bias: CubeOption<View<Line<AccG<MP>>, Coords2d>>,
88        #[comptime] config: Self::Config,
89    ) -> Self::AccGlobalReader;
90
91    /// Initializes the output feature map global writer with an appropriate layout
92    fn init_global_writer(
93        out: View<Line<AccG<MP>>, Coords2d, ReadWrite>,
94        #[comptime] config: Self::Config,
95    ) -> Self::GlobalWriter;
96
97    /// Initializes a new accumulator for the tile matmul
98    fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulators;
99}