cubecl_convolution/components/global/
base.rs1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_matmul::components::{
4 AccG, AvailableLineSizes, LhsG, MatmulElems, MatmulLineSizes, MatmulPrecision, MatmulSelection,
5 MatmulSetupError, RhsG,
6 global::GlobalWriter,
7 stage::{ContiguousTilingLayout, RowMajorTilingOrder},
8};
9use cubecl_std::{
10 CubeOption,
11 tensor::{View, layout::Coords2d},
12};
13
14use crate::components::{
15 ConvGemmConfig, ConvolutionProblem,
16 global::{args::RuntimeArgs, entry_point::ConvolutionLaunch},
17};
18
19pub type ConvTilingLayout = ContiguousTilingLayout<RowMajorTilingOrder>;
20
21pub type GlobalConfig<F> = <F as GlobalConvolutionFamily>::Config;
22
23pub trait GlobalConvolutionFamily: ConvolutionLaunch<Self::Config> + 'static {
24 type Config: ConvGemmConfig;
26 type Convolution<MP: MatmulPrecision>: GlobalConvolution<MP, Config = Self::Config>;
27
28 fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes;
29
30 fn setup<R: Runtime>(
31 client: &ComputeClient<R>,
32 problem: &ConvolutionProblem,
33 selection: &MatmulSelection,
34 line_sizes: &MatmulLineSizes,
35 dtypes: &MatmulElems,
36 ) -> Result<Self::Config, MatmulSetupError>;
37}
38
39#[cube]
40pub trait GlobalConvolution<MP: MatmulPrecision>: 'static + Send + Sync {
41 type LhsGlobalReader: CubeType;
43 type RhsGlobalReader: CubeType;
45 type AccGlobalReader: CubeType;
47 type Config: ConvGemmConfig;
49
50 type GlobalWriter: GlobalWriter<MP::Acc>;
52 type Accumulators: CubeType;
54
55 fn execute(
62 lhs_reader: Self::LhsGlobalReader,
63 rhs_reader: Self::RhsGlobalReader,
64 acc_reader: Self::AccGlobalReader,
65 writer: Self::GlobalWriter,
66 acc: &mut Self::Accumulators,
67 k_range: (u32, u32),
68 #[comptime] config: Self::Config,
69 );
70
71 fn init_lhs_global_reader(
73 lhs: View<Line<LhsG<MP>>, Coords2d>,
74 offset: Coords2d,
75 slice_size: Coords2d,
76 runtime_args: &RuntimeArgs,
77 #[comptime] config: Self::Config,
78 ) -> Self::LhsGlobalReader;
79
80 fn init_rhs_global_reader(
82 rhs: View<Line<RhsG<MP>>, Coords2d>,
83 runtime_args: &RuntimeArgs,
84 #[comptime] config: Self::Config,
85 ) -> Self::RhsGlobalReader;
86
87 fn init_bias_global_reader(
89 bias: CubeOption<View<Line<AccG<MP>>, Coords2d>>,
90 #[comptime] config: Self::Config,
91 ) -> Self::AccGlobalReader;
92
93 fn init_global_writer(
95 out: View<Line<AccG<MP>>, Coords2d, ReadWrite>,
96 #[comptime] config: Self::Config,
97 ) -> Self::GlobalWriter;
98
99 fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulators;
101}