use std::panic::Location;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum DeviceError {
#[error("Device {device} not found at {location}")]
NotFound {
device: String,
location: &'static Location<'static>,
},
#[error("CUDA driver error: {message} at {location}")]
#[cfg(feature = "cuda")]
CudaDriverError {
message: String,
#[source]
source: Option<cudarc::driver::result::DriverError>,
location: &'static Location<'static>,
},
#[cfg(feature = "cuda")]
#[error("CUDA Cublas error: {message} at {location}")]
CudaCublasError {
message: String,
#[source]
source: Option<cudarc::cublas::result::CublasError>,
location: &'static Location<'static>,
},
#[error("Environment variable {variable} not set at {location}")]
EnvVarNotSet {
variable: String,
location: &'static Location<'static>,
},
}
#[cfg(feature = "cuda")]
mod impls {
use crate::error::base::TensorError;
use crate::error::device::DeviceError;
use std::panic::Location;
impl From<cudarc::driver::result::DriverError> for TensorError {
fn from(source: cudarc::driver::result::DriverError) -> Self {
Self::Device(DeviceError::CudaDriverError {
message: source.to_string(),
source: Some(source),
location: Location::caller(),
})
}
}
impl From<cudarc::cublas::result::CublasError> for TensorError {
fn from(source: cudarc::cublas::result::CublasError) -> Self {
Self::Device(DeviceError::CudaCublasError {
message: source.to_string(),
source: Some(source),
location: Location::caller(),
})
}
}
}