pub use crate::error::{BackendError, BackendResult as CudaResult, ErrorContext, ErrorContextExt};
pub use torsh_core::error::TorshError;
pub type CudaError = TorshError;
pub mod cuda_errors {
use super::*;
pub fn runtime_error(
cuda_error: impl std::fmt::Display,
device_id: Option<usize>,
) -> TorshError {
let mut context = ErrorContext::new("cuda_runtime")
.with_backend("CUDA")
.with_details(cuda_error.to_string());
if let Some(device_id) = device_id {
context = context.with_device(format!("cuda:{}", device_id));
}
TorshError::ComputeError(context.format())
}
pub fn device_error(message: impl Into<String>, device_id: Option<usize>) -> TorshError {
let mut context = ErrorContext::new("cuda_device")
.with_backend("CUDA")
.with_details(message.into());
if let Some(device_id) = device_id {
context = context.with_device(format!("cuda:{}", device_id));
}
TorshError::BackendError(context.format())
}
pub fn memory_error(message: impl Into<String>, device_id: Option<usize>) -> TorshError {
let mut context = ErrorContext::new("cuda_memory")
.with_backend("CUDA")
.with_details(message.into());
if let Some(device_id) = device_id {
context = context.with_device(format!("cuda:{}", device_id));
}
TorshError::AllocationError(context.format())
}
pub fn kernel_launch_error(message: impl Into<String>, device_id: Option<usize>) -> TorshError {
let mut context = ErrorContext::new("cuda_kernel_launch")
.with_backend("CUDA")
.with_details(message.into());
if let Some(device_id) = device_id {
context = context.with_device(format!("cuda:{}", device_id));
}
TorshError::ComputeError(context.format())
}
pub fn stream_error(message: impl Into<String>, device_id: Option<usize>) -> TorshError {
let mut context = ErrorContext::new("cuda_stream")
.with_backend("CUDA")
.with_details(message.into());
if let Some(device_id) = device_id {
context = context.with_device(format!("cuda:{}", device_id));
}
TorshError::ComputeError(context.format())
}
pub fn cudnn_error(message: impl Into<String>, device_id: Option<usize>) -> TorshError {
let mut context = ErrorContext::new("cudnn")
.with_backend("CUDA")
.with_details(message.into());
if let Some(device_id) = device_id {
context = context.with_device(format!("cuda:{}", device_id));
}
TorshError::ComputeError(context.format())
}
pub fn cublas_error(message: impl Into<String>, device_id: Option<usize>) -> TorshError {
let mut context = ErrorContext::new("cublas")
.with_backend("CUDA")
.with_details(message.into());
if let Some(device_id) = device_id {
context = context.with_device(format!("cuda:{}", device_id));
}
TorshError::ComputeError(context.format())
}
pub fn nccl_error(message: impl Into<String>, device_id: Option<usize>) -> TorshError {
let mut context = ErrorContext::new("nccl")
.with_backend("CUDA")
.with_details(message.into());
if let Some(device_id) = device_id {
context = context.with_device(format!("cuda:{}", device_id));
}
TorshError::ComputeError(context.format())
}
pub fn invalid_device_error(device_id: usize) -> TorshError {
let context = ErrorContext::new("device_validation")
.with_backend("CUDA")
.with_device(format!("cuda:{}", device_id))
.with_details(format!("Invalid device ID: {}", device_id));
TorshError::InvalidArgument(context.format())
}
pub fn out_of_memory_error(requested: usize, device_id: Option<usize>) -> TorshError {
let mut context = ErrorContext::new("memory_allocation")
.with_backend("CUDA")
.with_details(format!("Out of memory: requested {} bytes", requested));
if let Some(device_id) = device_id {
context = context.with_device(format!("cuda:{}", device_id));
}
TorshError::AllocationError(context.format())
}
pub fn unsupported_operation_error(
operation: impl Into<String>,
device_id: Option<usize>,
) -> TorshError {
let mut context = ErrorContext::new("operation_validation")
.with_backend("CUDA")
.with_details(format!("Unsupported operation: {}", operation.into()));
if let Some(device_id) = device_id {
context = context.with_device(format!("cuda:{}", device_id));
}
TorshError::InvalidArgument(context.format())
}
pub fn context_error(message: impl Into<String>, device_id: Option<usize>) -> TorshError {
let mut context = ErrorContext::new("cuda_context")
.with_backend("CUDA")
.with_details(message.into());
if let Some(device_id) = device_id {
context = context.with_device(format!("cuda:{}", device_id));
}
TorshError::BackendError(context.format())
}
}
pub fn cuda_error_to_torsh(error: cust::error::CudaError) -> TorshError {
TorshError::ComputeError(format!("CUDA error: {}", error))
}
pub fn cuda_error_to_backend(error: cust::error::CudaError) -> crate::BackendError {
crate::BackendError::Runtime {
message: format!("CUDA error: {}", error),
}
}
pub trait CustResultExt<T> {
fn cuda_err(self) -> Result<T, TorshError>;
fn backend_err(self) -> Result<T, crate::BackendError>;
fn cuda_result(self) -> CudaResult<T>;
}
impl<T> CustResultExt<T> for Result<T, cust::error::CudaError> {
fn cuda_err(self) -> Result<T, TorshError> {
self.map_err(cuda_error_to_torsh)
}
fn backend_err(self) -> Result<T, crate::BackendError> {
self.map_err(cuda_error_to_backend)
}
fn cuda_result(self) -> CudaResult<T> {
self.map_err(cuda_error_to_torsh)
}
}