use thiserror::Error;
pub type Result<T> = core::result::Result<T, Error>;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum Error {
#[error("baracuda-cutlass: requested kernel is unavailable: {0}")]
Unsupported(&'static str),
#[error("baracuda-cutlass: invalid problem: {0}")]
InvalidProblem(&'static str),
#[error("baracuda-cutlass: misaligned operand")]
MisalignedOperand,
#[error("baracuda-cutlass: workspace too small (need {needed} bytes, got {got})")]
WorkspaceTooSmall {
needed: usize,
got: usize,
},
#[error("baracuda-cutlass: buffer too small for declared shape (need {needed} elements, got {got})")]
BufferTooSmall {
needed: usize,
got: usize,
},
#[error("baracuda-cutlass: CUTLASS internal error (status code {0})")]
CutlassInternal(i32),
#[error("baracuda-cutlass: driver error: {0}")]
Driver(#[from] baracuda_driver::Error),
}
pub(crate) fn status_to_result(code: i32) -> Result<()> {
match code {
0 => Ok(()),
1 => Err(Error::MisalignedOperand),
2 => Err(Error::InvalidProblem("CUTLASS reported invalid problem")),
3 => Err(Error::Unsupported("CUTLASS reported unsupported configuration")),
n => Err(Error::CutlassInternal(n)),
}
}