use std::cell::RefCell;
use ferray_core::error::FerrayError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FpErrorClass {
DivideByZero,
Overflow,
Underflow,
Invalid,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FpErrorState {
Ignore,
Warn,
Raise,
}
#[derive(Debug, Clone, Copy)]
struct Policies {
divide_by_zero: FpErrorState,
overflow: FpErrorState,
underflow: FpErrorState,
invalid: FpErrorState,
}
impl Default for Policies {
fn default() -> Self {
Self {
divide_by_zero: FpErrorState::Ignore,
overflow: FpErrorState::Ignore,
underflow: FpErrorState::Ignore,
invalid: FpErrorState::Warn,
}
}
}
thread_local! {
static POLICIES: RefCell<Policies> = RefCell::new(Policies::default());
static EVENT_LOG: RefCell<Vec<FpErrorClass>> = const { RefCell::new(Vec::new()) };
static PENDING_RAISES: RefCell<Vec<FpErrorClass>> = const { RefCell::new(Vec::new()) };
}
#[must_use]
pub struct ErrstateGuard {
previous: Policies,
}
impl Drop for ErrstateGuard {
fn drop(&mut self) {
POLICIES.with(|p| *p.borrow_mut() = self.previous);
}
}
pub fn seterr(class: FpErrorClass, state: FpErrorState) -> FpErrorState {
POLICIES.with(|p| {
let mut policies = p.borrow_mut();
let previous = match class {
FpErrorClass::DivideByZero => policies.divide_by_zero,
FpErrorClass::Overflow => policies.overflow,
FpErrorClass::Underflow => policies.underflow,
FpErrorClass::Invalid => policies.invalid,
};
match class {
FpErrorClass::DivideByZero => policies.divide_by_zero = state,
FpErrorClass::Overflow => policies.overflow = state,
FpErrorClass::Underflow => policies.underflow = state,
FpErrorClass::Invalid => policies.invalid = state,
}
previous
})
}
#[must_use]
pub fn geterr(class: FpErrorClass) -> FpErrorState {
POLICIES.with(|p| {
let policies = p.borrow();
match class {
FpErrorClass::DivideByZero => policies.divide_by_zero,
FpErrorClass::Overflow => policies.overflow,
FpErrorClass::Underflow => policies.underflow,
FpErrorClass::Invalid => policies.invalid,
}
})
}
pub fn with_errstate<F, R>(overrides: &[(FpErrorClass, FpErrorState)], f: F) -> R
where
F: FnOnce() -> R,
{
let previous = POLICIES.with(|p| *p.borrow());
let _guard = ErrstateGuard { previous };
for &(class, state) in overrides {
seterr(class, state);
}
f()
}
pub fn record_fp_event(class: FpErrorClass) {
let policy = geterr(class);
match policy {
FpErrorState::Ignore => {}
FpErrorState::Warn => {
EVENT_LOG.with(|log| log.borrow_mut().push(class));
}
FpErrorState::Raise => {
PENDING_RAISES.with(|q| q.borrow_mut().push(class));
}
}
}
pub fn take_fp_events() -> Vec<FpErrorClass> {
EVENT_LOG.with(|log| std::mem::take(&mut *log.borrow_mut()))
}
pub fn check_fp_errors() -> Result<(), FerrayError> {
let pending = PENDING_RAISES.with(|q| std::mem::take(&mut *q.borrow_mut()));
if pending.is_empty() {
return Ok(());
}
Err(FerrayError::invalid_value(format!(
"floating-point exception(s) raised: {pending:?}"
)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_policy_invalid_is_warn() {
seterr(FpErrorClass::DivideByZero, FpErrorState::Ignore);
seterr(FpErrorClass::Overflow, FpErrorState::Ignore);
seterr(FpErrorClass::Underflow, FpErrorState::Ignore);
seterr(FpErrorClass::Invalid, FpErrorState::Warn);
assert_eq!(geterr(FpErrorClass::Invalid), FpErrorState::Warn);
assert_eq!(geterr(FpErrorClass::DivideByZero), FpErrorState::Ignore);
}
#[test]
fn seterr_changes_and_returns_previous() {
let _ = take_fp_events();
let prev = seterr(FpErrorClass::DivideByZero, FpErrorState::Raise);
assert_eq!(geterr(FpErrorClass::DivideByZero), FpErrorState::Raise);
seterr(FpErrorClass::DivideByZero, prev);
assert_eq!(geterr(FpErrorClass::DivideByZero), prev);
}
#[test]
fn record_fp_event_warn_appends_to_log() {
let _ = take_fp_events();
seterr(FpErrorClass::Overflow, FpErrorState::Warn);
record_fp_event(FpErrorClass::Overflow);
record_fp_event(FpErrorClass::Overflow);
let events = take_fp_events();
assert_eq!(events, vec![FpErrorClass::Overflow, FpErrorClass::Overflow]);
assert!(take_fp_events().is_empty());
seterr(FpErrorClass::Overflow, FpErrorState::Ignore);
}
#[test]
fn record_fp_event_ignore_is_noop() {
let _ = take_fp_events();
seterr(FpErrorClass::Underflow, FpErrorState::Ignore);
record_fp_event(FpErrorClass::Underflow);
assert!(take_fp_events().is_empty());
}
#[test]
fn record_fp_event_raise_queues_for_check() {
let _ = check_fp_errors();
seterr(FpErrorClass::Invalid, FpErrorState::Raise);
record_fp_event(FpErrorClass::Invalid);
let err = check_fp_errors().unwrap_err();
assert!(err.to_string().contains("Invalid"));
assert!(check_fp_errors().is_ok());
seterr(FpErrorClass::Invalid, FpErrorState::Warn);
}
#[test]
fn with_errstate_scopes_overrides() {
let _ = take_fp_events();
seterr(FpErrorClass::DivideByZero, FpErrorState::Ignore);
let inner_state =
with_errstate(&[(FpErrorClass::DivideByZero, FpErrorState::Warn)], || {
record_fp_event(FpErrorClass::DivideByZero);
geterr(FpErrorClass::DivideByZero)
});
assert_eq!(inner_state, FpErrorState::Warn);
assert_eq!(geterr(FpErrorClass::DivideByZero), FpErrorState::Ignore);
let events = take_fp_events();
assert_eq!(events, vec![FpErrorClass::DivideByZero]);
}
#[test]
fn with_errstate_restores_on_panic() {
seterr(FpErrorClass::Overflow, FpErrorState::Ignore);
let result = std::panic::catch_unwind(|| {
with_errstate(&[(FpErrorClass::Overflow, FpErrorState::Raise)], || {
panic!("forced panic")
})
});
assert!(result.is_err());
assert_eq!(geterr(FpErrorClass::Overflow), FpErrorState::Ignore);
}
}