cubecl_convolution/components/global/
base.rs1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_matmul::components::{
4 AccG, AvailableLineSizes, LhsG, MatmulLineSizes, MatmulPrecision, MatmulSelection,
5 MatmulSetupError, RhsG,
6 global::GlobalWriter,
7 stage::{ContiguousTilingLayout, RowMajorTilingOrder},
8};
9use cubecl_std::{
10 CubeOption,
11 tensor::{View, layout::Coords2d},
12};
13
14use crate::{
15 components::{ConvGemmConfig, ConvolutionProblem, global::entry_point::ConvolutionLaunch},
16 kernels::layered::selector::RuntimeArgs,
17};
18
19pub type ConvTilingLayout = ContiguousTilingLayout<RowMajorTilingOrder>;
20
21pub type GlobalConfig<F> = <F as GlobalConvolutionFamily>::Config;
22
23pub trait GlobalConvolutionFamily: ConvolutionLaunch<Self::Config> + 'static {
24 type Config: ConvGemmConfig;
26 type Convolution<MP: MatmulPrecision>: GlobalConvolution<MP, Config = Self::Config>;
27
28 fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes;
29
30 fn setup<R: Runtime, MP: MatmulPrecision>(
31 client: &ComputeClient<R::Server>,
32 problem: &ConvolutionProblem,
33 selection: &MatmulSelection,
34 line_sizes: &MatmulLineSizes,
35 ) -> Result<Self::Config, MatmulSetupError>;
36}
37
38#[cube]
39pub trait GlobalConvolution<MP: MatmulPrecision>: 'static + Send + Sync {
40 type LhsGlobalReader: CubeType;
42 type RhsGlobalReader: CubeType;
44 type AccGlobalReader: CubeType;
46 type Config: ConvGemmConfig;
48
49 type GlobalWriter: GlobalWriter<MP::Acc>;
51 type Accumulators: CubeType;
53
54 fn execute(
61 lhs_reader: Self::LhsGlobalReader,
62 rhs_reader: Self::RhsGlobalReader,
63 acc_reader: Self::AccGlobalReader,
64 writer: Self::GlobalWriter,
65 acc: &mut Self::Accumulators,
66 k_range: (u32, u32),
67 #[comptime] config: Self::Config,
68 );
69
70 fn init_lhs_global_reader(
72 lhs: View<Line<LhsG<MP>>, Coords2d>,
73 offset: Coords2d,
74 slice_size: Coords2d,
75 runtime_args: &RuntimeArgs,
76 #[comptime] config: Self::Config,
77 ) -> Self::LhsGlobalReader;
78
79 fn init_rhs_global_reader(
81 rhs: View<Line<RhsG<MP>>, Coords2d>,
82 #[comptime] config: Self::Config,
83 ) -> Self::RhsGlobalReader;
84
85 fn init_bias_global_reader(
87 bias: CubeOption<View<Line<AccG<MP>>, Coords2d>>,
88 #[comptime] config: Self::Config,
89 ) -> Self::AccGlobalReader;
90
91 fn init_global_writer(
93 out: View<Line<AccG<MP>>, Coords2d, ReadWrite>,
94 #[comptime] config: Self::Config,
95 ) -> Self::GlobalWriter;
96
97 fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulators;
99}