cubecl_matmul/components/global/specialization/
config.rs1use crate::components::{
2 MatmulIdent,
3 global::{MaxGlobalReaderPlanes, specialization::roles::PlaneRoles},
4};
5
6#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
9pub struct LoadSpecializationConfig {
10 pub lhs: SpecializationTensorConfig,
12 pub rhs: SpecializationTensorConfig,
14}
15
16#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
21pub enum SpecializationTensorConfig {
22 #[default]
24 MainFlowOnly,
25
26 LoadFlowOnly,
29}
30
31impl LoadSpecializationConfig {
32 pub fn has_specialization(&self) -> bool {
34 self.lhs.has_specialization() || self.rhs.has_specialization()
35 }
36}
37
38impl SpecializationTensorConfig {
39 pub fn has_specialization(&self) -> bool {
41 match self {
42 SpecializationTensorConfig::MainFlowOnly => false,
43 SpecializationTensorConfig::LoadFlowOnly => true,
44 }
45 }
46}
47
48impl LoadSpecializationConfig {
49 pub fn to_plane_roles(
56 &self,
57 main_flow: u32,
58 reader_tasks: MaxGlobalReaderPlanes,
59 ) -> PlaneRoles {
60 use SpecializationTensorConfig::*;
61
62 let ideal_load_only = match (self.lhs, self.rhs) {
63 (MainFlowOnly, MainFlowOnly) => 0,
64 (MainFlowOnly, LoadFlowOnly) => reader_tasks.rhs,
65 (LoadFlowOnly, MainFlowOnly) => reader_tasks.lhs,
66 (LoadFlowOnly, LoadFlowOnly) => gcd(reader_tasks.lhs, reader_tasks.rhs),
67 };
68
69 let load_only = best_divisor_close_to_reference(ideal_load_only, main_flow);
71
72 PlaneRoles {
73 main_flow,
74 load_only,
75 }
76 }
77}
78
79fn best_divisor_close_to_reference(dividible_value: u32, reference: u32) -> u32 {
81 let mut best = 1;
82 let mut best_dist = reference.abs_diff(1);
83
84 for d in 1..=dividible_value {
85 if dividible_value.is_multiple_of(d) {
86 let dist = reference.abs_diff(d);
87 if dist < best_dist || (dist == best_dist && d > best) {
88 best = d;
89 best_dist = dist;
90 }
91 }
92 }
93
94 best
95}
96
97#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
98pub enum LoadingSides {
100 Both,
102 Lhs,
104 Rhs,
106 None,
108}
109
110impl LoadingSides {
111 pub fn includes_lhs(&self) -> bool {
113 self.includes(MatmulIdent::Lhs)
114 }
115
116 pub fn includes_rhs(&self) -> bool {
118 self.includes(MatmulIdent::Rhs)
119 }
120
121 pub fn includes(&self, ident: MatmulIdent) -> bool {
123 matches!(
124 (self, ident),
125 (LoadingSides::Both, _)
126 | (LoadingSides::Lhs, MatmulIdent::Lhs)
127 | (LoadingSides::Rhs, MatmulIdent::Rhs)
128 )
129 }
130}
131
132#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
133pub struct SpecializedLoadingSides {
135 pub main_flow: LoadingSides,
136 pub load_only: LoadingSides,
137}
138
139impl SpecializedLoadingSides {
140 pub fn num_loading_planes(
142 &self,
143 specialized: bool,
144 ident: MatmulIdent,
145 plane_roles: PlaneRoles,
146 ) -> u32 {
147 if specialized {
148 let mut num_loading_planes = 0;
149 if self.main_flow.includes(ident) {
150 num_loading_planes += plane_roles.main_flow;
151 }
152 if self.load_only.includes(ident) {
153 num_loading_planes += plane_roles.load_only;
154 }
155 num_loading_planes
156 } else {
157 plane_roles.main_flow
158 }
159 }
160}
161
162impl From<LoadSpecializationConfig> for SpecializedLoadingSides {
163 fn from(lsc: LoadSpecializationConfig) -> Self {
164 use SpecializationTensorConfig::*;
165 match (lsc.lhs, lsc.rhs) {
166 (MainFlowOnly, MainFlowOnly) => SpecializedLoadingSides {
167 main_flow: LoadingSides::Both,
168 load_only: LoadingSides::None,
169 },
170 (MainFlowOnly, LoadFlowOnly) => SpecializedLoadingSides {
171 main_flow: LoadingSides::Lhs,
172 load_only: LoadingSides::Rhs,
173 },
174 (LoadFlowOnly, MainFlowOnly) => SpecializedLoadingSides {
175 main_flow: LoadingSides::Rhs,
176 load_only: LoadingSides::Lhs,
177 },
178 (LoadFlowOnly, LoadFlowOnly) => SpecializedLoadingSides {
179 main_flow: LoadingSides::None,
180 load_only: LoadingSides::Both,
181 },
182 }
183 }
184}
185
186pub(crate) fn gcd(mut a: u32, mut b: u32) -> u32 {
187 while b != 0 {
188 let r = a % b;
189 a = b;
190 b = r;
191 }
192 a
193}