use oxicuda_blas::BlasError;
use oxicuda_driver::CudaError;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum DnnError {
#[error("CUDA driver error: {0}")]
Cuda(#[from] CudaError),
#[error("BLAS error: {0}")]
Blas(#[from] BlasError),
#[error("PTX generation error: {0}")]
PtxGeneration(String),
#[error("invalid tensor dimensions: {0}")]
InvalidDimension(String),
#[error("buffer too small: expected {expected} bytes, got {actual} bytes")]
BufferTooSmall {
expected: usize,
actual: usize,
},
#[error("unsupported operation: {0}")]
UnsupportedOperation(String),
#[error("invalid argument: {0}")]
InvalidArgument(String),
#[error("workspace required: need at least {0} bytes")]
WorkspaceRequired(usize),
#[error("kernel launch failed: {0}")]
LaunchFailed(String),
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
}
impl From<oxicuda_ptx::PtxGenError> for DnnError {
fn from(e: oxicuda_ptx::PtxGenError) -> Self {
Self::PtxGeneration(e.to_string())
}
}
pub type DnnResult<T> = Result<T, DnnError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display_buffer_too_small() {
let e = DnnError::BufferTooSmall {
expected: 4096,
actual: 1024,
};
assert!(e.to_string().contains("4096"));
assert!(e.to_string().contains("1024"));
}
#[test]
fn display_workspace_required() {
let e = DnnError::WorkspaceRequired(8192);
assert!(e.to_string().contains("8192"));
}
#[test]
fn from_cuda_error() {
let cuda_err = CudaError::InvalidValue;
let dnn_err: DnnError = cuda_err.into();
assert!(matches!(dnn_err, DnnError::Cuda(_)));
}
}