use std::fmt;
#[derive(Debug)]
pub enum Error {
Hip(crate::hip::Error),
RocRand(crate::rocrand::Error),
#[cfg(feature = "miopen")]
MIOpen(crate::miopen::Error),
RocFFT(crate::rocfft::error::Error),
RocBLAS(crate::rocblas::Error),
Custom(String),
InvalidOperation(String),
OutOfMemory(String),
InvalidArgument(String),
NotImplemented(String),
Io(std::io::Error),
Parse(String),
Timeout(String),
DeviceError(String),
KernelCompilation(String),
SynchronizationError(String),
}
impl From<crate::hip::Error> for Error {
fn from(error: crate::hip::Error) -> Self {
Error::Hip(error)
}
}
impl From<crate::rocrand::Error> for Error {
fn from(error: crate::rocrand::Error) -> Self {
Error::RocRand(error)
}
}
#[cfg(feature = "miopen")]
impl From<crate::miopen::Error> for Error {
fn from(error: crate::miopen::Error) -> Self {
Error::MIOpen(error)
}
}
impl From<crate::rocfft::error::Error> for Error {
fn from(error: crate::rocfft::error::Error) -> Self {
Error::RocFFT(error)
}
}
impl From<std::io::Error> for Error {
fn from(error: std::io::Error) -> Self {
Error::Io(error)
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::Hip(e) => write!(f, "HIP error: {}", e),
Error::RocRand(e) => write!(f, "rocRAND error: {}", e),
#[cfg(feature = "miopen")]
Error::MIOpen(e) => write!(f, "MIOpen error: {}", e),
Error::RocFFT(e) => write!(f, "rocFFT error: {}", e),
Error::RocBLAS(e) => write!(f, "rocBLAS error: {}", e),
Error::Custom(msg) => write!(f, "Error: {}", msg),
Error::InvalidOperation(msg) => write!(f, "Invalid operation: {}", msg),
Error::OutOfMemory(msg) => write!(f, "Out of memory: {}", msg),
Error::InvalidArgument(msg) => write!(f, "Invalid argument: {}", msg),
Error::NotImplemented(msg) => write!(f, "Not implemented: {}", msg),
Error::Io(e) => write!(f, "I/O error: {}", e),
Error::Parse(msg) => write!(f, "Parse error: {}", msg),
Error::Timeout(msg) => write!(f, "Timeout: {}", msg),
Error::DeviceError(msg) => write!(f, "Device error: {}", msg),
Error::KernelCompilation(msg) => write!(f, "Kernel compilation error: {}", msg),
Error::SynchronizationError(msg) => write!(f, "Synchronization error: {}", msg),
}
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Error::Hip(e) => Some(e),
Error::RocRand(e) => Some(e),
#[cfg(feature = "miopen")]
Error::MIOpen(e) => Some(e),
Error::RocFFT(e) => Some(e),
Error::Io(e) => Some(e),
_ => None,
}
}
}
pub type Result<T> = std::result::Result<T, Error>;
pub fn custom_error<S: Into<String>>(message: S) -> Error {
Error::Custom(message.into())
}
pub fn invalid_operation<S: Into<String>>(message: S) -> Error {
Error::InvalidOperation(message.into())
}
pub fn out_of_memory<S: Into<String>>(message: S) -> Error {
Error::OutOfMemory(message.into())
}
pub fn invalid_argument<S: Into<String>>(message: S) -> Error {
Error::InvalidArgument(message.into())
}
pub fn not_implemented<S: Into<String>>(message: S) -> Error {
Error::NotImplemented(message.into())
}
pub fn parse_error<S: Into<String>>(message: S) -> Error {
Error::Parse(message.into())
}
pub fn timeout_error<S: Into<String>>(message: S) -> Error {
Error::Timeout(message.into())
}
pub fn device_error<S: Into<String>>(message: S) -> Error {
Error::DeviceError(message.into())
}
pub fn kernel_compilation_error<S: Into<String>>(message: S) -> Error {
Error::KernelCompilation(message.into())
}
pub fn synchronization_error<S: Into<String>>(message: S) -> Error {
Error::SynchronizationError(message.into())
}
#[macro_export]
macro_rules! rocm_error {
($kind:ident, $($arg:tt)*) => {
$crate::error::Error::$kind(format!($($arg)*))
};
}
#[macro_export]
macro_rules! custom_error {
($($arg:tt)*) => {
$crate::error::custom_error(format!($($arg)*))
};
}
#[macro_export]
macro_rules! invalid_operation {
($($arg:tt)*) => {
$crate::error::invalid_operation(format!($($arg)*))
};
}
pub trait ErrorContext<T> {
fn with_context<F>(self, f: F) -> Result<T>
where
F: FnOnce() -> String;
fn context<S: Into<String>>(self, msg: S) -> Result<T>;
}
impl<T> ErrorContext<T> for Result<T> {
fn with_context<F>(self, f: F) -> Result<T>
where
F: FnOnce() -> String,
{
self.map_err(|e| Error::Custom(format!("{}: {}", f(), e)))
}
fn context<S: Into<String>>(self, msg: S) -> Result<T> {
self.map_err(|e| Error::Custom(format!("{}: {}", msg.into(), e)))
}
}
pub mod error_codes {
pub const SUCCESS: i32 = 0;
pub const ERROR: i32 = -1;
pub const INVALID_ARGUMENT: i32 = -2;
pub const OUT_OF_MEMORY: i32 = -3;
pub const NOT_IMPLEMENTED: i32 = -4;
pub const DEVICE_ERROR: i32 = -5;
pub const KERNEL_COMPILATION: i32 = -6;
pub const SYNCHRONIZATION_ERROR: i32 = -7;
pub const TIMEOUT: i32 = -8;
pub const IO_ERROR: i32 = -9;
pub const PARSE_ERROR: i32 = -10;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_display() {
let err = custom_error("Test error");
assert_eq!(format!("{}", err), "Error: Test error");
let err = invalid_operation("Invalid operation test");
assert_eq!(
format!("{}", err),
"Invalid operation: Invalid operation test"
);
}
#[test]
fn test_error_macros() {
let err = rocm_error!(InvalidOperation, "Test {} error", "formatted");
match err {
Error::InvalidOperation(msg) => assert_eq!(msg, "Test formatted error"),
_ => panic!("Expected InvalidOperation error"),
}
}
}