use std::fmt;
#[derive(Debug)]
pub enum Error {
Cuda(String),
Uncovered {
op: &'static str,
m: usize,
k: usize,
n: usize,
},
DtypeMismatch(&'static str),
UnsupportedArch { major: u32, minor: u32 },
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::Cuda(s) => write!(f, "CUDA error: {s}"),
Error::Uncovered { op, m, k, n } => write!(
f,
"{op}: no deterministic bucket for shape M={m} K={k} N={n} in this tier"
),
Error::DtypeMismatch(what) => write!(f, "dtype mismatch: {what}"),
Error::UnsupportedArch { major, minor } => write!(
f,
"unsupported GPU architecture sm_{major}{minor}: sgemm-bi requires \
Ampere or newer (sm_80+)"
),
}
}
}
impl std::error::Error for Error {}
impl From<cudarc::driver::DriverError> for Error {
fn from(e: cudarc::driver::DriverError) -> Self {
Error::Cuda(format!("{e:?}"))
}
}
pub type Result<T> = std::result::Result<T, Error>;