use oxicuda_driver::CudaError;
use oxicuda_ptx::error::PtxGenError;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum RandError {
#[error("CUDA driver error: {0}")]
Cuda(#[from] CudaError),
#[error("PTX generation error: {0}")]
PtxGeneration(#[from] PtxGenError),
#[error("invalid output size: {0}")]
InvalidSize(String),
#[error("invalid seed: {0}")]
InvalidSeed(String),
#[error("unsupported distribution: {0}")]
UnsupportedDistribution(String),
#[error("internal error: {0}")]
InternalError(String),
}
pub type RandResult<T> = Result<T, RandError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display_invalid_size() {
let err = RandError::InvalidSize("must be a multiple of 4".to_string());
assert!(err.to_string().contains("multiple of 4"));
}
#[test]
fn display_unsupported_distribution() {
let err = RandError::UnsupportedDistribution("poisson f64".to_string());
assert!(err.to_string().contains("poisson f64"));
}
#[test]
fn from_cuda_error() {
let cuda_err = CudaError::NotInitialized;
let rand_err: RandError = cuda_err.into();
assert!(matches!(rand_err, RandError::Cuda(_)));
}
}