Skip to main content

numrs2/
error_handling.rs

1//! Error handling configuration for NumRS2
2//!
3//! This module provides error handling configuration similar to NumPy's seterr, geterr, and errstate.
4//! It allows controlling how floating-point errors (like division by zero, overflow, underflow, etc.)
5//! are handled in NumRS2 operations.
6
7use lazy_static::lazy_static;
8use std::sync::{Arc, Mutex};
9
10/// Error handling behavior for different types of floating-point errors
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
12pub enum ErrorAction {
13    /// Ignore the error (continue execution)
14    Ignore,
15    /// Issue a warning but continue execution
16    #[default]
17    Warn,
18    /// Raise an exception/error
19    Raise,
20    /// Call a user-defined callback function
21    Call,
22}
23
24impl std::fmt::Display for ErrorAction {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        match self {
27            ErrorAction::Ignore => write!(f, "ignore"),
28            ErrorAction::Warn => write!(f, "warn"),
29            ErrorAction::Raise => write!(f, "raise"),
30            ErrorAction::Call => write!(f, "call"),
31        }
32    }
33}
34
35impl std::str::FromStr for ErrorAction {
36    type Err = String;
37
38    fn from_str(s: &str) -> Result<Self, Self::Err> {
39        match s.to_lowercase().as_str() {
40            "ignore" => Ok(ErrorAction::Ignore),
41            "warn" => Ok(ErrorAction::Warn),
42            "raise" => Ok(ErrorAction::Raise),
43            "call" => Ok(ErrorAction::Call),
44            _ => Err(format!("Invalid error action: {}", s)),
45        }
46    }
47}
48
49/// Configuration for error handling
50#[derive(Debug, Clone)]
51pub struct ErrorState {
52    /// How to handle division by zero
53    pub divide: ErrorAction,
54    /// How to handle overflow
55    pub over: ErrorAction,
56    /// How to handle underflow
57    pub under: ErrorAction,
58    /// How to handle invalid operations (like sqrt of negative number)
59    pub invalid: ErrorAction,
60}
61
62impl Default for ErrorState {
63    fn default() -> Self {
64        Self {
65            divide: ErrorAction::Warn,
66            over: ErrorAction::Warn,
67            under: ErrorAction::Ignore,
68            invalid: ErrorAction::Warn,
69        }
70    }
71}
72
73impl ErrorState {
74    /// Create a new error state with all actions set to the same value
75    pub fn new(action: ErrorAction) -> Self {
76        Self {
77            divide: action,
78            over: action,
79            under: action,
80            invalid: action,
81        }
82    }
83
84    /// Create a new error state with specific actions for each error type
85    pub fn with_actions(
86        divide: ErrorAction,
87        over: ErrorAction,
88        under: ErrorAction,
89        invalid: ErrorAction,
90    ) -> Self {
91        Self {
92            divide,
93            over,
94            under,
95            invalid,
96        }
97    }
98}
99
100// Global error state for NumRS2
101lazy_static! {
102    static ref GLOBAL_ERROR_STATE: Arc<Mutex<ErrorState>> =
103        Arc::new(Mutex::new(ErrorState::default()));
104}
105
106/// User-defined error callback function type
107pub type ErrorCallback = Arc<dyn Fn(&str) + Send + Sync>;
108
109// Global error callback
110lazy_static! {
111    static ref GLOBAL_ERROR_CALLBACK: Arc<Mutex<Option<ErrorCallback>>> =
112        Arc::new(Mutex::new(None));
113}
114
115/// Set the error handling behavior for floating-point errors
116///
117/// # Arguments
118///
119/// * `all` - Action to take for all error types (if specified, overrides individual settings)
120/// * `divide` - Action for division by zero
121/// * `over` - Action for overflow
122/// * `under` - Action for underflow
123/// * `invalid` - Action for invalid operations
124///
125/// # Returns
126///
127/// The previous error state
128///
129/// # Examples
130///
131/// ```
132/// use numrs2::error_handling::{seterr, ErrorAction};
133///
134/// // Set all errors to raise exceptions
135/// let old_state = seterr(Some(ErrorAction::Raise), None, None, None, None);
136///
137/// // Set specific error handling
138/// let old_state = seterr(
139///     None,
140///     Some(ErrorAction::Raise),  // Division by zero raises
141///     Some(ErrorAction::Warn),   // Overflow warns
142///     Some(ErrorAction::Ignore), // Underflow ignored
143///     Some(ErrorAction::Warn),   // Invalid warns
144/// );
145/// ```
146pub fn seterr(
147    all: Option<ErrorAction>,
148    divide: Option<ErrorAction>,
149    over: Option<ErrorAction>,
150    under: Option<ErrorAction>,
151    invalid: Option<ErrorAction>,
152) -> ErrorState {
153    let mut state = GLOBAL_ERROR_STATE
154        .lock()
155        .expect("global error state lock should not be poisoned");
156    let old_state = state.clone();
157
158    if let Some(action) = all {
159        *state = ErrorState::new(action);
160    }
161
162    if let Some(action) = divide {
163        state.divide = action;
164    }
165    if let Some(action) = over {
166        state.over = action;
167    }
168    if let Some(action) = under {
169        state.under = action;
170    }
171    if let Some(action) = invalid {
172        state.invalid = action;
173    }
174
175    old_state
176}
177
178/// Get the current error handling behavior
179///
180/// # Returns
181///
182/// The current error state
183///
184/// # Examples
185///
186/// ```
187/// use numrs2::error_handling::geterr;
188///
189/// let current_state = geterr();
190/// println!("Division by zero: {}", current_state.divide);
191/// println!("Overflow: {}", current_state.over);
192/// println!("Underflow: {}", current_state.under);
193/// println!("Invalid: {}", current_state.invalid);
194/// ```
195pub fn geterr() -> ErrorState {
196    GLOBAL_ERROR_STATE
197        .lock()
198        .expect("global error state lock should not be poisoned")
199        .clone()
200}
201
202/// Set the callback function for error handling
203///
204/// # Arguments
205///
206/// * `callback` - Function to call when errors occur (if ErrorAction::Call is set)
207///
208/// # Examples
209///
210/// ```
211/// use numrs2::error_handling::{seterrcall, ErrorAction, seterr};
212/// use std::sync::Arc;
213///
214/// // Set a custom error callback
215/// seterrcall(Some(Arc::new(|msg: &str| {
216///     eprintln!("NumRS2 Error: {}", msg);
217/// })));
218///
219/// // Configure to use the callback for division by zero
220/// seterr(None, Some(ErrorAction::Call), None, None, None);
221/// ```
222pub fn seterrcall(callback: Option<ErrorCallback>) {
223    let mut cb = GLOBAL_ERROR_CALLBACK
224        .lock()
225        .expect("global error callback lock should not be poisoned");
226    *cb = callback;
227}
228
229/// Get the current error callback
230///
231/// # Returns
232///
233/// The current error callback (if any)
234pub fn geterrcall() -> Option<ErrorCallback> {
235    GLOBAL_ERROR_CALLBACK
236        .lock()
237        .expect("global error callback lock should not be poisoned")
238        .clone()
239}
240
241/// Error types that can occur in floating-point operations
242#[derive(Debug, Clone, Copy, PartialEq, Eq)]
243pub enum FloatingPointError {
244    /// Division by zero
245    DivideByZero,
246    /// Overflow
247    Overflow,
248    /// Underflow
249    Underflow,
250    /// Invalid operation
251    Invalid,
252}
253
254impl std::fmt::Display for FloatingPointError {
255    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
256        match self {
257            FloatingPointError::DivideByZero => write!(f, "divide by zero"),
258            FloatingPointError::Overflow => write!(f, "overflow"),
259            FloatingPointError::Underflow => write!(f, "underflow"),
260            FloatingPointError::Invalid => write!(f, "invalid operation"),
261        }
262    }
263}
264
265/// Handle a floating-point error according to the current error state
266///
267/// # Arguments
268///
269/// * `error_type` - The type of error that occurred
270/// * `message` - Descriptive message about the error
271///
272/// # Returns
273///
274/// `true` if the operation should continue, `false` if it should abort
275pub fn handle_error(error_type: FloatingPointError, message: &str) -> bool {
276    let state = geterr();
277    let action = match error_type {
278        FloatingPointError::DivideByZero => state.divide,
279        FloatingPointError::Overflow => state.over,
280        FloatingPointError::Underflow => state.under,
281        FloatingPointError::Invalid => state.invalid,
282    };
283
284    match action {
285        ErrorAction::Ignore => true,
286        ErrorAction::Warn => {
287            eprintln!("Warning: {} - {}", error_type, message);
288            true
289        }
290        ErrorAction::Raise => {
291            panic!("NumRS2 Error: {} - {}", error_type, message);
292        }
293        ErrorAction::Call => {
294            if let Some(callback) = geterrcall() {
295                callback(&format!("{} - {}", error_type, message));
296            }
297            true
298        }
299    }
300}
301
302/// Context manager for temporarily changing error handling behavior
303///
304/// Similar to NumPy's errstate context manager. This allows you to temporarily
305/// change the error handling behavior and automatically restore it when done.
306///
307/// # Examples
308///
309/// ```
310/// use numrs2::error_handling::{errstate, ErrorAction};
311///
312/// {
313///     let _guard = errstate()
314///         .divide(ErrorAction::Ignore)
315///         .over(ErrorAction::Raise)
316///         .enter();
317///     
318///     // Operations here will ignore division by zero and raise on overflow
319///     // Error state is automatically restored when _guard goes out of scope
320/// }
321/// ```
322pub struct ErrorStateGuard {
323    old_state: ErrorState,
324}
325
326impl Drop for ErrorStateGuard {
327    fn drop(&mut self) {
328        // Restore the old error state
329        let mut state = GLOBAL_ERROR_STATE
330            .lock()
331            .expect("global error state lock should not be poisoned");
332        *state = self.old_state.clone();
333    }
334}
335
336/// Builder for creating temporary error state contexts
337pub struct ErrorStateBuilder {
338    divide: Option<ErrorAction>,
339    over: Option<ErrorAction>,
340    under: Option<ErrorAction>,
341    invalid: Option<ErrorAction>,
342}
343
344impl ErrorStateBuilder {
345    /// Set the action for division by zero errors
346    pub fn divide(mut self, action: ErrorAction) -> Self {
347        self.divide = Some(action);
348        self
349    }
350
351    /// Set the action for overflow errors
352    pub fn over(mut self, action: ErrorAction) -> Self {
353        self.over = Some(action);
354        self
355    }
356
357    /// Set the action for underflow errors
358    pub fn under(mut self, action: ErrorAction) -> Self {
359        self.under = Some(action);
360        self
361    }
362
363    /// Set the action for invalid operation errors
364    pub fn invalid(mut self, action: ErrorAction) -> Self {
365        self.invalid = Some(action);
366        self
367    }
368
369    /// Enter the error state context
370    ///
371    /// # Returns
372    ///
373    /// A guard that will restore the previous error state when dropped
374    pub fn enter(self) -> ErrorStateGuard {
375        let old_state = seterr(None, self.divide, self.over, self.under, self.invalid);
376        ErrorStateGuard { old_state }
377    }
378}
379
380/// Create a new error state context manager
381///
382/// # Returns
383///
384/// A builder for configuring the temporary error state
385///
386/// # Examples
387///
388/// ```
389/// use numrs2::error_handling::{errstate, ErrorAction};
390///
391/// {
392///     let _guard = errstate()
393///         .divide(ErrorAction::Ignore)
394///         .over(ErrorAction::Raise)
395///         .enter();
396///     
397///     // Your code here with modified error handling
398/// } // Error state automatically restored here
399/// ```
400pub fn errstate() -> ErrorStateBuilder {
401    ErrorStateBuilder {
402        divide: None,
403        over: None,
404        under: None,
405        invalid: None,
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412    use serial_test::serial;
413
414    #[test]
415    fn test_error_state_default() {
416        let state = ErrorState::default();
417        assert_eq!(state.divide, ErrorAction::Warn);
418        assert_eq!(state.over, ErrorAction::Warn);
419        assert_eq!(state.under, ErrorAction::Ignore);
420        assert_eq!(state.invalid, ErrorAction::Warn);
421    }
422
423    #[test]
424    #[serial]
425    fn test_seterr_geterr() {
426        // Save current state
427        let original_state = geterr();
428
429        // Set new state
430        let _old_state = seterr(Some(ErrorAction::Raise), None, None, None, None);
431
432        // Verify the new state
433        let current_state = geterr();
434        assert_eq!(current_state.divide, ErrorAction::Raise);
435        assert_eq!(current_state.over, ErrorAction::Raise);
436        assert_eq!(current_state.under, ErrorAction::Raise);
437        assert_eq!(current_state.invalid, ErrorAction::Raise);
438
439        // Restore original state
440        seterr(
441            None,
442            Some(original_state.divide),
443            Some(original_state.over),
444            Some(original_state.under),
445            Some(original_state.invalid),
446        );
447    }
448
449    #[test]
450    #[serial]
451    fn test_errstate_context_manager() {
452        let original_state = geterr();
453
454        {
455            let _guard = errstate()
456                .divide(ErrorAction::Ignore)
457                .over(ErrorAction::Raise)
458                .enter();
459
460            let current_state = geterr();
461            assert_eq!(current_state.divide, ErrorAction::Ignore);
462            assert_eq!(current_state.over, ErrorAction::Raise);
463            // Other values should remain unchanged
464            assert_eq!(current_state.under, original_state.under);
465            assert_eq!(current_state.invalid, original_state.invalid);
466        }
467
468        // State should be restored
469        let restored_state = geterr();
470        assert_eq!(restored_state.divide, original_state.divide);
471        assert_eq!(restored_state.over, original_state.over);
472        assert_eq!(restored_state.under, original_state.under);
473        assert_eq!(restored_state.invalid, original_state.invalid);
474    }
475
476    #[test]
477    fn test_error_action_from_str() {
478        assert_eq!(
479            "ignore"
480                .parse::<ErrorAction>()
481                .expect("'ignore' should parse to ErrorAction::Ignore"),
482            ErrorAction::Ignore
483        );
484        assert_eq!(
485            "warn"
486                .parse::<ErrorAction>()
487                .expect("'warn' should parse to ErrorAction::Warn"),
488            ErrorAction::Warn
489        );
490        assert_eq!(
491            "raise"
492                .parse::<ErrorAction>()
493                .expect("'raise' should parse to ErrorAction::Raise"),
494            ErrorAction::Raise
495        );
496        assert_eq!(
497            "call"
498                .parse::<ErrorAction>()
499                .expect("'call' should parse to ErrorAction::Call"),
500            ErrorAction::Call
501        );
502
503        assert!("invalid".parse::<ErrorAction>().is_err());
504    }
505
506    #[test]
507    #[serial]
508    fn test_handle_error_ignore() {
509        let _guard = errstate().divide(ErrorAction::Ignore).enter();
510        let result = handle_error(FloatingPointError::DivideByZero, "test error");
511        assert!(result); // Should continue
512    }
513
514    #[test]
515    #[serial]
516    fn test_handle_error_warn() {
517        let _guard = errstate().divide(ErrorAction::Warn).enter();
518        let result = handle_error(FloatingPointError::DivideByZero, "test error");
519        assert!(result); // Should continue but warn
520    }
521
522    #[test]
523    #[serial]
524    #[should_panic(expected = "NumRS2 Error: divide by zero - test error")]
525    fn test_handle_error_raise() {
526        let _guard = errstate().divide(ErrorAction::Raise).enter();
527        handle_error(FloatingPointError::DivideByZero, "test error");
528    }
529}