cubecl_matmul/components/global/specialization/
roles.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::components::error::MatmulSetupError;
5use crate::components::global::specialization::config::LoadSpecializationConfig;
6use crate::components::global::{MaxGlobalReaderPlanes, SpecializationTensorConfig};
7
8#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
9pub struct PlaneRoles {
11 pub main_flow: u32,
13 pub load_only: u32,
15}
16
17impl PlaneRoles {
18 pub fn total_count(&self) -> u32 {
20 self.main_flow + self.load_only
21 }
22}
23
24#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
25pub struct PlaneRoleConfig {
27 pub plane_roles: PlaneRoles,
28 pub rule: RoleRuleConfig,
29}
30
31#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
32pub enum RoleRuleConfig {
34 MainFlowOnly,
35 LoadOnlyFirst { load_only: u32 },
36 LoadOnlyLast { main_flow: u32 },
37}
38
39#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
40pub struct Threshold {
45 #[cube(comptime)]
46 threshold: u32,
47}
48
49#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
50pub enum RoleRule {
52 MainFlowOnly,
54 LoadOnlyFirst(Threshold),
57 LoadOnlyLast(Threshold),
60}
61
62impl PlaneRoleConfig {
63 pub fn new(
65 load_specialization_config: LoadSpecializationConfig,
66 reader_tasks: Option<MaxGlobalReaderPlanes>,
67 num_main_flow_planes: u32,
68 ) -> Result<PlaneRoleConfig, MatmulSetupError> {
69 let plane_roles = match reader_tasks {
70 Some(reader_tasks) => {
71 load_specialization_config.to_plane_roles(num_main_flow_planes, reader_tasks)
72 }
73
74 None => {
75 if load_specialization_config.has_specialization() {
76 return Err(MatmulSetupError::InvalidConfig(Box::new(
77 "Error: Load specialization config has specialization but no reader tasks were given."
78 .to_string(),
79 )));
80 } else {
81 PlaneRoles {
82 main_flow: num_main_flow_planes,
83 load_only: 0,
84 }
85 }
86 }
87 };
88
89 let rule = match plane_roles.load_only {
91 0 => RoleRuleConfig::MainFlowOnly,
92 _ => RoleRuleConfig::LoadOnlyFirst {
93 load_only: plane_roles.load_only,
94 },
95 };
96
97 Ok(Self { plane_roles, rule })
98 }
99
100 pub fn new_unspecialized(num_planes: u32) -> PlaneRoleConfig {
101 PlaneRoleConfig {
102 plane_roles: PlaneRoles {
103 main_flow: num_planes,
104 load_only: 0,
105 },
106 rule: RoleRuleConfig::MainFlowOnly,
107 }
108 }
109
110 pub fn main_flow_count(&self) -> u32 {
112 self.plane_roles.main_flow
113 }
114
115 pub fn has_specialization(&self) -> bool {
117 self.plane_roles.load_only > 0
118 }
119}
120
121#[cube]
122impl RoleRule {
123 pub fn new(#[comptime] comptime_rule: RoleRuleConfig) -> RoleRule {
125 match comptime!(comptime_rule) {
126 RoleRuleConfig::MainFlowOnly => RoleRule::new_MainFlowOnly(),
127 RoleRuleConfig::LoadOnlyFirst { load_only } => RoleRule::new_LoadOnlyFirst(Threshold {
128 threshold: load_only,
129 }),
130 RoleRuleConfig::LoadOnlyLast { main_flow } => RoleRule::new_LoadOnlyLast(Threshold {
131 threshold: main_flow,
132 }),
133 }
134 }
135
136 pub fn compute_index(self) -> u32 {
139 match self {
140 RoleRule::MainFlowOnly => UNIT_POS_Y,
141 RoleRule::LoadOnlyFirst(load_only) => UNIT_POS_Y - load_only.threshold,
142 RoleRule::LoadOnlyLast(_) => UNIT_POS_Y,
143 }
144 }
145
146 pub fn load_index(
149 self,
150 #[comptime] specialization_tensor_config: SpecializationTensorConfig,
151 ) -> u32 {
152 match self {
153 RoleRule::MainFlowOnly => UNIT_POS_Y,
154 RoleRule::LoadOnlyFirst(load_only) => match specialization_tensor_config {
155 SpecializationTensorConfig::MainFlowOnly => UNIT_POS_Y - load_only.threshold,
156 SpecializationTensorConfig::LoadFlowOnly => UNIT_POS_Y,
157 },
158 RoleRule::LoadOnlyLast(main_flow) => match specialization_tensor_config {
159 SpecializationTensorConfig::LoadFlowOnly => UNIT_POS_Y - main_flow.threshold,
160 SpecializationTensorConfig::MainFlowOnly => UNIT_POS_Y,
161 },
162 }
163 }
164
165 pub fn elect_load_leader(self) -> bool {
171 let plane_id = plane_broadcast(UNIT_POS_Y, 0);
172
173 let is_elected_plane = match self {
174 RoleRule::MainFlowOnly | RoleRule::LoadOnlyFirst(_) => plane_id == 0,
175 RoleRule::LoadOnlyLast(main_flow) => plane_id == main_flow.threshold,
176 };
177
178 is_elected_plane && plane_elect()
179 }
180
181 pub fn is_load_plane(self) -> bool {
183 match self {
184 RoleRule::MainFlowOnly => false,
185 RoleRule::LoadOnlyFirst(load_only) => UNIT_POS_Y < load_only.threshold,
186 RoleRule::LoadOnlyLast(main_flow) => UNIT_POS_Y >= main_flow.threshold,
187 }
188 }
189
190 pub fn is_compute_plane(self) -> bool {
195 let plane_id = plane_broadcast(UNIT_POS_Y, 0);
196
197 match self {
198 RoleRule::MainFlowOnly => true,
199 RoleRule::LoadOnlyFirst(load_only) => plane_id >= load_only.threshold,
200 RoleRule::LoadOnlyLast(main_flow) => plane_id < main_flow.threshold,
201 }
202 }
203}