use oxicuda_driver::CudaError;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum BlasError {
#[error("CUDA driver error: {0}")]
Cuda(#[from] CudaError),
#[error("invalid matrix dimensions: {0}")]
InvalidDimension(String),
#[error("buffer too small: expected at least {expected} elements, got {actual}")]
BufferTooSmall {
expected: usize,
actual: usize,
},
#[error("dimension mismatch: {0}")]
DimensionMismatch(String),
#[error("unsupported operation: {0}")]
UnsupportedOperation(String),
#[error("PTX generation error: {0}")]
PtxGeneration(String),
#[error("kernel launch failed: {0}")]
LaunchFailed(String),
#[error("invalid argument: {0}")]
InvalidArgument(String),
#[error("autotuner error: {0}")]
AutotuneError(String),
}
impl From<oxicuda_ptx::PtxGenError> for BlasError {
fn from(err: oxicuda_ptx::PtxGenError) -> Self {
Self::PtxGeneration(err.to_string())
}
}
pub type BlasResult<T> = Result<T, BlasError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display_buffer_too_small() {
let err = BlasError::BufferTooSmall {
expected: 1024,
actual: 512,
};
assert!(err.to_string().contains("1024"));
assert!(err.to_string().contains("512"));
}
#[test]
fn display_dimension_mismatch() {
let err = BlasError::DimensionMismatch("A.cols != B.rows".to_string());
assert!(err.to_string().contains("A.cols != B.rows"));
}
#[test]
fn from_cuda_error() {
let cuda_err = CudaError::NotInitialized;
let blas_err: BlasError = cuda_err.into();
assert!(matches!(blas_err, BlasError::Cuda(_)));
}
}