oxicuda_blas/error.rs
1//! Error types for OxiCUDA BLAS operations.
2//!
3//! Provides [`BlasError`] covering all failure modes for GPU-accelerated
4//! BLAS routines — dimension mismatches, buffer validation, PTX generation,
5//! kernel launch issues, and underlying CUDA driver errors.
6
7use oxicuda_driver::CudaError;
8use thiserror::Error;
9
10/// BLAS-specific error type.
11///
12/// Every fallible BLAS operation returns [`BlasResult<T>`] which uses this
13/// enum as its error variant. The variants are ordered roughly by how early
14/// in the call chain they are likely to appear: argument validation first,
15/// then PTX/launch errors, then driver-level failures.
16#[derive(Debug, Error)]
17pub enum BlasError {
18 /// A CUDA driver call failed.
19 #[error("CUDA driver error: {0}")]
20 Cuda(#[from] CudaError),
21
22 /// A matrix or vector dimension is invalid (e.g. zero rows).
23 #[error("invalid matrix dimensions: {0}")]
24 InvalidDimension(String),
25
26 /// A device buffer is too small for the requested operation.
27 #[error("buffer too small: expected at least {expected} elements, got {actual}")]
28 BufferTooSmall {
29 /// Minimum number of elements required.
30 expected: usize,
31 /// Actual number of elements in the buffer.
32 actual: usize,
33 },
34
35 /// Two operands have incompatible dimensions (e.g. inner dims of A and B
36 /// in a GEMM do not match).
37 #[error("dimension mismatch: {0}")]
38 DimensionMismatch(String),
39
40 /// The requested operation or precision is not supported on this device.
41 #[error("unsupported operation: {0}")]
42 UnsupportedOperation(String),
43
44 /// PTX kernel source generation failed.
45 #[error("PTX generation error: {0}")]
46 PtxGeneration(String),
47
48 /// A kernel launch (grid/block configuration or driver call) failed.
49 #[error("kernel launch failed: {0}")]
50 LaunchFailed(String),
51
52 /// A caller-provided argument is invalid.
53 #[error("invalid argument: {0}")]
54 InvalidArgument(String),
55
56 /// The autotuner encountered an error while profiling kernel variants.
57 #[error("autotuner error: {0}")]
58 AutotuneError(String),
59}
60
61// -- Conversions from dependency error types -----------------------------------
62
63impl From<oxicuda_ptx::PtxGenError> for BlasError {
64 fn from(err: oxicuda_ptx::PtxGenError) -> Self {
65 Self::PtxGeneration(err.to_string())
66 }
67}
68
69/// Convenience alias used throughout the BLAS crate.
70pub type BlasResult<T> = Result<T, BlasError>;
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75
76 #[test]
77 fn display_buffer_too_small() {
78 let err = BlasError::BufferTooSmall {
79 expected: 1024,
80 actual: 512,
81 };
82 assert!(err.to_string().contains("1024"));
83 assert!(err.to_string().contains("512"));
84 }
85
86 #[test]
87 fn display_dimension_mismatch() {
88 let err = BlasError::DimensionMismatch("A.cols != B.rows".to_string());
89 assert!(err.to_string().contains("A.cols != B.rows"));
90 }
91
92 #[test]
93 fn from_cuda_error() {
94 let cuda_err = CudaError::NotInitialized;
95 let blas_err: BlasError = cuda_err.into();
96 assert!(matches!(blas_err, BlasError::Cuda(_)));
97 }
98}