cubek_std/
cube_dim_resource.rs1use cubecl::prelude::*;
2
3use crate::InvalidConfigError;
4
5#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
6pub struct PlaneFlowCounts {
8 pub main_flow: u32,
10 pub load_only: u32,
12}
13
14impl PlaneFlowCounts {
15 pub fn total_count(&self) -> u32 {
17 self.main_flow + self.load_only
18 }
19}
20
21#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
22pub enum PlaneFlowPartitionRule {
24 MainFlowOnly,
25 LoadOnlyFirst { load_only: u32 },
26 LoadOnlyLast { main_flow: u32 },
27}
28
29#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
30pub struct SpecializedCubeDim {
33 pub counts: PlaneFlowCounts,
34 pub partition_rule: PlaneFlowPartitionRule,
35}
36
37impl SpecializedCubeDim {
38 pub fn new_unspecialized(num_planes: u32) -> Self {
40 Self {
41 counts: PlaneFlowCounts {
42 main_flow: num_planes,
43 load_only: 0,
44 },
45 partition_rule: PlaneFlowPartitionRule::MainFlowOnly,
46 }
47 }
48
49 pub fn main_flow_count(&self) -> u32 {
51 self.counts.main_flow
52 }
53
54 pub fn has_specialization(&self) -> bool {
56 self.counts.load_only > 0
57 }
58}
59
60#[derive(Debug)]
61pub enum CubeDimResource {
64 Units(u32),
65 Planes(u32),
66 Specialized(SpecializedCubeDim),
67}
68
69impl CubeDimResource {
70 pub fn as_plane_resource(self, plane_dim: u32) -> Result<Self, InvalidConfigError> {
75 match self {
76 CubeDimResource::Units(units) => {
77 if units % plane_dim == 0 {
78 Ok(CubeDimResource::Planes(units / plane_dim))
79 } else {
80 Err(Box::new(format!(
81 "Number of units {units:?} should be divisible by plane_dim {plane_dim:?}"
82 )))
83 }
84 }
85 CubeDimResource::Planes(_) => Ok(self),
86 CubeDimResource::Specialized(spec) => {
87 Ok(CubeDimResource::Planes(spec.counts.total_count()))
88 }
89 }
90 }
91
92 pub fn to_cube_dim(self, plane_dim: u32) -> Result<CubeDim, InvalidConfigError> {
98 match self {
99 CubeDimResource::Units(_) => self.as_plane_resource(plane_dim)?.to_cube_dim(plane_dim),
100 CubeDimResource::Planes(num_planes) => Ok(CubeDim::new_2d(plane_dim, num_planes)),
101 CubeDimResource::Specialized(_) => {
102 self.as_plane_resource(plane_dim)?.to_cube_dim(plane_dim)
103 }
104 }
105 }
106
107 pub fn num_planes(self, plane_dim: u32) -> Result<u32, InvalidConfigError> {
111 let plane_resources = self.as_plane_resource(plane_dim)?;
112 if let CubeDimResource::Planes(num_planes) = plane_resources {
113 Ok(num_planes)
114 } else {
115 unreachable!()
116 }
117 }
118
119 pub fn as_specialized(self, plane_dim: u32) -> Result<SpecializedCubeDim, InvalidConfigError> {
122 match self {
123 CubeDimResource::Units(_) => {
124 self.as_plane_resource(plane_dim)?.as_specialized(plane_dim)
125 }
126 CubeDimResource::Planes(num_planes) => {
127 Ok(SpecializedCubeDim::new_unspecialized(num_planes))
128 }
129 CubeDimResource::Specialized(spec) => Ok(spec),
130 }
131 }
132}