use std::{
ffi::{self, CStr},
fmt::{self, Display, Formatter},
result,
};
use num_enum::{IntoPrimitive, TryFromPrimitive};
use singe_core::impl_enum_conversion;
use singe_cuda::error::Error as CudaError;
use thiserror::Error;
use singe_cublas_sys as sys;
#[derive(Error, Debug)]
pub enum Error {
#[error("cuda error: {0}")]
Cuda(#[from] CudaError),
#[error("cublas error ({code}): {message}")]
Cublas { code: Status, message: String },
#[error("string contains interior nul byte")]
InteriorNul,
#[error("unexpected null handle")]
NullHandle,
#[error("{name} is out of range")]
OutOfRange { name: String },
#[error("unexpected attribute size: expected {expected} bytes, got {actual}")]
AttributeSizeMismatch { expected: usize, actual: usize },
#[error("{name} has mismatched length")]
MismatchedLength { name: String },
#[error("invalid vector increment")]
InvalidIncrement,
#[error("invalid matrix leading dimension")]
InvalidLeadingDimension,
#[error("invalid matrix shape")]
InvalidMatrixShape,
#[error("invalid vector shape")]
InvalidVectorShape,
#[error("stream belongs to a different cuda context")]
StreamContextMismatch,
#[error("operation requires host pointer mode")]
RequiresHostPointerMode,
#[error("scalar pointer modes do not match")]
ScalarPointerModeMismatch,
}
pub type Result<T> = result::Result<T, Error>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
#[repr(u32)]
pub enum Status {
Success = sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS as _,
NotInitialized = sys::cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED as _,
AllocFailed = sys::cublasStatus_t::CUBLAS_STATUS_ALLOC_FAILED as _,
InvalidValue = sys::cublasStatus_t::CUBLAS_STATUS_INVALID_VALUE as _,
ArchMismatch = sys::cublasStatus_t::CUBLAS_STATUS_ARCH_MISMATCH as _,
MappingError = sys::cublasStatus_t::CUBLAS_STATUS_MAPPING_ERROR as _,
ExecutionFailed = sys::cublasStatus_t::CUBLAS_STATUS_EXECUTION_FAILED as _,
InternalError = sys::cublasStatus_t::CUBLAS_STATUS_INTERNAL_ERROR as _,
NotSupported = sys::cublasStatus_t::CUBLAS_STATUS_NOT_SUPPORTED as _,
LicenseError = sys::cublasStatus_t::CUBLAS_STATUS_LICENSE_ERROR as _,
}
impl_enum_conversion!(sys::cublasStatus_t, Status);
impl Status {
pub const fn description(self) -> &'static str {
match self {
Self::Success => "success",
Self::NotInitialized => "not initialized",
Self::AllocFailed => "allocation failed",
Self::InvalidValue => "invalid value",
Self::ArchMismatch => "architecture mismatch",
Self::MappingError => "mapping error",
Self::ExecutionFailed => "execution failed",
Self::InternalError => "internal error",
Self::NotSupported => "not supported",
Self::LicenseError => "license error",
}
}
}
impl Display for Status {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Success => write!(f, "CUBLAS_STATUS_SUCCESS"),
Self::NotInitialized => write!(f, "CUBLAS_STATUS_NOT_INITIALIZED"),
Self::AllocFailed => write!(f, "CUBLAS_STATUS_ALLOC_FAILED"),
Self::InvalidValue => write!(f, "CUBLAS_STATUS_INVALID_VALUE"),
Self::ArchMismatch => write!(f, "CUBLAS_STATUS_ARCH_MISMATCH"),
Self::MappingError => write!(f, "CUBLAS_STATUS_MAPPING_ERROR"),
Self::ExecutionFailed => write!(f, "CUBLAS_STATUS_EXECUTION_FAILED"),
Self::InternalError => write!(f, "CUBLAS_STATUS_INTERNAL_ERROR"),
Self::NotSupported => write!(f, "CUBLAS_STATUS_NOT_SUPPORTED"),
Self::LicenseError => write!(f, "CUBLAS_STATUS_LICENSE_ERROR"),
}
}
}
impl From<sys::cublasStatus_t> for Error {
fn from(status: sys::cublasStatus_t) -> Self {
debug_assert_ne!(status, sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS);
let message = unsafe {
let c_ptr = sys::cublasGetStatusString(status);
if c_ptr.is_null() {
String::from("unknown cublas error")
} else {
CStr::from_ptr(c_ptr).to_string_lossy().into_owned()
}
};
Self::Cublas {
code: status.into(),
message,
}
}
}
impl From<Status> for Error {
fn from(status: Status) -> Self {
sys::cublasStatus_t::from(status).into()
}
}
impl From<ffi::NulError> for Error {
fn from(_: ffi::NulError) -> Self {
Self::InteriorNul
}
}
#[macro_export]
macro_rules! try_ffi {
($expr:expr) => {{
let status = { $expr };
if status != singe_cublas_sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS {
Err($crate::error::Error::from(status))
} else {
Ok(())
}
}};
}