use num_traits::Float;
use rand::distr::uniform::Error as UniformError;
use rand_distr::{ExpError, NormalError, PoissonError};
use std::fmt::Debug;
use thiserror::Error;
pub type XResult<T> = Result<T, XError>;
#[derive(Error, Debug, Clone)]
pub enum XError {
#[error("Sample Uniform Distribution Error: {0}")]
UniformSampleError(#[from] UniformError),
#[error("Sample Normal Distribution Error: {0}")]
NormalSampleError(#[from] NormalError),
#[error("Sample Poisson Distribution Error: {0}")]
PoissonSampleError(#[from] PoissonError),
#[error("Sample Exponential Distribution Error: {0}")]
ExpSampleError(#[from] ExpError),
#[error("Sample Stable Distribution Error: {0}")]
StableSampleError(#[from] StableError),
#[error("Probability must be between 0 and 1")]
BoolSampleError,
#[error("Simulate Error: {0}")]
SimulateError(#[from] SimulationError),
#[cfg(feature = "visualize")]
#[error("Visualization Error: {0}")]
VisualizationError(String),
#[error("Invalid parameters: {0}")]
InvalidParameters(String),
#[error("Circulant embedding matrix is not positive definite, eigenvalue: {0}")]
NotPositiveDefinite(f64),
#[error("{0}")]
FFTError(String),
#[cfg(feature = "io")]
#[error("CSV IO Error: {0}")]
CSVError(String),
#[cfg(feature = "io")]
#[error("CSV Write Error: {0}")]
CSVWriteError(String),
#[cfg(feature = "visualize")]
#[error("Plotter Config Error: {0}")]
BuilderError(String),
#[cfg(feature = "cuda")]
#[error(transparent)]
CUDAError(#[from] cudarc::driver::result::DriverError),
#[cfg(feature = "cuda")]
#[error(transparent)]
SystemTimeError(#[from] std::time::SystemTimeError),
#[cfg(feature = "cuda")]
#[error(transparent)]
CURANDError(#[from] cudarc::curand::result::CurandError),
#[error("Error: {0}")]
Other(String),
}
impl From<realfft::FftError> for XError {
fn from(value: realfft::FftError) -> Self {
XError::FFTError(value.to_string())
}
}
#[derive(Error, Debug, PartialEq, Eq, Clone, Copy)]
pub enum StableError {
#[error("Index of stability must be in the range (0, 2]")]
InvalidIndex,
#[error("Skewness parameter must be in the range [-1, 1]")]
InvalidSkewness,
#[error("Scale parameter must be positive")]
InvalidScale,
#[error("Location parameter must be a real number")]
InvalidLocation,
#[error("Index of skewness must be in the range (0, 1)")]
InvalidSkewIndex,
}
#[derive(Error, Debug, PartialEq, Eq, Clone)]
pub enum SimulationError {
#[error("Invalid parameters: {0}")]
InvalidParameters(String),
#[error("Invalid time step: {0}")]
InvalidTimeStep(String),
#[error("Invalid time interval: {0}")]
InvalidTimeInterval(String),
#[error("Unknown error, simulation failed")]
Unknown,
#[error("No result available")]
NoResult,
}
#[cfg(feature = "visualize")]
#[derive(Error, Debug)]
pub enum PlotterError {
#[error("Config Error: {0}")]
ConfigError(String),
#[error("Invalid color: {0}")]
InvalidColor(String),
#[error("Plot Error: {0}")]
DrawingError(String),
}
#[cfg(feature = "visualize")]
impl<E: std::error::Error + Send + Sync> From<plotters::drawing::DrawingAreaErrorKind<E>>
for XError
{
fn from(err: plotters::drawing::DrawingAreaErrorKind<E>) -> Self {
XError::VisualizationError(err.to_string())
}
}
#[cfg(feature = "visualize")]
impl From<PlotterError> for XError {
fn from(err: PlotterError) -> Self {
XError::VisualizationError(err.to_string())
}
}
#[cfg(feature = "visualize")]
impl From<crate::visualize::config::PlotConfigBuilderError> for XError {
fn from(err: crate::visualize::config::PlotConfigBuilderError) -> Self {
XError::BuilderError(err.to_string())
}
}
#[cfg(feature = "io")]
impl From<csv::Error> for XError {
fn from(err: csv::Error) -> Self {
XError::CSVError(err.to_string())
}
}
#[cfg(feature = "io")]
impl From<std::io::Error> for XError {
fn from(err: std::io::Error) -> Self {
XError::CSVWriteError(err.to_string())
}
}
impl From<&XError> for XError {
fn from(value: &XError) -> Self {
value.clone()
}
}
#[inline]
pub(crate) fn check_duration_time_step<T: Float + Debug>(duration: T, time_step: T) -> XResult<()> {
if duration <= T::zero() {
return Err(SimulationError::InvalidParameters(format!(
"The `duration` must be positive, got {duration:?}"
))
.into());
}
if time_step <= T::zero() {
return Err(SimulationError::InvalidParameters(format!(
"The `time_step` must be positive, got `{time_step:?}`"
))
.into());
}
if time_step > duration {
return Err(SimulationError::InvalidParameters(format!(
"The `time_step` must be less than or equal to the `duration`, got `{time_step:?}` > `{duration:?}`"
))
.into());
}
Ok(())
}