1use lazy_static::lazy_static;
8use std::sync::{Arc, Mutex};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
12pub enum ErrorAction {
13 Ignore,
15 #[default]
17 Warn,
18 Raise,
20 Call,
22}
23
24impl std::fmt::Display for ErrorAction {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 match self {
27 ErrorAction::Ignore => write!(f, "ignore"),
28 ErrorAction::Warn => write!(f, "warn"),
29 ErrorAction::Raise => write!(f, "raise"),
30 ErrorAction::Call => write!(f, "call"),
31 }
32 }
33}
34
35impl std::str::FromStr for ErrorAction {
36 type Err = String;
37
38 fn from_str(s: &str) -> Result<Self, Self::Err> {
39 match s.to_lowercase().as_str() {
40 "ignore" => Ok(ErrorAction::Ignore),
41 "warn" => Ok(ErrorAction::Warn),
42 "raise" => Ok(ErrorAction::Raise),
43 "call" => Ok(ErrorAction::Call),
44 _ => Err(format!("Invalid error action: {}", s)),
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
51pub struct ErrorState {
52 pub divide: ErrorAction,
54 pub over: ErrorAction,
56 pub under: ErrorAction,
58 pub invalid: ErrorAction,
60}
61
62impl Default for ErrorState {
63 fn default() -> Self {
64 Self {
65 divide: ErrorAction::Warn,
66 over: ErrorAction::Warn,
67 under: ErrorAction::Ignore,
68 invalid: ErrorAction::Warn,
69 }
70 }
71}
72
73impl ErrorState {
74 pub fn new(action: ErrorAction) -> Self {
76 Self {
77 divide: action,
78 over: action,
79 under: action,
80 invalid: action,
81 }
82 }
83
84 pub fn with_actions(
86 divide: ErrorAction,
87 over: ErrorAction,
88 under: ErrorAction,
89 invalid: ErrorAction,
90 ) -> Self {
91 Self {
92 divide,
93 over,
94 under,
95 invalid,
96 }
97 }
98}
99
100lazy_static! {
102 static ref GLOBAL_ERROR_STATE: Arc<Mutex<ErrorState>> =
103 Arc::new(Mutex::new(ErrorState::default()));
104}
105
106pub type ErrorCallback = Arc<dyn Fn(&str) + Send + Sync>;
108
109lazy_static! {
111 static ref GLOBAL_ERROR_CALLBACK: Arc<Mutex<Option<ErrorCallback>>> =
112 Arc::new(Mutex::new(None));
113}
114
115pub fn seterr(
147 all: Option<ErrorAction>,
148 divide: Option<ErrorAction>,
149 over: Option<ErrorAction>,
150 under: Option<ErrorAction>,
151 invalid: Option<ErrorAction>,
152) -> ErrorState {
153 let mut state = GLOBAL_ERROR_STATE
154 .lock()
155 .expect("global error state lock should not be poisoned");
156 let old_state = state.clone();
157
158 if let Some(action) = all {
159 *state = ErrorState::new(action);
160 }
161
162 if let Some(action) = divide {
163 state.divide = action;
164 }
165 if let Some(action) = over {
166 state.over = action;
167 }
168 if let Some(action) = under {
169 state.under = action;
170 }
171 if let Some(action) = invalid {
172 state.invalid = action;
173 }
174
175 old_state
176}
177
178pub fn geterr() -> ErrorState {
196 GLOBAL_ERROR_STATE
197 .lock()
198 .expect("global error state lock should not be poisoned")
199 .clone()
200}
201
202pub fn seterrcall(callback: Option<ErrorCallback>) {
223 let mut cb = GLOBAL_ERROR_CALLBACK
224 .lock()
225 .expect("global error callback lock should not be poisoned");
226 *cb = callback;
227}
228
229pub fn geterrcall() -> Option<ErrorCallback> {
235 GLOBAL_ERROR_CALLBACK
236 .lock()
237 .expect("global error callback lock should not be poisoned")
238 .clone()
239}
240
241#[derive(Debug, Clone, Copy, PartialEq, Eq)]
243pub enum FloatingPointError {
244 DivideByZero,
246 Overflow,
248 Underflow,
250 Invalid,
252}
253
254impl std::fmt::Display for FloatingPointError {
255 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
256 match self {
257 FloatingPointError::DivideByZero => write!(f, "divide by zero"),
258 FloatingPointError::Overflow => write!(f, "overflow"),
259 FloatingPointError::Underflow => write!(f, "underflow"),
260 FloatingPointError::Invalid => write!(f, "invalid operation"),
261 }
262 }
263}
264
265pub fn handle_error(error_type: FloatingPointError, message: &str) -> bool {
276 let state = geterr();
277 let action = match error_type {
278 FloatingPointError::DivideByZero => state.divide,
279 FloatingPointError::Overflow => state.over,
280 FloatingPointError::Underflow => state.under,
281 FloatingPointError::Invalid => state.invalid,
282 };
283
284 match action {
285 ErrorAction::Ignore => true,
286 ErrorAction::Warn => {
287 eprintln!("Warning: {} - {}", error_type, message);
288 true
289 }
290 ErrorAction::Raise => {
291 panic!("NumRS2 Error: {} - {}", error_type, message);
292 }
293 ErrorAction::Call => {
294 if let Some(callback) = geterrcall() {
295 callback(&format!("{} - {}", error_type, message));
296 }
297 true
298 }
299 }
300}
301
302pub struct ErrorStateGuard {
323 old_state: ErrorState,
324}
325
326impl Drop for ErrorStateGuard {
327 fn drop(&mut self) {
328 let mut state = GLOBAL_ERROR_STATE
330 .lock()
331 .expect("global error state lock should not be poisoned");
332 *state = self.old_state.clone();
333 }
334}
335
336pub struct ErrorStateBuilder {
338 divide: Option<ErrorAction>,
339 over: Option<ErrorAction>,
340 under: Option<ErrorAction>,
341 invalid: Option<ErrorAction>,
342}
343
344impl ErrorStateBuilder {
345 pub fn divide(mut self, action: ErrorAction) -> Self {
347 self.divide = Some(action);
348 self
349 }
350
351 pub fn over(mut self, action: ErrorAction) -> Self {
353 self.over = Some(action);
354 self
355 }
356
357 pub fn under(mut self, action: ErrorAction) -> Self {
359 self.under = Some(action);
360 self
361 }
362
363 pub fn invalid(mut self, action: ErrorAction) -> Self {
365 self.invalid = Some(action);
366 self
367 }
368
369 pub fn enter(self) -> ErrorStateGuard {
375 let old_state = seterr(None, self.divide, self.over, self.under, self.invalid);
376 ErrorStateGuard { old_state }
377 }
378}
379
380pub fn errstate() -> ErrorStateBuilder {
401 ErrorStateBuilder {
402 divide: None,
403 over: None,
404 under: None,
405 invalid: None,
406 }
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412 use serial_test::serial;
413
414 #[test]
415 fn test_error_state_default() {
416 let state = ErrorState::default();
417 assert_eq!(state.divide, ErrorAction::Warn);
418 assert_eq!(state.over, ErrorAction::Warn);
419 assert_eq!(state.under, ErrorAction::Ignore);
420 assert_eq!(state.invalid, ErrorAction::Warn);
421 }
422
423 #[test]
424 #[serial]
425 fn test_seterr_geterr() {
426 let original_state = geterr();
428
429 let _old_state = seterr(Some(ErrorAction::Raise), None, None, None, None);
431
432 let current_state = geterr();
434 assert_eq!(current_state.divide, ErrorAction::Raise);
435 assert_eq!(current_state.over, ErrorAction::Raise);
436 assert_eq!(current_state.under, ErrorAction::Raise);
437 assert_eq!(current_state.invalid, ErrorAction::Raise);
438
439 seterr(
441 None,
442 Some(original_state.divide),
443 Some(original_state.over),
444 Some(original_state.under),
445 Some(original_state.invalid),
446 );
447 }
448
449 #[test]
450 #[serial]
451 fn test_errstate_context_manager() {
452 let original_state = geterr();
453
454 {
455 let _guard = errstate()
456 .divide(ErrorAction::Ignore)
457 .over(ErrorAction::Raise)
458 .enter();
459
460 let current_state = geterr();
461 assert_eq!(current_state.divide, ErrorAction::Ignore);
462 assert_eq!(current_state.over, ErrorAction::Raise);
463 assert_eq!(current_state.under, original_state.under);
465 assert_eq!(current_state.invalid, original_state.invalid);
466 }
467
468 let restored_state = geterr();
470 assert_eq!(restored_state.divide, original_state.divide);
471 assert_eq!(restored_state.over, original_state.over);
472 assert_eq!(restored_state.under, original_state.under);
473 assert_eq!(restored_state.invalid, original_state.invalid);
474 }
475
476 #[test]
477 fn test_error_action_from_str() {
478 assert_eq!(
479 "ignore"
480 .parse::<ErrorAction>()
481 .expect("'ignore' should parse to ErrorAction::Ignore"),
482 ErrorAction::Ignore
483 );
484 assert_eq!(
485 "warn"
486 .parse::<ErrorAction>()
487 .expect("'warn' should parse to ErrorAction::Warn"),
488 ErrorAction::Warn
489 );
490 assert_eq!(
491 "raise"
492 .parse::<ErrorAction>()
493 .expect("'raise' should parse to ErrorAction::Raise"),
494 ErrorAction::Raise
495 );
496 assert_eq!(
497 "call"
498 .parse::<ErrorAction>()
499 .expect("'call' should parse to ErrorAction::Call"),
500 ErrorAction::Call
501 );
502
503 assert!("invalid".parse::<ErrorAction>().is_err());
504 }
505
506 #[test]
507 #[serial]
508 fn test_handle_error_ignore() {
509 let _guard = errstate().divide(ErrorAction::Ignore).enter();
510 let result = handle_error(FloatingPointError::DivideByZero, "test error");
511 assert!(result); }
513
514 #[test]
515 #[serial]
516 fn test_handle_error_warn() {
517 let _guard = errstate().divide(ErrorAction::Warn).enter();
518 let result = handle_error(FloatingPointError::DivideByZero, "test error");
519 assert!(result); }
521
522 #[test]
523 #[serial]
524 #[should_panic(expected = "NumRS2 Error: divide by zero - test error")]
525 fn test_handle_error_raise() {
526 let _guard = errstate().divide(ErrorAction::Raise).enter();
527 handle_error(FloatingPointError::DivideByZero, "test error");
528 }
529}