hpt_common/error/
device.rs

1use std::panic::Location;
2
3use thiserror::Error;
4
5/// Device-related errors such as device not found, CUDA errors
6#[derive(Debug, Error)]
7pub enum DeviceError {
8    /// Device not found
9    #[error("Device {device} not found at {location}")]
10    NotFound {
11        /// Name or ID of the device that was not found
12        device: String,
13        /// Location where the error occurred
14        location: &'static Location<'static>,
15    },
16
17    /// CUDA driver error
18    #[error("CUDA driver error: {message} at {location}")]
19    #[cfg(feature = "cuda")]
20    CudaDriverError {
21        /// Error message
22        message: String,
23        /// Source error
24        #[source]
25        source: Option<cudarc::driver::result::DriverError>,
26        /// Location where the error occurred
27        location: &'static Location<'static>,
28    },
29
30    #[cfg(feature = "cuda")]
31    /// CUDA Cublas error
32    #[error("CUDA Cublas error: {message} at {location}")]
33    CudaCublasError {
34        /// Error message
35        message: String,
36        /// Source error
37        #[source]
38        source: Option<cudarc::cublas::result::CublasError>,
39        /// Location where the error occurred
40        location: &'static Location<'static>,
41    },
42
43    /// Environment variable not set
44    #[error("Environment variable {variable} not set at {location}")]
45    EnvVarNotSet {
46        /// Name of the environment variable that was not set
47        variable: String,
48        /// Location where the error occurred
49        location: &'static Location<'static>,
50    },
51}
52
53#[cfg(feature = "cuda")]
54mod impls {
55    use crate::error::base::TensorError;
56    use crate::error::device::DeviceError;
57    use std::panic::Location;
58    impl From<cudarc::driver::result::DriverError> for TensorError {
59        fn from(source: cudarc::driver::result::DriverError) -> Self {
60            Self::Device(DeviceError::CudaDriverError {
61                message: source.to_string(),
62                source: Some(source),
63                location: Location::caller(),
64            })
65        }
66    }
67    impl From<cudarc::cublas::result::CublasError> for TensorError {
68        fn from(source: cudarc::cublas::result::CublasError) -> Self {
69            Self::Device(DeviceError::CudaCublasError {
70                message: source.to_string(),
71                source: Some(source),
72                location: Location::caller(),
73            })
74        }
75    }
76}