cubecl_linalg/matmul/components/global/
config.rs1use crate::matmul::components::{
2 Ident, MatmulConfig, MatrixLayout, TilingDimensions,
3 stage::{self},
4};
5
6pub const PRECOMPUTE_JOB: bool = false;
7
8#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
9pub struct CommonGlobalConfig<S: stage::StageConfig> {
11 pub smm_config: S,
12 pub check_m_bounds: bool,
13 pub check_n_bounds: bool,
14 pub check_k_bounds: bool,
15 pub lhs_layout: MatrixLayout,
16 pub rhs_layout: MatrixLayout,
17 pub lhs_line_size: u32,
18 pub rhs_line_size: u32,
19 pub out_line_size: u32,
20 pub num_planes: u32,
21}
22
23impl<S: stage::StageConfig> super::GlobalConfig for CommonGlobalConfig<S> {
24 type SmmConfig = S;
25
26 fn to_smm_config(&self) -> Self::SmmConfig {
27 self.smm_config
28 }
29
30 fn global_line_size<I: Into<Ident>>(&self, ident: I) -> u32 {
31 match ident.into() {
32 Ident::Lhs => self.lhs_line_size,
33 Ident::Rhs => self.rhs_line_size,
34 Ident::Out => self.out_line_size,
35 }
36 }
37
38 fn tiling_dimensions<I: Into<Ident>>(&self, ident: I) -> TilingDimensions {
39 self.smm_config.tiling_dimensions(ident.into())
40 }
41
42 fn matrix_layout<I: Into<Ident>>(&self, ident: I) -> MatrixLayout {
43 match ident.into() {
44 Ident::Lhs => self.lhs_layout,
45 Ident::Rhs => self.rhs_layout,
46 Ident::Out => self.smm_config.matrix_layout(Ident::Out),
47 }
48 }
49
50 fn num_planes(&self) -> u32 {
51 self.num_planes
52 }
53
54 fn plane_dim(&self) -> u32 {
55 self.smm_config.plane_dim()
56 }
57
58 fn check_row_bounds<I: Into<Ident>>(&self, ident: I) -> bool {
59 match ident.into() {
60 Ident::Lhs => self.check_m_bounds,
61 Ident::Rhs => self.check_k_bounds,
62 Ident::Out => self.check_m_bounds,
63 }
64 }
65
66 fn check_col_bounds<I: Into<Ident>>(&self, ident: I) -> bool {
67 match ident.into() {
68 Ident::Lhs => self.check_k_bounds,
69 Ident::Rhs => self.check_n_bounds,
70 Ident::Out => self.check_n_bounds,
71 }
72 }
73
74 fn check_k_bounds(&self) -> bool {
75 self.check_k_bounds
76 }
77
78 fn precompute_job(&self) -> bool {
79 PRECOMPUTE_JOB
80 }
81}
82
83impl<S: stage::StageConfig> MatmulConfig for CommonGlobalConfig<S> {}
84
85impl<S: stage::StageConfig> CommonGlobalConfig<S> {
86 #[allow(clippy::too_many_arguments)]
87 pub fn new(
88 smm_config: S,
89 check_m_bounds: bool,
90 check_n_bounds: bool,
91 check_k_bounds: bool,
92 lhs_layout: MatrixLayout,
93 rhs_layout: MatrixLayout,
94 lhs_line_size: u32,
95 rhs_line_size: u32,
96 out_line_size: u32,
97 num_planes: u32,
98 ) -> Self {
99 Self {
100 smm_config,
101 check_m_bounds,
102 check_n_bounds,
103 check_k_bounds,
104 lhs_layout,
105 rhs_layout,
106 lhs_line_size,
107 rhs_line_size,
108 out_line_size,
109 num_planes,
110 }
111 }
112}