use thiserror::Error;
#[derive(Debug, Error)]
pub enum CudaDispatchError {
#[error("CUDA driver error: {0}")]
Driver(#[from] oxicuda_driver::CudaError),
#[error("CUDA BLAS error: {0}")]
Blas(String),
#[error("CUDA DNN error: {0}")]
Dnn(String),
#[error("PTX generation error: {0}")]
Ptx(String),
#[error("Unsupported CUDA config for op '{op}': {reason}")]
Unsupported {
op: &'static str,
reason: String,
},
#[error("Shape error for op '{op}': {msg}")]
Shape {
op: &'static str,
msg: String,
},
}
impl From<CudaDispatchError> for oxionnx_core::OnnxError {
fn from(e: CudaDispatchError) -> Self {
Self::Internal(e.to_string())
}
}