cubek_convolution/components/global/
base.rs1use cubecl::prelude::*;
2use cubecl::std::{
3 CubeOption,
4 tensor::{View, layout::Coords2d},
5};
6use cubecl::{self, ir::DeviceProperties};
7use cubek_matmul::components::{
8 global::GlobalWriter,
9 stage::{ContiguousTilingLayout, RowMajorTilingOrder},
10};
11use cubek_matmul::definition::{
12 AccG, AvailableLineSizes, LhsG, MatmulElems, MatmulLineSizes, MatmulPrecision,
13 MatmulSetupError, RhsG, TilingBlueprint,
14};
15
16use crate::components::{
17 ConvGemmConfig, ConvolutionProblem,
18 global::{args::RuntimeArgs, entry_point::ConvolutionLaunch},
19};
20
21pub type ConvTilingLayout = ContiguousTilingLayout<RowMajorTilingOrder>;
22
23pub type GlobalConfig<F> = <F as GlobalConvolutionFamily>::Config;
24
25pub trait GlobalConvolutionFamily: ConvolutionLaunch<Self::Config> + 'static {
26 type Config: ConvGemmConfig;
28 type Convolution<MP: MatmulPrecision>: GlobalConvolution<MP, Config = Self::Config>;
29
30 fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes;
31
32 fn expand_config(
33 device_props: &DeviceProperties,
34 problem: &ConvolutionProblem,
35 selection: &TilingBlueprint,
36 line_sizes: &MatmulLineSizes,
37 dtypes: &MatmulElems,
38 ) -> Result<Self::Config, MatmulSetupError>;
39}
40
41#[cube]
42pub trait GlobalConvolution<MP: MatmulPrecision>: 'static + Send + Sync {
43 type LhsGlobalReader: CubeType;
45 type RhsGlobalReader: CubeType;
47 type AccGlobalReader: CubeType;
49 type Config: ConvGemmConfig;
51
52 type GlobalWriter: GlobalWriter<MP::Acc>;
54 type Accumulators: CubeType;
56
57 fn execute(
64 lhs_reader: Self::LhsGlobalReader,
65 rhs_reader: Self::RhsGlobalReader,
66 acc_reader: Self::AccGlobalReader,
67 writer: Self::GlobalWriter,
68 acc: &mut Self::Accumulators,
69 k_range: (u32, u32),
70 #[comptime] config: Self::Config,
71 );
72
73 fn init_lhs_global_reader(
75 lhs: View<Line<LhsG<MP>>, Coords2d>,
76 offset: Coords2d,
77 slice_size: Coords2d,
78 runtime_args: &RuntimeArgs,
79 #[comptime] config: Self::Config,
80 ) -> Self::LhsGlobalReader;
81
82 fn init_rhs_global_reader(
84 rhs: View<Line<RhsG<MP>>, Coords2d>,
85 runtime_args: &RuntimeArgs,
86 #[comptime] config: Self::Config,
87 ) -> Self::RhsGlobalReader;
88
89 fn init_bias_global_reader(
91 bias: CubeOption<View<Line<AccG<MP>>, Coords2d>>,
92 #[comptime] config: Self::Config,
93 ) -> Self::AccGlobalReader;
94
95 fn init_global_writer(
97 out: View<Line<AccG<MP>>, Coords2d, ReadWrite>,
98 #[comptime] config: Self::Config,
99 ) -> Self::GlobalWriter;
100
101 fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulators;
103}