cubecl_matmul/components/global/specialization/
config.rs

1use crate::components::{
2    MatmulIdent,
3    global::{MaxGlobalReaderPlanes, specialization::roles::PlaneRoles},
4};
5
6/// Configuration for how each input tensor (Lhs and Rhs) is loaded,
7/// specifying the plane roles responsible for loading them.
8#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
9pub struct LoadSpecializationConfig {
10    /// Load strategy for the Lhs tensor.
11    pub lhs: SpecializationTensorConfig,
12    /// Load strategy for the Rhs tensor.
13    pub rhs: SpecializationTensorConfig,
14}
15
16/// Determines which types of planes are responsible for loading a tensor.
17///
18/// TODO: maybe we want a "MainPlusExtra" variant that uses main flow planes and load-only planes
19/// for the same tensor
20#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
21pub enum SpecializationTensorConfig {
22    /// The tensor is loaded exclusively by planes that participate in the main computation flow.
23    #[default]
24    MainFlowOnly,
25
26    /// The tensor is loaded exclusively by planes dedicated to loading (load-only planes),
27    /// which do not participate in computation.
28    LoadFlowOnly,
29}
30
31impl LoadSpecializationConfig {
32    /// Whether there is specialization in the algorithm
33    pub fn has_specialization(&self) -> bool {
34        self.lhs.has_specialization() || self.rhs.has_specialization()
35    }
36}
37
38impl SpecializationTensorConfig {
39    /// Whether there is specialization for the tensor
40    pub fn has_specialization(&self) -> bool {
41        match self {
42            SpecializationTensorConfig::MainFlowOnly => false,
43            SpecializationTensorConfig::LoadFlowOnly => true,
44        }
45    }
46}
47
48impl LoadSpecializationConfig {
49    /// Computes how many planes of each role there should be,
50    /// using the number of planes needed for main execution, and how
51    /// many planes each reader can handle
52    ///
53    /// The strategy is to find a balanced divisor for reader planes that stays as
54    /// close as possible to the main execution plane count.
55    pub fn to_plane_roles(
56        &self,
57        main_flow: u32,
58        reader_tasks: MaxGlobalReaderPlanes,
59    ) -> PlaneRoles {
60        use SpecializationTensorConfig::*;
61
62        let ideal_load_only = match (self.lhs, self.rhs) {
63            (MainFlowOnly, MainFlowOnly) => 0,
64            (MainFlowOnly, LoadFlowOnly) => reader_tasks.rhs,
65            (LoadFlowOnly, MainFlowOnly) => reader_tasks.lhs,
66            (LoadFlowOnly, LoadFlowOnly) => gcd(reader_tasks.lhs, reader_tasks.rhs),
67        };
68
69        // Don't stray too far from main_flow
70        let load_only = best_divisor_close_to_reference(ideal_load_only, main_flow);
71
72        PlaneRoles {
73            main_flow,
74            load_only,
75        }
76    }
77}
78
79/// Returns the divisor of `dividible_value` closest to `reference`, preferring larger on ties.
80fn best_divisor_close_to_reference(dividible_value: u32, reference: u32) -> u32 {
81    let mut best = 1;
82    let mut best_dist = reference.abs_diff(1);
83
84    for d in 1..=dividible_value {
85        if dividible_value.is_multiple_of(d) {
86            let dist = reference.abs_diff(d);
87            if dist < best_dist || (dist == best_dist && d > best) {
88                best = d;
89                best_dist = dist;
90            }
91        }
92    }
93
94    best
95}
96
97#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
98/// Specifies which input(s) a plane role participates in loading.
99pub enum LoadingSides {
100    /// Load both Lhs and Rhs
101    Both,
102    /// Load Lhs only
103    Lhs,
104    /// Load Rhs only
105    Rhs,
106    /// Don't perform loading
107    None,
108}
109
110impl LoadingSides {
111    /// Returns `true` if Lhs is included.
112    pub fn includes_lhs(&self) -> bool {
113        self.includes(MatmulIdent::Lhs)
114    }
115
116    /// Returns `true` if Rhs is included.
117    pub fn includes_rhs(&self) -> bool {
118        self.includes(MatmulIdent::Rhs)
119    }
120
121    /// Returns `true` if the given input is included.
122    pub fn includes(&self, ident: MatmulIdent) -> bool {
123        matches!(
124            (self, ident),
125            (LoadingSides::Both, _)
126                | (LoadingSides::Lhs, MatmulIdent::Lhs)
127                | (LoadingSides::Rhs, MatmulIdent::Rhs)
128        )
129    }
130}
131
132#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
133/// Aggregates loading sides for both main flow and load only roles
134pub struct SpecializedLoadingSides {
135    pub main_flow: LoadingSides,
136    pub load_only: LoadingSides,
137}
138
139impl SpecializedLoadingSides {
140    /// Returns the number of planes participating in the loading of `ident`
141    pub fn num_loading_planes(
142        &self,
143        specialized: bool,
144        ident: MatmulIdent,
145        plane_roles: PlaneRoles,
146    ) -> u32 {
147        if specialized {
148            let mut num_loading_planes = 0;
149            if self.main_flow.includes(ident) {
150                num_loading_planes += plane_roles.main_flow;
151            }
152            if self.load_only.includes(ident) {
153                num_loading_planes += plane_roles.load_only;
154            }
155            num_loading_planes
156        } else {
157            plane_roles.main_flow
158        }
159    }
160}
161
162impl From<LoadSpecializationConfig> for SpecializedLoadingSides {
163    fn from(lsc: LoadSpecializationConfig) -> Self {
164        use SpecializationTensorConfig::*;
165        match (lsc.lhs, lsc.rhs) {
166            (MainFlowOnly, MainFlowOnly) => SpecializedLoadingSides {
167                main_flow: LoadingSides::Both,
168                load_only: LoadingSides::None,
169            },
170            (MainFlowOnly, LoadFlowOnly) => SpecializedLoadingSides {
171                main_flow: LoadingSides::Lhs,
172                load_only: LoadingSides::Rhs,
173            },
174            (LoadFlowOnly, MainFlowOnly) => SpecializedLoadingSides {
175                main_flow: LoadingSides::Rhs,
176                load_only: LoadingSides::Lhs,
177            },
178            (LoadFlowOnly, LoadFlowOnly) => SpecializedLoadingSides {
179                main_flow: LoadingSides::None,
180                load_only: LoadingSides::Both,
181            },
182        }
183    }
184}
185
186pub(crate) fn gcd(mut a: u32, mut b: u32) -> u32 {
187    while b != 0 {
188        let r = a % b;
189        a = b;
190        b = r;
191    }
192    a
193}