Skip to main content

oxicuda_blas/
error.rs

1//! Error types for OxiCUDA BLAS operations.
2//!
3//! Provides [`BlasError`] covering all failure modes for GPU-accelerated
4//! BLAS routines — dimension mismatches, buffer validation, PTX generation,
5//! kernel launch issues, and underlying CUDA driver errors.
6
7use oxicuda_driver::CudaError;
8use thiserror::Error;
9
10/// BLAS-specific error type.
11///
12/// Every fallible BLAS operation returns [`BlasResult<T>`] which uses this
13/// enum as its error variant. The variants are ordered roughly by how early
14/// in the call chain they are likely to appear: argument validation first,
15/// then PTX/launch errors, then driver-level failures.
16#[derive(Debug, Error)]
17pub enum BlasError {
18    /// A CUDA driver call failed.
19    #[error("CUDA driver error: {0}")]
20    Cuda(#[from] CudaError),
21
22    /// A matrix or vector dimension is invalid (e.g. zero rows).
23    #[error("invalid matrix dimensions: {0}")]
24    InvalidDimension(String),
25
26    /// A device buffer is too small for the requested operation.
27    #[error("buffer too small: expected at least {expected} elements, got {actual}")]
28    BufferTooSmall {
29        /// Minimum number of elements required.
30        expected: usize,
31        /// Actual number of elements in the buffer.
32        actual: usize,
33    },
34
35    /// Two operands have incompatible dimensions (e.g. inner dims of A and B
36    /// in a GEMM do not match).
37    #[error("dimension mismatch: {0}")]
38    DimensionMismatch(String),
39
40    /// The requested operation or precision is not supported on this device.
41    #[error("unsupported operation: {0}")]
42    UnsupportedOperation(String),
43
44    /// PTX kernel source generation failed.
45    #[error("PTX generation error: {0}")]
46    PtxGeneration(String),
47
48    /// A kernel launch (grid/block configuration or driver call) failed.
49    #[error("kernel launch failed: {0}")]
50    LaunchFailed(String),
51
52    /// A caller-provided argument is invalid.
53    #[error("invalid argument: {0}")]
54    InvalidArgument(String),
55
56    /// The autotuner encountered an error while profiling kernel variants.
57    #[error("autotuner error: {0}")]
58    AutotuneError(String),
59}
60
61// -- Conversions from dependency error types -----------------------------------
62
63impl From<oxicuda_ptx::PtxGenError> for BlasError {
64    fn from(err: oxicuda_ptx::PtxGenError) -> Self {
65        Self::PtxGeneration(err.to_string())
66    }
67}
68
69/// Convenience alias used throughout the BLAS crate.
70pub type BlasResult<T> = Result<T, BlasError>;
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75
76    #[test]
77    fn display_buffer_too_small() {
78        let err = BlasError::BufferTooSmall {
79            expected: 1024,
80            actual: 512,
81        };
82        assert!(err.to_string().contains("1024"));
83        assert!(err.to_string().contains("512"));
84    }
85
86    #[test]
87    fn display_dimension_mismatch() {
88        let err = BlasError::DimensionMismatch("A.cols != B.rows".to_string());
89        assert!(err.to_string().contains("A.cols != B.rows"));
90    }
91
92    #[test]
93    fn from_cuda_error() {
94        let cuda_err = CudaError::NotInitialized;
95        let blas_err: BlasError = cuda_err.into();
96        assert!(matches!(blas_err, BlasError::Cuda(_)));
97    }
98}