Skip to main content

cubek_std/
cube_dim_resource.rs

1use cubecl::prelude::*;
2
3use crate::InvalidConfigError;
4
5#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
6/// Represents how many planes are used for main computation and for loading-only tasks.
7pub struct PlaneFlowCounts {
8    /// Number of planes participating in main flow and (possibly) loading.
9    pub main_flow: u32,
10    /// Number of planes dedicated solely to loading.
11    pub load_only: u32,
12}
13
14impl PlaneFlowCounts {
15    /// Return the total number of planes
16    pub fn total_count(&self) -> u32 {
17        self.main_flow + self.load_only
18    }
19}
20
21#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
22/// How planes are partitioned by id between the main flow and load-only roles.
23pub enum PlaneFlowPartitionRule {
24    MainFlowOnly,
25    LoadOnlyFirst { load_only: u32 },
26    LoadOnlyLast { main_flow: u32 },
27}
28
29#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
30/// Plane-flow configuration carried by [`CubeDimResource::Specialized`]. Holds the
31/// counts for main-flow vs load-only planes and the partition rule used at runtime.
32pub struct SpecializedCubeDim {
33    pub counts: PlaneFlowCounts,
34    pub partition_rule: PlaneFlowPartitionRule,
35}
36
37impl SpecializedCubeDim {
38    /// All planes participate in the main flow; no load-only planes.
39    pub fn new_unspecialized(num_planes: u32) -> Self {
40        Self {
41            counts: PlaneFlowCounts {
42                main_flow: num_planes,
43                load_only: 0,
44            },
45            partition_rule: PlaneFlowPartitionRule::MainFlowOnly,
46        }
47    }
48
49    /// Number of planes participating in main flow.
50    pub fn main_flow_count(&self) -> u32 {
51        self.counts.main_flow
52    }
53
54    /// Whether the configuration uses dedicated load-only planes.
55    pub fn has_specialization(&self) -> bool {
56        self.counts.load_only > 0
57    }
58}
59
60#[derive(Debug)]
61/// Number of compute primitives required by some component, specified as either units, planes,
62/// or a specialized plane-flow split.
63pub enum CubeDimResource {
64    Units(u32),
65    Planes(u32),
66    Specialized(SpecializedCubeDim),
67}
68
69impl CubeDimResource {
70    /// Ensures [CubeDimResource] is the Planes variant, converting
71    /// units using plane_dim, the number of units in a plane.
72    ///
73    /// Will fail if the number of units does not correspond to an exact number of planes
74    pub fn as_plane_resource(self, plane_dim: u32) -> Result<Self, InvalidConfigError> {
75        match self {
76            CubeDimResource::Units(units) => {
77                if units % plane_dim == 0 {
78                    Ok(CubeDimResource::Planes(units / plane_dim))
79                } else {
80                    Err(Box::new(format!(
81                        "Number of units {units:?} should be divisible by plane_dim {plane_dim:?}"
82                    )))
83                }
84            }
85            CubeDimResource::Planes(_) => Ok(self),
86            CubeDimResource::Specialized(spec) => {
87                Ok(CubeDimResource::Planes(spec.counts.total_count()))
88            }
89        }
90    }
91
92    /// Make a [CubeDim] from specified resources.
93    ///
94    /// Obtained CubeDim is always (plane_dim, number_of_planes, 1)
95    ///
96    /// Will fail if the number of units does not correspond to an exact number of planes
97    pub fn to_cube_dim(self, plane_dim: u32) -> Result<CubeDim, InvalidConfigError> {
98        match self {
99            CubeDimResource::Units(_) => self.as_plane_resource(plane_dim)?.to_cube_dim(plane_dim),
100            CubeDimResource::Planes(num_planes) => Ok(CubeDim::new_2d(plane_dim, num_planes)),
101            CubeDimResource::Specialized(_) => {
102                self.as_plane_resource(plane_dim)?.to_cube_dim(plane_dim)
103            }
104        }
105    }
106
107    /// Get the number of planes
108    ///
109    /// Will fail if the number of units does not correspond to an exact number of planes
110    pub fn num_planes(self, plane_dim: u32) -> Result<u32, InvalidConfigError> {
111        let plane_resources = self.as_plane_resource(plane_dim)?;
112        if let CubeDimResource::Planes(num_planes) = plane_resources {
113            Ok(num_planes)
114        } else {
115            unreachable!()
116        }
117    }
118
119    /// Recover the [SpecializedCubeDim] view of this resource. `Units`/`Planes` produce a
120    /// non-specialized config (all planes in the main flow).
121    pub fn as_specialized(self, plane_dim: u32) -> Result<SpecializedCubeDim, InvalidConfigError> {
122        match self {
123            CubeDimResource::Units(_) => {
124                self.as_plane_resource(plane_dim)?.as_specialized(plane_dim)
125            }
126            CubeDimResource::Planes(num_planes) => {
127                Ok(SpecializedCubeDim::new_unspecialized(num_planes))
128            }
129            CubeDimResource::Specialized(spec) => Ok(spec),
130        }
131    }
132}