1use oxicuda_driver::CudaError;
8use oxicuda_ptx::error::PtxGenError;
9use thiserror::Error;
10
11#[derive(Debug, Error)]
16pub enum RandError {
17 #[error("CUDA driver error: {0}")]
19 Cuda(#[from] CudaError),
20
21 #[error("PTX generation error: {0}")]
23 PtxGeneration(#[from] PtxGenError),
24
25 #[error("invalid output size: {0}")]
28 InvalidSize(String),
29
30 #[error("invalid seed: {0}")]
32 InvalidSeed(String),
33
34 #[error("invalid parameter: {0}")]
36 InvalidParameter(String),
37
38 #[error("unsupported distribution: {0}")]
41 UnsupportedDistribution(String),
42
43 #[error("internal error: {0}")]
45 InternalError(String),
46}
47
48pub type RandResult<T> = Result<T, RandError>;
50
51#[cfg(test)]
52mod tests {
53 use super::*;
54
55 #[test]
56 fn display_invalid_size() {
57 let err = RandError::InvalidSize("must be a multiple of 4".to_string());
58 assert!(err.to_string().contains("multiple of 4"));
59 }
60
61 #[test]
62 fn display_unsupported_distribution() {
63 let err = RandError::UnsupportedDistribution("poisson f64".to_string());
64 assert!(err.to_string().contains("poisson f64"));
65 }
66
67 #[test]
68 fn display_invalid_parameter() {
69 let err = RandError::InvalidParameter("lambda must be >= 0".to_string());
70 assert!(err.to_string().contains("lambda must be >= 0"));
71 }
72
73 #[test]
74 fn from_cuda_error() {
75 let cuda_err = CudaError::NotInitialized;
76 let rand_err: RandError = cuda_err.into();
77 assert!(matches!(rand_err, RandError::Cuda(_)));
78 }
79}