cubecl_linalg/matmul/components/global/
config.rs

1use crate::matmul::components::{
2    Ident, MatmulConfig, MatrixLayout, TilingDimensions,
3    stage::{self},
4};
5
6pub const PRECOMPUTE_JOB: bool = false;
7
8#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
9/// Configuration for the pipelined global matmul
10pub struct CommonGlobalConfig<S: stage::StageConfig> {
11    pub smm_config: S,
12    pub check_m_bounds: bool,
13    pub check_n_bounds: bool,
14    pub check_k_bounds: bool,
15    pub lhs_layout: MatrixLayout,
16    pub rhs_layout: MatrixLayout,
17    pub lhs_line_size: u32,
18    pub rhs_line_size: u32,
19    pub out_line_size: u32,
20    pub num_planes: u32,
21}
22
23impl<S: stage::StageConfig> super::GlobalConfig for CommonGlobalConfig<S> {
24    type SmmConfig = S;
25
26    fn to_smm_config(&self) -> Self::SmmConfig {
27        self.smm_config
28    }
29
30    fn global_line_size<I: Into<Ident>>(&self, ident: I) -> u32 {
31        match ident.into() {
32            Ident::Lhs => self.lhs_line_size,
33            Ident::Rhs => self.rhs_line_size,
34            Ident::Out => self.out_line_size,
35        }
36    }
37
38    fn tiling_dimensions<I: Into<Ident>>(&self, ident: I) -> TilingDimensions {
39        self.smm_config.tiling_dimensions(ident.into())
40    }
41
42    fn matrix_layout<I: Into<Ident>>(&self, ident: I) -> MatrixLayout {
43        match ident.into() {
44            Ident::Lhs => self.lhs_layout,
45            Ident::Rhs => self.rhs_layout,
46            Ident::Out => self.smm_config.matrix_layout(Ident::Out),
47        }
48    }
49
50    fn num_planes(&self) -> u32 {
51        self.num_planes
52    }
53
54    fn plane_dim(&self) -> u32 {
55        self.smm_config.plane_dim()
56    }
57
58    fn check_row_bounds<I: Into<Ident>>(&self, ident: I) -> bool {
59        match ident.into() {
60            Ident::Lhs => self.check_m_bounds,
61            Ident::Rhs => self.check_k_bounds,
62            Ident::Out => self.check_m_bounds,
63        }
64    }
65
66    fn check_col_bounds<I: Into<Ident>>(&self, ident: I) -> bool {
67        match ident.into() {
68            Ident::Lhs => self.check_k_bounds,
69            Ident::Rhs => self.check_n_bounds,
70            Ident::Out => self.check_n_bounds,
71        }
72    }
73
74    fn check_k_bounds(&self) -> bool {
75        self.check_k_bounds
76    }
77
78    fn precompute_job(&self) -> bool {
79        PRECOMPUTE_JOB
80    }
81}
82
83impl<S: stage::StageConfig> MatmulConfig for CommonGlobalConfig<S> {}
84
85impl<S: stage::StageConfig> CommonGlobalConfig<S> {
86    #[allow(clippy::too_many_arguments)]
87    pub fn new(
88        smm_config: S,
89        check_m_bounds: bool,
90        check_n_bounds: bool,
91        check_k_bounds: bool,
92        lhs_layout: MatrixLayout,
93        rhs_layout: MatrixLayout,
94        lhs_line_size: u32,
95        rhs_line_size: u32,
96        out_line_size: u32,
97        num_planes: u32,
98    ) -> Self {
99        Self {
100            smm_config,
101            check_m_bounds,
102            check_n_bounds,
103            check_k_bounds,
104            lhs_layout,
105            rhs_layout,
106            lhs_line_size,
107            rhs_line_size,
108            out_line_size,
109            num_planes,
110        }
111    }
112}