Skip to main content

ferrotorch_gpu/
error.rs

1use core::fmt;
2
3/// Errors produced by GPU operations.
4#[derive(Debug)]
5#[non_exhaustive]
6pub enum GpuError {
7    /// CUDA driver error forwarded from cudarc.
8    #[cfg(feature = "cuda")]
9    Driver(cudarc::driver::DriverError),
10
11    /// Attempted a GPU operation but the `cuda` feature is not enabled.
12    #[cfg(not(feature = "cuda"))]
13    NoCudaFeature,
14
15    /// Device ordinal is out of range.
16    InvalidDevice { ordinal: usize, count: usize },
17
18    /// Tried to operate on buffers from different devices.
19    DeviceMismatch { expected: usize, got: usize },
20
21    /// GPU out of memory. Contains the requested size and the free bytes at
22    /// the time of the failed allocation.
23    OutOfMemory {
24        requested_bytes: usize,
25        free_bytes: usize,
26    },
27
28    /// Allocation rejected because it would exceed the user-configured memory
29    /// budget (see [`crate::memory_guard::MemoryGuard::set_budget`]).
30    BudgetExceeded {
31        requested_bytes: usize,
32        budget_bytes: usize,
33        used_bytes: usize,
34    },
35
36    /// Binary op received buffers with different lengths.
37    LengthMismatch { a: usize, b: usize },
38
39    /// Matrix multiplication shape mismatch (inner dimensions differ).
40    ShapeMismatch {
41        op: &'static str,
42        expected: Vec<usize>,
43        got: Vec<usize>,
44    },
45
46    /// cuBLAS error forwarded from cudarc.
47    #[cfg(feature = "cuda")]
48    Blas(cudarc::cublas::result::CublasError),
49
50    /// PTX kernel compilation failed (e.g. unsupported GPU architecture).
51    PtxCompileFailed { kernel: &'static str },
52
53    /// An operation was attempted in an invalid state (e.g., capture on a
54    /// sealed pool, or cuSOLVER reported a negative info value).
55    InvalidState { message: String },
56}
57
58impl fmt::Display for GpuError {
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        match self {
61            #[cfg(feature = "cuda")]
62            GpuError::Driver(e) => write!(f, "CUDA driver error: {e}"),
63
64            #[cfg(not(feature = "cuda"))]
65            GpuError::NoCudaFeature => {
66                write!(f, "GPU operations require the `cuda` feature")
67            }
68
69            GpuError::InvalidDevice { ordinal, count } => {
70                write!(
71                    f,
72                    "invalid device ordinal {ordinal} (only {count} devices available)"
73                )
74            }
75
76            GpuError::DeviceMismatch { expected, got } => {
77                write!(
78                    f,
79                    "device mismatch: expected cuda:{expected}, got cuda:{got}"
80                )
81            }
82
83            GpuError::OutOfMemory {
84                requested_bytes,
85                free_bytes,
86            } => {
87                write!(
88                    f,
89                    "GPU out of memory: requested {requested_bytes} bytes but only \
90                     {free_bytes} bytes free"
91                )
92            }
93
94            GpuError::BudgetExceeded {
95                requested_bytes,
96                budget_bytes,
97                used_bytes,
98            } => {
99                write!(
100                    f,
101                    "memory budget exceeded: requested {requested_bytes} bytes, \
102                     budget is {budget_bytes} bytes with {used_bytes} bytes already used"
103                )
104            }
105
106            GpuError::LengthMismatch { a, b } => {
107                write!(f, "buffer length mismatch: {a} vs {b}")
108            }
109
110            GpuError::ShapeMismatch { op, expected, got } => {
111                write!(
112                    f,
113                    "{op}: shape mismatch, expected {expected:?}, got {got:?}"
114                )
115            }
116
117            #[cfg(feature = "cuda")]
118            GpuError::Blas(e) => write!(f, "cuBLAS error: {e}"),
119
120            GpuError::PtxCompileFailed { kernel } => {
121                write!(f, "PTX kernel compilation failed: {kernel}")
122            }
123
124            GpuError::InvalidState { message } => {
125                write!(f, "invalid state: {message}")
126            }
127        }
128    }
129}
130
131impl std::error::Error for GpuError {
132    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
133        match self {
134            #[cfg(feature = "cuda")]
135            GpuError::Driver(e) => Some(e),
136            #[cfg(feature = "cuda")]
137            GpuError::Blas(e) => Some(e),
138            _ => None,
139        }
140    }
141}
142
143#[cfg(feature = "cuda")]
144impl From<cudarc::driver::DriverError> for GpuError {
145    fn from(e: cudarc::driver::DriverError) -> Self {
146        GpuError::Driver(e)
147    }
148}
149
150#[cfg(feature = "cuda")]
151impl From<cudarc::cublas::result::CublasError> for GpuError {
152    fn from(e: cudarc::cublas::result::CublasError) -> Self {
153        GpuError::Blas(e)
154    }
155}
156
157/// Convenience alias for GPU results.
158pub type GpuResult<T> = Result<T, GpuError>;