use cubecl::{CubeCount, CubeDim, VectorizationError, ir::StorageType, server::LaunchError};
use cubek_std::{InvalidConfigError, MatrixLayout, TileSize};
use std::fmt::{Debug, Display};
pub enum MatmulSetupError {
Unavailable(MatmulAvailabilityError),
InvalidConfig(InvalidConfigError),
Vectorization(VectorizationError),
Launch(LaunchError),
}
impl From<LaunchError> for MatmulSetupError {
fn from(value: LaunchError) -> Self {
Self::Launch(value)
}
}
pub enum MatmulAvailabilityError {
CubeCountTooBig(CubeCount),
CubeDimTooBig(CubeDim),
PlaneDimUnsupported { plane_dim: u32 },
TypesUnavailable {
lhs: StorageType,
rhs: StorageType,
output: StorageType,
},
CmmaInstructionUnavailable {
lhs: StorageType,
rhs: StorageType,
output: StorageType,
size: Option<TileSize>,
},
TileSizeNotFound,
LayoutUnsupported {
lhs: MatrixLayout,
rhs: MatrixLayout,
},
BarrierUnavailable,
TmaUnavailable,
ReinterpretMemoryUnavailable,
PlaneOpsUnavailable,
}
impl From<MatmulAvailabilityError> for MatmulSetupError {
fn from(value: MatmulAvailabilityError) -> Self {
Self::Unavailable(value)
}
}
impl From<InvalidConfigError> for MatmulSetupError {
fn from(value: InvalidConfigError) -> Self {
Self::InvalidConfig(value)
}
}
impl From<VectorizationError> for MatmulSetupError {
fn from(value: VectorizationError) -> Self {
Self::Vectorization(value)
}
}
impl Display for MatmulSetupError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
impl Debug for MatmulSetupError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MatmulSetupError::Unavailable(err) => {
writeln!(
f,
"Unable to launch matmul because a required feature is unavailable: {err:?}"
)
}
MatmulSetupError::InvalidConfig(err) => {
writeln!(
f,
"Unable to launch matmul because the config is invalid: {:?}",
err.to_string()
)
}
MatmulSetupError::Vectorization(err) => {
writeln!(
f,
"Unable to launch matmul because could not find supported vector size: {err:?}"
)
}
MatmulSetupError::Launch(err) => {
writeln!(f, "Unable to launch matmul with err: {err:?}")
}
}
}
}
impl Debug for MatmulAvailabilityError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MatmulAvailabilityError::CubeCountTooBig(count) => {
writeln!(f, "Cube count too big {count:?}")
}
MatmulAvailabilityError::CubeDimTooBig(dim) => {
writeln!(f, "Cube dim too big {dim:?}")
}
MatmulAvailabilityError::PlaneDimUnsupported { plane_dim } => {
writeln!(
f,
"Plane dimension unsupported: {plane_dim}. Only 32 & 64 are supported."
)
}
MatmulAvailabilityError::TypesUnavailable { lhs, rhs, output } => {
writeln!(
f,
"Types lhs={lhs:?}, rhs={rhs:?} and/or output={output:?} not supported.",
)
}
MatmulAvailabilityError::CmmaInstructionUnavailable {
lhs,
rhs,
output,
size: Some(size),
} => writeln!(
f,
"Cmma on lhs {:?} rhs {:?} and output {:?} with shape m={:?}, n={:?}, k={:?} not supported.",
lhs,
rhs,
output,
size.m(),
size.n(),
size.k()
),
MatmulAvailabilityError::LayoutUnsupported { lhs, rhs } => {
writeln!(
f,
"Cmma with layouts lhs {lhs:?} and rhs {rhs:?} not supported."
)
}
MatmulAvailabilityError::CmmaInstructionUnavailable {
lhs,
rhs,
output,
size: None,
} => writeln!(
f,
"Cmma on inputs lhs {lhs:?} rhs {rhs:?} and output {output:?} not supported.",
),
MatmulAvailabilityError::BarrierUnavailable => {
writeln!(f, "Barrier is not available.")
}
MatmulAvailabilityError::TmaUnavailable => {
writeln!(f, "TMA is not available.")
}
MatmulAvailabilityError::ReinterpretMemoryUnavailable => {
writeln!(f, "Memory reinterpretation is not available.")
}
MatmulAvailabilityError::PlaneOpsUnavailable => {
writeln!(f, "Plane-wide operations like plane_sum are not available.")
}
MatmulAvailabilityError::TileSizeNotFound => {
writeln!(f, "No tile size is available for the problem.")
}
}
}
}