cubecl_matmul/components/global/specialization/
roles.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::components::error::MatmulSetupError;
5use crate::components::global::specialization::config::LoadSpecializationConfig;
6use crate::components::global::{MaxGlobalReaderPlanes, SpecializationTensorConfig};
7
8#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
9/// Represents how many planes are used for main matmul computation and for loading-only tasks.
10pub struct PlaneRoles {
11    /// Number of planes participating in main matmul and (possibly) loading.
12    pub main_flow: u32,
13    /// Number of planes dedicated solely to loading.
14    pub load_only: u32,
15}
16
17impl PlaneRoles {
18    /// Return the total number of planes
19    pub fn total_count(&self) -> u32 {
20        self.main_flow + self.load_only
21    }
22}
23
24#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
25/// Contains the number of plane in each role and the rule to distinguish planes based on their plane id
26pub struct PlaneRoleConfig {
27    pub plane_roles: PlaneRoles,
28    pub rule: RoleRuleConfig,
29}
30
31#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
32/// Comptime version of [RoleRule]
33pub enum RoleRuleConfig {
34    MainFlowOnly,
35    LoadOnlyFirst { load_only: u32 },
36    LoadOnlyLast { main_flow: u32 },
37}
38
39#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
40/// Threshold of plane id at which the roles change
41///
42/// Note: this struct is only necessary because Cube enums cannot hold
43/// a comptime value directly
44pub struct Threshold {
45    #[cube(comptime)]
46    threshold: u32,
47}
48
49#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
50/// Rule to distinguish a plane's role based on its plane id
51pub enum RoleRule {
52    /// All planes are in the main flow, this is equivalent of having no specialization
53    MainFlowOnly,
54    /// Load-only planes: [0, Threshold)
55    /// Main flow planes: [Threshold, total)
56    LoadOnlyFirst(Threshold),
57    /// Main flow planes: [0, Threshold)
58    /// Load-only planes: [Threshold, total)
59    LoadOnlyLast(Threshold),
60}
61
62impl PlaneRoleConfig {
63    /// Make a new PlaneRoleConfig
64    pub fn new(
65        load_specialization_config: LoadSpecializationConfig,
66        reader_tasks: Option<MaxGlobalReaderPlanes>,
67        num_main_flow_planes: u32,
68    ) -> Result<PlaneRoleConfig, MatmulSetupError> {
69        let plane_roles = match reader_tasks {
70            Some(reader_tasks) => {
71                load_specialization_config.to_plane_roles(num_main_flow_planes, reader_tasks)
72            }
73
74            None => {
75                if load_specialization_config.has_specialization() {
76                    return Err(MatmulSetupError::InvalidConfig(Box::new(
77                        "Error: Load specialization config has specialization but no reader tasks were given."
78                            .to_string(),
79                    )));
80                } else {
81                    PlaneRoles {
82                        main_flow: num_main_flow_planes,
83                        load_only: 0,
84                    }
85                }
86            }
87        };
88
89        // TODO make possible to select LoadOnlyLast
90        let rule = match plane_roles.load_only {
91            0 => RoleRuleConfig::MainFlowOnly,
92            _ => RoleRuleConfig::LoadOnlyFirst {
93                load_only: plane_roles.load_only,
94            },
95        };
96
97        Ok(Self { plane_roles, rule })
98    }
99
100    pub fn new_unspecialized(num_planes: u32) -> PlaneRoleConfig {
101        PlaneRoleConfig {
102            plane_roles: PlaneRoles {
103                main_flow: num_planes,
104                load_only: 0,
105            },
106            rule: RoleRuleConfig::MainFlowOnly,
107        }
108    }
109
110    /// Returns the number of planes participating in main flow
111    pub fn main_flow_count(&self) -> u32 {
112        self.plane_roles.main_flow
113    }
114
115    /// Whether the plane role config implies specialization
116    pub fn has_specialization(&self) -> bool {
117        self.plane_roles.load_only > 0
118    }
119}
120
121#[cube]
122impl RoleRule {
123    /// Make a cube role rule from comptime config
124    pub fn new(#[comptime] comptime_rule: RoleRuleConfig) -> RoleRule {
125        match comptime!(comptime_rule) {
126            RoleRuleConfig::MainFlowOnly => RoleRule::new_MainFlowOnly(),
127            RoleRuleConfig::LoadOnlyFirst { load_only } => RoleRule::new_LoadOnlyFirst(Threshold {
128                threshold: load_only,
129            }),
130            RoleRuleConfig::LoadOnlyLast { main_flow } => RoleRule::new_LoadOnlyLast(Threshold {
131                threshold: main_flow,
132            }),
133        }
134    }
135
136    /// The index of the current plane among planes that perform compute,
137    /// ignoring load-only planes
138    pub fn compute_index(self) -> u32 {
139        match self {
140            RoleRule::MainFlowOnly => UNIT_POS_Y,
141            RoleRule::LoadOnlyFirst(load_only) => UNIT_POS_Y - load_only.threshold,
142            RoleRule::LoadOnlyLast(_) => UNIT_POS_Y,
143        }
144    }
145
146    /// The index of the current plane among planes that perform loading,
147    /// ignoring any plane that does not participate for this `ident`.
148    pub fn load_index(
149        self,
150        #[comptime] specialization_tensor_config: SpecializationTensorConfig,
151    ) -> u32 {
152        match self {
153            RoleRule::MainFlowOnly => UNIT_POS_Y,
154            RoleRule::LoadOnlyFirst(load_only) => match specialization_tensor_config {
155                SpecializationTensorConfig::MainFlowOnly => UNIT_POS_Y - load_only.threshold,
156                SpecializationTensorConfig::LoadFlowOnly => UNIT_POS_Y,
157            },
158            RoleRule::LoadOnlyLast(main_flow) => match specialization_tensor_config {
159                SpecializationTensorConfig::LoadFlowOnly => UNIT_POS_Y - main_flow.threshold,
160                SpecializationTensorConfig::MainFlowOnly => UNIT_POS_Y,
161            },
162        }
163    }
164
165    /// Whether this unit is the leader of the loading units. Will always be the lowest unit in the
166    /// correct group.
167    ///
168    /// Only used with TMA, so has some CUDA optimizations. `plane_broadcast` and `plane_elect`
169    /// ensure the compiler recognizes the values as warp uniform.
170    pub fn elect_load_leader(self) -> bool {
171        let plane_id = plane_broadcast(UNIT_POS_Y, 0);
172
173        let is_elected_plane = match self {
174            RoleRule::MainFlowOnly | RoleRule::LoadOnlyFirst(_) => plane_id == 0,
175            RoleRule::LoadOnlyLast(main_flow) => plane_id == main_flow.threshold,
176        };
177
178        is_elected_plane && plane_elect()
179    }
180
181    /// Whether the current plane is a load-only plane
182    pub fn is_load_plane(self) -> bool {
183        match self {
184            RoleRule::MainFlowOnly => false,
185            RoleRule::LoadOnlyFirst(load_only) => UNIT_POS_Y < load_only.threshold,
186            RoleRule::LoadOnlyLast(main_flow) => UNIT_POS_Y >= main_flow.threshold,
187        }
188    }
189
190    /// Whether this plane is part of the compute planes
191    ///
192    /// Only used in specialized, so has some CUDA optimizations. `plane_broadcast` ensure the
193    /// compiler recognizes the values as warp uniform.
194    pub fn is_compute_plane(self) -> bool {
195        let plane_id = plane_broadcast(UNIT_POS_Y, 0);
196
197        match self {
198            RoleRule::MainFlowOnly => true,
199            RoleRule::LoadOnlyFirst(load_only) => plane_id >= load_only.threshold,
200            RoleRule::LoadOnlyLast(main_flow) => plane_id < main_flow.threshold,
201        }
202    }
203}