use cubecl::prelude::*;
use crate::InvalidConfigError;
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct PlaneFlowCounts {
pub main_flow: u32,
pub load_only: u32,
}
impl PlaneFlowCounts {
pub fn total_count(&self) -> u32 {
self.main_flow + self.load_only
}
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub enum PlaneFlowPartitionRule {
MainFlowOnly,
LoadOnlyFirst { load_only: u32 },
LoadOnlyLast { main_flow: u32 },
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct SpecializedCubeDim {
pub counts: PlaneFlowCounts,
pub partition_rule: PlaneFlowPartitionRule,
}
impl SpecializedCubeDim {
pub fn new_unspecialized(num_planes: u32) -> Self {
Self {
counts: PlaneFlowCounts {
main_flow: num_planes,
load_only: 0,
},
partition_rule: PlaneFlowPartitionRule::MainFlowOnly,
}
}
pub fn main_flow_count(&self) -> u32 {
self.counts.main_flow
}
pub fn has_specialization(&self) -> bool {
self.counts.load_only > 0
}
}
#[derive(Debug)]
pub enum CubeDimResource {
Units(u32),
Planes(u32),
Specialized(SpecializedCubeDim),
}
impl CubeDimResource {
pub fn as_plane_resource(self, plane_dim: u32) -> Result<Self, InvalidConfigError> {
match self {
CubeDimResource::Units(units) => {
if units % plane_dim == 0 {
Ok(CubeDimResource::Planes(units / plane_dim))
} else {
Err(Box::new(format!(
"Number of units {units:?} should be divisible by plane_dim {plane_dim:?}"
)))
}
}
CubeDimResource::Planes(_) => Ok(self),
CubeDimResource::Specialized(spec) => {
Ok(CubeDimResource::Planes(spec.counts.total_count()))
}
}
}
pub fn to_cube_dim(self, plane_dim: u32) -> Result<CubeDim, InvalidConfigError> {
match self {
CubeDimResource::Units(_) => self.as_plane_resource(plane_dim)?.to_cube_dim(plane_dim),
CubeDimResource::Planes(num_planes) => Ok(CubeDim::new_2d(plane_dim, num_planes)),
CubeDimResource::Specialized(_) => {
self.as_plane_resource(plane_dim)?.to_cube_dim(plane_dim)
}
}
}
pub fn num_planes(self, plane_dim: u32) -> Result<u32, InvalidConfigError> {
let plane_resources = self.as_plane_resource(plane_dim)?;
if let CubeDimResource::Planes(num_planes) = plane_resources {
Ok(num_planes)
} else {
unreachable!()
}
}
pub fn as_specialized(self, plane_dim: u32) -> Result<SpecializedCubeDim, InvalidConfigError> {
match self {
CubeDimResource::Units(_) => {
self.as_plane_resource(plane_dim)?.as_specialized(plane_dim)
}
CubeDimResource::Planes(num_planes) => {
Ok(SpecializedCubeDim::new_unspecialized(num_planes))
}
CubeDimResource::Specialized(spec) => Ok(spec),
}
}
}