cubecl_matmul/components/global/specialization/
roles.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::components::MatmulIdent;
5use crate::components::error::MatmulSetupError;
6use crate::components::global::MaxGlobalReaderPlanes;
7use crate::components::global::specialization::config::{
8    LoadSpecializationConfig, SpecializedLoadingSides,
9};
10
11#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
12/// Represents how many planes are used for main matmul computation and for loading-only tasks.
13pub struct PlaneRoles {
14    /// Number of planes participating in main matmul and (possibly) loading.
15    pub main_flow: u32,
16    /// Number of planes dedicated solely to loading.
17    pub load_only: u32,
18}
19
20impl PlaneRoles {
21    /// Return the total number of planes
22    pub fn total_count(&self) -> u32 {
23        self.main_flow + self.load_only
24    }
25}
26
27#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
28/// Contains the number of plane in each role and the rule to distinguish planes based on their plane id
29pub struct PlaneRoleConfig {
30    pub plane_roles: PlaneRoles,
31    pub rule: RoleRuleConfig,
32}
33
34#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
35/// Comptime version of [RoleRule]
36pub enum RoleRuleConfig {
37    MainFlowOnly,
38    LoadOnlyFirst { load_only: u32 },
39    LoadOnlyLast { main_flow: u32 },
40}
41
42#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
43/// Threshold of plane id at which the roles change
44///
45/// Note: this struct is only necessary because Cube enums cannot hold
46/// a comptime value directly
47pub struct Threshold {
48    #[cube(comptime)]
49    threshold: u32,
50}
51
52#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
53/// Rule to distinguish a plane's role based on its plane id
54pub enum RoleRule {
55    /// All planes are in the main flow, this is equivalent of having no specialization
56    MainFlowOnly,
57    /// Load-only planes: [0, Threshold)
58    /// Main flow planes: [Threshold, total)
59    LoadOnlyFirst(Threshold),
60    /// Main flow planes: [0, Threshold)
61    /// Load-only planes: [Threshold, total)
62    LoadOnlyLast(Threshold),
63}
64
65impl PlaneRoleConfig {
66    /// Make a new PlaneRoleConfig
67    pub fn new(
68        load_specialization_config: LoadSpecializationConfig,
69        reader_tasks: Option<MaxGlobalReaderPlanes>,
70        num_main_flow_planes: u32,
71    ) -> Result<PlaneRoleConfig, MatmulSetupError> {
72        let plane_roles = match reader_tasks {
73            Some(reader_tasks) => {
74                load_specialization_config.to_plane_roles(num_main_flow_planes, reader_tasks)
75            }
76
77            None => {
78                if load_specialization_config.has_specialization() {
79                    return Err(MatmulSetupError::InvalidConfig(Box::new(
80                        "Error: Load specialization config has specialization but no reader tasks were given."
81                            .to_string(),
82                    )));
83                } else {
84                    PlaneRoles {
85                        main_flow: num_main_flow_planes,
86                        load_only: 0,
87                    }
88                }
89            }
90        };
91
92        // TODO make possible to select LoadOnlyLast
93        let rule = match plane_roles.load_only {
94            0 => RoleRuleConfig::MainFlowOnly,
95            _ => RoleRuleConfig::LoadOnlyFirst {
96                load_only: plane_roles.load_only,
97            },
98        };
99
100        Ok(Self { plane_roles, rule })
101    }
102
103    /// Returns the number of planes participating in main flow
104    pub fn main_flow_count(&self) -> u32 {
105        self.plane_roles.main_flow
106    }
107
108    /// Whether the plane role config implies specialization
109    pub fn has_specialization(&self) -> bool {
110        self.plane_roles.load_only > 0
111    }
112}
113
114#[cube]
115impl RoleRule {
116    /// Make a cube role rule from comptime config
117    pub fn new(#[comptime] comptime_rule: RoleRuleConfig) -> RoleRule {
118        match comptime!(comptime_rule) {
119            RoleRuleConfig::MainFlowOnly => RoleRule::new_MainFlowOnly(),
120            RoleRuleConfig::LoadOnlyFirst { load_only } => RoleRule::new_LoadOnlyFirst(Threshold {
121                threshold: load_only,
122            }),
123            RoleRuleConfig::LoadOnlyLast { main_flow } => RoleRule::new_LoadOnlyLast(Threshold {
124                threshold: main_flow,
125            }),
126        }
127    }
128
129    /// Whether the current plane is a load-only plane
130    pub fn is_load_only(self) -> bool {
131        match self {
132            RoleRule::MainFlowOnly => false,
133            RoleRule::LoadOnlyFirst(load_only) => UNIT_POS_Y < load_only.threshold,
134            RoleRule::LoadOnlyLast(main_flow) => UNIT_POS_Y >= main_flow.threshold,
135        }
136    }
137
138    /// The index of the current plane among planes that perform compute,
139    /// ignoring load-only planes
140    pub fn compute_index(self) -> u32 {
141        match self {
142            RoleRule::MainFlowOnly => UNIT_POS_Y,
143            RoleRule::LoadOnlyFirst(load_only) => UNIT_POS_Y - load_only.threshold,
144            RoleRule::LoadOnlyLast(_) => UNIT_POS_Y,
145        }
146    }
147
148    /// The index of the current plane among planes that perform loading,
149    /// ignoring any plane that does not participate for this `ident`.
150    pub fn load_index(
151        self,
152        #[comptime] ident: MatmulIdent,
153        #[comptime] specialized_loading_sides: SpecializedLoadingSides,
154    ) -> u32 {
155        match self {
156            RoleRule::MainFlowOnly => UNIT_POS_Y,
157            RoleRule::LoadOnlyFirst(load_only) => {
158                if comptime!(!specialized_loading_sides.load_only.includes(ident)) {
159                    UNIT_POS_Y - load_only.threshold
160                } else {
161                    UNIT_POS_Y
162                }
163            }
164            RoleRule::LoadOnlyLast(main_flow) => {
165                if comptime!(!specialized_loading_sides.main_flow.includes(ident)) {
166                    UNIT_POS_Y - main_flow.threshold
167                } else {
168                    UNIT_POS_Y
169                }
170            }
171        }
172    }
173}