use thiserror::Error;
pub type GpuResult<T> = Result<T, GpuError>;
#[derive(Debug, Error)]
pub enum GpuError {
#[error("No suitable GPU adapter found. Ensure a GPU with compute capabilities is available.")]
NoAdapter,
#[error("No compatible GPU device found: {0}")]
NoDevice(String),
#[error("Failed to create GPU device: {0}")]
DeviceCreation(String),
#[error("Failed to request GPU device: {0}")]
DeviceRequestFailed(String),
#[error("Shader compilation failed: {0}")]
ShaderCompilation(String),
#[error("Buffer allocation failed: {0}")]
BufferAllocation(String),
#[error("Buffer allocation failed: requested {requested_bytes} bytes, reason: {reason}")]
BufferAllocationFailed {
requested_bytes: u64,
reason: String,
},
#[error("Buffer size {size} exceeds maximum allowed {max}")]
BufferTooLarge {
size: u64,
max: u64,
},
#[error("Buffer size mismatch: expected {expected}, got {actual}")]
BufferSizeMismatch { expected: usize, actual: usize },
#[error("Buffer read-back failed: {0}")]
BufferReadFailed(String),
#[error("Buffer mapping failed: {0}")]
BufferMapFailed(String),
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch { expected: usize, actual: usize },
#[error("Invalid binding configuration: expected {expected} bindings, got {actual}")]
InvalidBindingCount {
expected: usize,
actual: usize,
},
#[error("Invalid workgroup configuration: [{x}, {y}, {z}] exceeds device limits")]
InvalidWorkgroupSize {
x: u32,
y: u32,
z: u32,
},
#[error("Failed to create compute pipeline: {0}")]
PipelineCreation(String),
#[error("Command encoding failed: {0}")]
CommandEncoding(String),
#[error("GPU execution failed: {0}")]
ExecutionFailed(String),
#[error("Failed to read buffer: {0}")]
BufferRead(String),
#[error("Failed to write buffer: {0}")]
BufferWrite(String),
#[error("GPU operation timed out after {0}ms")]
Timeout(u64),
#[error("Graph has no edges to compute")]
EmptyGraph,
#[error("Invalid GPU configuration: {0}")]
InvalidConfig(String),
#[error("GPU feature not supported: {0}")]
UnsupportedFeature(String),
#[error("Failed to request GPU adapter: {0}")]
AdapterRequest(String),
#[error("Out of GPU memory: requested {requested_bytes} bytes")]
OutOfMemory {
requested_bytes: u64,
},
#[error("GPU device lost: {0}")]
DeviceLost(String),
#[error("Internal GPU error: {0}")]
Internal(String),
}
impl GpuError {
pub fn should_fallback(&self) -> bool {
matches!(
self,
GpuError::NoAdapter
| GpuError::NoDevice(_)
| GpuError::DeviceCreation(_)
| GpuError::DeviceRequestFailed(_)
| GpuError::AdapterRequest(_)
| GpuError::UnsupportedFeature(_)
)
}
pub fn is_recoverable(&self) -> bool {
matches!(
self,
GpuError::Timeout(_)
| GpuError::BufferRead(_)
| GpuError::BufferReadFailed(_)
| GpuError::ExecutionFailed(_)
)
}
}
impl From<wgpu::RequestDeviceError> for GpuError {
fn from(e: wgpu::RequestDeviceError) -> Self {
Self::DeviceRequestFailed(e.to_string())
}
}
impl From<wgpu::BufferAsyncError> for GpuError {
fn from(e: wgpu::BufferAsyncError) -> Self {
Self::BufferMapFailed(e.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_should_fallback() {
assert!(GpuError::NoAdapter.should_fallback());
assert!(GpuError::NoDevice("test".into()).should_fallback());
assert!(GpuError::DeviceCreation("test".into()).should_fallback());
assert!(!GpuError::Timeout(100).should_fallback());
assert!(!GpuError::EmptyGraph.should_fallback());
}
#[test]
fn test_is_recoverable() {
assert!(GpuError::Timeout(100).is_recoverable());
assert!(GpuError::BufferRead("test".into()).is_recoverable());
assert!(GpuError::BufferReadFailed("test".into()).is_recoverable());
assert!(!GpuError::NoDevice("test".into()).is_recoverable());
assert!(!GpuError::NoAdapter.is_recoverable());
}
#[test]
fn test_error_display() {
let err = GpuError::BufferAllocationFailed {
requested_bytes: 1024,
reason: "out of memory".to_string(),
};
assert!(err.to_string().contains("1024"));
assert!(err.to_string().contains("out of memory"));
}
#[test]
fn test_workgroup_error() {
let err = GpuError::InvalidWorkgroupSize {
x: 1000,
y: 1,
z: 1,
};
let msg = err.to_string();
assert!(msg.contains("1000"));
}
}