cubecl_convolution/components/global/
base.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_matmul::components::{
4    AccG, AvailableLineSizes, LhsG, MatmulElems, MatmulLineSizes, MatmulPrecision, MatmulSelection,
5    MatmulSetupError, RhsG,
6    global::GlobalWriter,
7    stage::{ContiguousTilingLayout, RowMajorTilingOrder},
8};
9use cubecl_std::{
10    CubeOption,
11    tensor::{View, layout::Coords2d},
12};
13
14use crate::{
15    components::{ConvGemmConfig, ConvolutionProblem, global::entry_point::ConvolutionLaunch},
16    kernels::layered::selector::RuntimeArgs,
17};
18
19pub type ConvTilingLayout = ContiguousTilingLayout<RowMajorTilingOrder>;
20
21pub type GlobalConfig<F> = <F as GlobalConvolutionFamily>::Config;
22
23pub trait GlobalConvolutionFamily: ConvolutionLaunch<Self::Config> + 'static {
24    /// Configuration tailored to the matmul implementation
25    type Config: ConvGemmConfig;
26    type Convolution<MP: MatmulPrecision>: GlobalConvolution<MP, Config = Self::Config>;
27
28    fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes;
29
30    fn setup<R: Runtime>(
31        client: &ComputeClient<R>,
32        problem: &ConvolutionProblem,
33        selection: &MatmulSelection,
34        line_sizes: &MatmulLineSizes,
35        dtypes: &MatmulElems,
36    ) -> Result<Self::Config, MatmulSetupError>;
37}
38
39#[cube]
40pub trait GlobalConvolution<MP: MatmulPrecision>: 'static + Send + Sync {
41    /// The global reader for the Lhs (input feature map) tensor
42    type LhsGlobalReader: CubeType;
43    /// The global reader for the Rhs (weight) tensor
44    type RhsGlobalReader: CubeType;
45    /// The global reader for the accumulator (bias) tensor
46    type AccGlobalReader: CubeType;
47    /// The config type of the convolution
48    type Config: ConvGemmConfig;
49
50    /// The writer used to write the results to the output feature map
51    type GlobalWriter: GlobalWriter<MP::Acc>;
52    /// The type of the tile matmul accumulator
53    type Accumulators: CubeType;
54
55    /// Performs the convolution over data loaded by the
56    /// LHS and RHS readers, over the range given for K, and stores with
57    /// using the output writer.
58    ///
59    /// To compute the whole range of k values, use k_range=(0, K) where
60    /// K is the K dimension of LHS and RHS.
61    fn execute(
62        lhs_reader: Self::LhsGlobalReader,
63        rhs_reader: Self::RhsGlobalReader,
64        acc_reader: Self::AccGlobalReader,
65        writer: Self::GlobalWriter,
66        acc: &mut Self::Accumulators,
67        k_range: (u32, u32),
68        #[comptime] config: Self::Config,
69    );
70
71    /// Initializes the global reader for the input feature map with an appropriate layout
72    fn init_lhs_global_reader(
73        lhs: View<Line<LhsG<MP>>, Coords2d>,
74        offset: Coords2d,
75        slice_size: Coords2d,
76        runtime_args: &RuntimeArgs,
77        #[comptime] config: Self::Config,
78    ) -> Self::LhsGlobalReader;
79
80    /// Initializes the global reader for the weights with an appropriate layout
81    fn init_rhs_global_reader(
82        rhs: View<Line<RhsG<MP>>, Coords2d>,
83        #[comptime] config: Self::Config,
84    ) -> Self::RhsGlobalReader;
85
86    /// Initializes the global reader for the bias with an appropriate layout
87    fn init_bias_global_reader(
88        bias: CubeOption<View<Line<AccG<MP>>, Coords2d>>,
89        #[comptime] config: Self::Config,
90    ) -> Self::AccGlobalReader;
91
92    /// Initializes the output feature map global writer with an appropriate layout
93    fn init_global_writer(
94        out: View<Line<AccG<MP>>, Coords2d, ReadWrite>,
95        #[comptime] config: Self::Config,
96    ) -> Self::GlobalWriter;
97
98    /// Initializes a new accumulator for the tile matmul
99    fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulators;
100}