use lazy_static::lazy_static;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ErrorAction {
Ignore,
#[default]
Warn,
Raise,
Call,
}
impl std::fmt::Display for ErrorAction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ErrorAction::Ignore => write!(f, "ignore"),
ErrorAction::Warn => write!(f, "warn"),
ErrorAction::Raise => write!(f, "raise"),
ErrorAction::Call => write!(f, "call"),
}
}
}
impl std::str::FromStr for ErrorAction {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"ignore" => Ok(ErrorAction::Ignore),
"warn" => Ok(ErrorAction::Warn),
"raise" => Ok(ErrorAction::Raise),
"call" => Ok(ErrorAction::Call),
_ => Err(format!("Invalid error action: {}", s)),
}
}
}
#[derive(Debug, Clone)]
pub struct ErrorState {
pub divide: ErrorAction,
pub over: ErrorAction,
pub under: ErrorAction,
pub invalid: ErrorAction,
}
impl Default for ErrorState {
fn default() -> Self {
Self {
divide: ErrorAction::Warn,
over: ErrorAction::Warn,
under: ErrorAction::Ignore,
invalid: ErrorAction::Warn,
}
}
}
impl ErrorState {
pub fn new(action: ErrorAction) -> Self {
Self {
divide: action,
over: action,
under: action,
invalid: action,
}
}
pub fn with_actions(
divide: ErrorAction,
over: ErrorAction,
under: ErrorAction,
invalid: ErrorAction,
) -> Self {
Self {
divide,
over,
under,
invalid,
}
}
}
lazy_static! {
static ref GLOBAL_ERROR_STATE: Arc<Mutex<ErrorState>> =
Arc::new(Mutex::new(ErrorState::default()));
}
pub type ErrorCallback = Arc<dyn Fn(&str) + Send + Sync>;
lazy_static! {
static ref GLOBAL_ERROR_CALLBACK: Arc<Mutex<Option<ErrorCallback>>> =
Arc::new(Mutex::new(None));
}
pub fn seterr(
all: Option<ErrorAction>,
divide: Option<ErrorAction>,
over: Option<ErrorAction>,
under: Option<ErrorAction>,
invalid: Option<ErrorAction>,
) -> ErrorState {
let mut state = GLOBAL_ERROR_STATE
.lock()
.expect("global error state lock should not be poisoned");
let old_state = state.clone();
if let Some(action) = all {
*state = ErrorState::new(action);
}
if let Some(action) = divide {
state.divide = action;
}
if let Some(action) = over {
state.over = action;
}
if let Some(action) = under {
state.under = action;
}
if let Some(action) = invalid {
state.invalid = action;
}
old_state
}
pub fn geterr() -> ErrorState {
GLOBAL_ERROR_STATE
.lock()
.expect("global error state lock should not be poisoned")
.clone()
}
pub fn seterrcall(callback: Option<ErrorCallback>) {
let mut cb = GLOBAL_ERROR_CALLBACK
.lock()
.expect("global error callback lock should not be poisoned");
*cb = callback;
}
pub fn geterrcall() -> Option<ErrorCallback> {
GLOBAL_ERROR_CALLBACK
.lock()
.expect("global error callback lock should not be poisoned")
.clone()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FloatingPointError {
DivideByZero,
Overflow,
Underflow,
Invalid,
}
impl std::fmt::Display for FloatingPointError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FloatingPointError::DivideByZero => write!(f, "divide by zero"),
FloatingPointError::Overflow => write!(f, "overflow"),
FloatingPointError::Underflow => write!(f, "underflow"),
FloatingPointError::Invalid => write!(f, "invalid operation"),
}
}
}
pub fn handle_error(error_type: FloatingPointError, message: &str) -> bool {
let state = geterr();
let action = match error_type {
FloatingPointError::DivideByZero => state.divide,
FloatingPointError::Overflow => state.over,
FloatingPointError::Underflow => state.under,
FloatingPointError::Invalid => state.invalid,
};
match action {
ErrorAction::Ignore => true,
ErrorAction::Warn => {
eprintln!("Warning: {} - {}", error_type, message);
true
}
ErrorAction::Raise => {
panic!("NumRS2 Error: {} - {}", error_type, message);
}
ErrorAction::Call => {
if let Some(callback) = geterrcall() {
callback(&format!("{} - {}", error_type, message));
}
true
}
}
}
pub struct ErrorStateGuard {
old_state: ErrorState,
}
impl Drop for ErrorStateGuard {
fn drop(&mut self) {
let mut state = GLOBAL_ERROR_STATE
.lock()
.expect("global error state lock should not be poisoned");
*state = self.old_state.clone();
}
}
pub struct ErrorStateBuilder {
divide: Option<ErrorAction>,
over: Option<ErrorAction>,
under: Option<ErrorAction>,
invalid: Option<ErrorAction>,
}
impl ErrorStateBuilder {
pub fn divide(mut self, action: ErrorAction) -> Self {
self.divide = Some(action);
self
}
pub fn over(mut self, action: ErrorAction) -> Self {
self.over = Some(action);
self
}
pub fn under(mut self, action: ErrorAction) -> Self {
self.under = Some(action);
self
}
pub fn invalid(mut self, action: ErrorAction) -> Self {
self.invalid = Some(action);
self
}
pub fn enter(self) -> ErrorStateGuard {
let old_state = seterr(None, self.divide, self.over, self.under, self.invalid);
ErrorStateGuard { old_state }
}
}
pub fn errstate() -> ErrorStateBuilder {
ErrorStateBuilder {
divide: None,
over: None,
under: None,
invalid: None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
#[test]
fn test_error_state_default() {
let state = ErrorState::default();
assert_eq!(state.divide, ErrorAction::Warn);
assert_eq!(state.over, ErrorAction::Warn);
assert_eq!(state.under, ErrorAction::Ignore);
assert_eq!(state.invalid, ErrorAction::Warn);
}
#[test]
#[serial]
fn test_seterr_geterr() {
let original_state = geterr();
let _old_state = seterr(Some(ErrorAction::Raise), None, None, None, None);
let current_state = geterr();
assert_eq!(current_state.divide, ErrorAction::Raise);
assert_eq!(current_state.over, ErrorAction::Raise);
assert_eq!(current_state.under, ErrorAction::Raise);
assert_eq!(current_state.invalid, ErrorAction::Raise);
seterr(
None,
Some(original_state.divide),
Some(original_state.over),
Some(original_state.under),
Some(original_state.invalid),
);
}
#[test]
#[serial]
fn test_errstate_context_manager() {
let original_state = geterr();
{
let _guard = errstate()
.divide(ErrorAction::Ignore)
.over(ErrorAction::Raise)
.enter();
let current_state = geterr();
assert_eq!(current_state.divide, ErrorAction::Ignore);
assert_eq!(current_state.over, ErrorAction::Raise);
assert_eq!(current_state.under, original_state.under);
assert_eq!(current_state.invalid, original_state.invalid);
}
let restored_state = geterr();
assert_eq!(restored_state.divide, original_state.divide);
assert_eq!(restored_state.over, original_state.over);
assert_eq!(restored_state.under, original_state.under);
assert_eq!(restored_state.invalid, original_state.invalid);
}
#[test]
fn test_error_action_from_str() {
assert_eq!(
"ignore"
.parse::<ErrorAction>()
.expect("'ignore' should parse to ErrorAction::Ignore"),
ErrorAction::Ignore
);
assert_eq!(
"warn"
.parse::<ErrorAction>()
.expect("'warn' should parse to ErrorAction::Warn"),
ErrorAction::Warn
);
assert_eq!(
"raise"
.parse::<ErrorAction>()
.expect("'raise' should parse to ErrorAction::Raise"),
ErrorAction::Raise
);
assert_eq!(
"call"
.parse::<ErrorAction>()
.expect("'call' should parse to ErrorAction::Call"),
ErrorAction::Call
);
assert!("invalid".parse::<ErrorAction>().is_err());
}
#[test]
#[serial]
fn test_handle_error_ignore() {
let _guard = errstate().divide(ErrorAction::Ignore).enter();
let result = handle_error(FloatingPointError::DivideByZero, "test error");
assert!(result); }
#[test]
#[serial]
fn test_handle_error_warn() {
let _guard = errstate().divide(ErrorAction::Warn).enter();
let result = handle_error(FloatingPointError::DivideByZero, "test error");
assert!(result); }
#[test]
#[serial]
#[should_panic(expected = "NumRS2 Error: divide by zero - test error")]
fn test_handle_error_raise() {
let _guard = errstate().divide(ErrorAction::Raise).enter();
handle_error(FloatingPointError::DivideByZero, "test error");
}
}