cubek_convolution/components/global/
base.rs

1use cubecl::prelude::*;
2use cubecl::std::{
3    CubeOption,
4    tensor::{View, layout::Coords2d},
5};
6use cubecl::{self, ir::DeviceProperties};
7use cubek_matmul::components::{
8    global::GlobalWriter,
9    stage::{ContiguousTilingLayout, RowMajorTilingOrder},
10};
11use cubek_matmul::definition::{
12    AccG, AvailableLineSizes, LhsG, MatmulElems, MatmulLineSizes, MatmulPrecision,
13    MatmulSetupError, RhsG, TilingBlueprint,
14};
15
16use crate::components::{
17    ConvGemmConfig, ConvolutionProblem,
18    global::{args::RuntimeArgs, entry_point::ConvolutionLaunch},
19};
20
21pub type ConvTilingLayout = ContiguousTilingLayout<RowMajorTilingOrder>;
22
23pub type GlobalConfig<F> = <F as GlobalConvolutionFamily>::Config;
24
25pub trait GlobalConvolutionFamily: ConvolutionLaunch<Self::Config> + 'static {
26    /// Configuration tailored to the matmul implementation
27    type Config: ConvGemmConfig;
28    type Convolution<MP: MatmulPrecision>: GlobalConvolution<MP, Config = Self::Config>;
29
30    fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes;
31
32    fn expand_config(
33        device_props: &DeviceProperties,
34        problem: &ConvolutionProblem,
35        selection: &TilingBlueprint,
36        line_sizes: &MatmulLineSizes,
37        dtypes: &MatmulElems,
38    ) -> Result<Self::Config, MatmulSetupError>;
39}
40
41#[cube]
42pub trait GlobalConvolution<MP: MatmulPrecision>: 'static + Send + Sync {
43    /// The global reader for the Lhs (input feature map) tensor
44    type LhsGlobalReader: CubeType;
45    /// The global reader for the Rhs (weight) tensor
46    type RhsGlobalReader: CubeType;
47    /// The global reader for the accumulator (bias) tensor
48    type AccGlobalReader: CubeType;
49    /// The config type of the convolution
50    type Config: ConvGemmConfig;
51
52    /// The writer used to write the results to the output feature map
53    type GlobalWriter: GlobalWriter<MP::Acc>;
54    /// The type of the tile matmul accumulator
55    type Accumulators: CubeType;
56
57    /// Performs the convolution over data loaded by the
58    /// LHS and RHS readers, over the range given for K, and stores with
59    /// using the output writer.
60    ///
61    /// To compute the whole range of k values, use k_range=(0, K) where
62    /// K is the K dimension of LHS and RHS.
63    fn execute(
64        lhs_reader: Self::LhsGlobalReader,
65        rhs_reader: Self::RhsGlobalReader,
66        acc_reader: Self::AccGlobalReader,
67        writer: Self::GlobalWriter,
68        acc: &mut Self::Accumulators,
69        k_range: (u32, u32),
70        #[comptime] config: Self::Config,
71    );
72
73    /// Initializes the global reader for the input feature map with an appropriate layout
74    fn init_lhs_global_reader(
75        lhs: View<Line<LhsG<MP>>, Coords2d>,
76        offset: Coords2d,
77        slice_size: Coords2d,
78        runtime_args: &RuntimeArgs,
79        #[comptime] config: Self::Config,
80    ) -> Self::LhsGlobalReader;
81
82    /// Initializes the global reader for the weights with an appropriate layout
83    fn init_rhs_global_reader(
84        rhs: View<Line<RhsG<MP>>, Coords2d>,
85        runtime_args: &RuntimeArgs,
86        #[comptime] config: Self::Config,
87    ) -> Self::RhsGlobalReader;
88
89    /// Initializes the global reader for the bias with an appropriate layout
90    fn init_bias_global_reader(
91        bias: CubeOption<View<Line<AccG<MP>>, Coords2d>>,
92        #[comptime] config: Self::Config,
93    ) -> Self::AccGlobalReader;
94
95    /// Initializes the output feature map global writer with an appropriate layout
96    fn init_global_writer(
97        out: View<Line<AccG<MP>>, Coords2d, ReadWrite>,
98        #[comptime] config: Self::Config,
99    ) -> Self::GlobalWriter;
100
101    /// Initializes a new accumulator for the tile matmul
102    fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulators;
103}