cubecl_matmul/components/global/specialization/
roles.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::components::MatmulIdent;
5use crate::components::error::MatmulSetupError;
6use crate::components::global::MaxGlobalReaderPlanes;
7use crate::components::global::specialization::config::{
8 LoadSpecializationConfig, SpecializedLoadingSides,
9};
10
11#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
12pub struct PlaneRoles {
14 pub main_flow: u32,
16 pub load_only: u32,
18}
19
20impl PlaneRoles {
21 pub fn total_count(&self) -> u32 {
23 self.main_flow + self.load_only
24 }
25}
26
27#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
28pub struct PlaneRoleConfig {
30 pub plane_roles: PlaneRoles,
31 pub rule: RoleRuleConfig,
32}
33
34#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
35pub enum RoleRuleConfig {
37 MainFlowOnly,
38 LoadOnlyFirst { load_only: u32 },
39 LoadOnlyLast { main_flow: u32 },
40}
41
42#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
43pub struct Threshold {
48 #[cube(comptime)]
49 threshold: u32,
50}
51
52#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
53pub enum RoleRule {
55 MainFlowOnly,
57 LoadOnlyFirst(Threshold),
60 LoadOnlyLast(Threshold),
63}
64
65impl PlaneRoleConfig {
66 pub fn new(
68 load_specialization_config: LoadSpecializationConfig,
69 reader_tasks: Option<MaxGlobalReaderPlanes>,
70 num_main_flow_planes: u32,
71 ) -> Result<PlaneRoleConfig, MatmulSetupError> {
72 let plane_roles = match reader_tasks {
73 Some(reader_tasks) => {
74 load_specialization_config.to_plane_roles(num_main_flow_planes, reader_tasks)
75 }
76
77 None => {
78 if load_specialization_config.has_specialization() {
79 return Err(MatmulSetupError::InvalidConfig(Box::new(
80 "Error: Load specialization config has specialization but no reader tasks were given."
81 .to_string(),
82 )));
83 } else {
84 PlaneRoles {
85 main_flow: num_main_flow_planes,
86 load_only: 0,
87 }
88 }
89 }
90 };
91
92 let rule = match plane_roles.load_only {
94 0 => RoleRuleConfig::MainFlowOnly,
95 _ => RoleRuleConfig::LoadOnlyFirst {
96 load_only: plane_roles.load_only,
97 },
98 };
99
100 Ok(Self { plane_roles, rule })
101 }
102
103 pub fn main_flow_count(&self) -> u32 {
105 self.plane_roles.main_flow
106 }
107
108 pub fn has_specialization(&self) -> bool {
110 self.plane_roles.load_only > 0
111 }
112}
113
114#[cube]
115impl RoleRule {
116 pub fn new(#[comptime] comptime_rule: RoleRuleConfig) -> RoleRule {
118 match comptime!(comptime_rule) {
119 RoleRuleConfig::MainFlowOnly => RoleRule::new_MainFlowOnly(),
120 RoleRuleConfig::LoadOnlyFirst { load_only } => RoleRule::new_LoadOnlyFirst(Threshold {
121 threshold: load_only,
122 }),
123 RoleRuleConfig::LoadOnlyLast { main_flow } => RoleRule::new_LoadOnlyLast(Threshold {
124 threshold: main_flow,
125 }),
126 }
127 }
128
129 pub fn is_load_only(self) -> bool {
131 match self {
132 RoleRule::MainFlowOnly => false,
133 RoleRule::LoadOnlyFirst(load_only) => UNIT_POS_Y < load_only.threshold,
134 RoleRule::LoadOnlyLast(main_flow) => UNIT_POS_Y >= main_flow.threshold,
135 }
136 }
137
138 pub fn compute_index(self) -> u32 {
141 match self {
142 RoleRule::MainFlowOnly => UNIT_POS_Y,
143 RoleRule::LoadOnlyFirst(load_only) => UNIT_POS_Y - load_only.threshold,
144 RoleRule::LoadOnlyLast(_) => UNIT_POS_Y,
145 }
146 }
147
148 pub fn load_index(
151 self,
152 #[comptime] ident: MatmulIdent,
153 #[comptime] specialized_loading_sides: SpecializedLoadingSides,
154 ) -> u32 {
155 match self {
156 RoleRule::MainFlowOnly => UNIT_POS_Y,
157 RoleRule::LoadOnlyFirst(load_only) => {
158 if comptime!(!specialized_loading_sides.load_only.includes(ident)) {
159 UNIT_POS_Y - load_only.threshold
160 } else {
161 UNIT_POS_Y
162 }
163 }
164 RoleRule::LoadOnlyLast(main_flow) => {
165 if comptime!(!specialized_loading_sides.main_flow.includes(ident)) {
166 UNIT_POS_Y - main_flow.threshold
167 } else {
168 UNIT_POS_Y
169 }
170 }
171 }
172 }
173}