Skip to main content

entrenar/
safety.rs

1//! Panic safety and graceful degradation
2//!
3//! Provides panic hooks and catch_unwind wrappers for training operations
4//! to prevent data corruption on panic.
5//!
6//! Batuta: SF-08 (Panic Safety)
7
8use std::panic;
9
10/// Install a panic hook that logs structured panic information
11/// and ensures checkpoint state is not corrupted.
12pub fn install_panic_hook() {
13    let default_hook = panic::take_hook();
14    panic::set_hook(Box::new(move |info| {
15        // Log the panic with structured info
16        let location = info.location().map_or_else(
17            || "unknown location".to_string(),
18            |loc| format!("{}:{}:{}", loc.file(), loc.line(), loc.column()),
19        );
20        let message = if let Some(s) = info.payload().downcast_ref::<&str>() {
21            (*s).to_string()
22        } else if let Some(s) = info.payload().downcast_ref::<String>() {
23            s.clone()
24        } else {
25            "unknown panic payload".to_string()
26        };
27
28        eprintln!("[entrenar::panic] at {location}: {message}");
29
30        // Call the default hook for normal panic output
31        default_hook(info);
32    }));
33}
34
35/// Run a training operation with panic safety, returning None on panic.
36///
37/// This prevents panics from propagating through FFI boundaries
38/// or corrupting shared state.
39pub fn catch_training_panic<F, T>(op: F) -> Option<T>
40where
41    F: FnOnce() -> T + panic::UnwindSafe,
42{
43    match panic::catch_unwind(op) {
44        Ok(result) => Some(result),
45        Err(_) => {
46            eprintln!("[entrenar::safety] Training operation panicked, returning None");
47            None
48        }
49    }
50}
51
52#[cfg(test)]
53mod tests {
54    use super::*;
55
56    #[test]
57    fn test_catch_training_panic_success() {
58        let result = catch_training_panic(|| 42);
59        assert_eq!(result, Some(42));
60    }
61
62    #[test]
63    fn test_catch_training_panic_failure() {
64        let result = catch_training_panic(|| -> i32 { panic!("test panic") });
65        assert_eq!(result, None);
66    }
67
68    #[test]
69    fn test_install_panic_hook_does_not_panic() {
70        // Just verify it doesn't crash
71        install_panic_hook();
72        // Restore default hook
73        let _ = panic::take_hook();
74    }
75
76    #[test]
77    fn test_catch_training_panic_with_string_result() {
78        let result = catch_training_panic(|| "hello".to_string());
79        assert_eq!(result, Some("hello".to_string()));
80    }
81
82    #[test]
83    fn test_catch_training_panic_with_vec_result() {
84        let result = catch_training_panic(|| vec![1, 2, 3]);
85        assert_eq!(result, Some(vec![1, 2, 3]));
86    }
87
88    #[test]
89    fn test_catch_training_panic_with_option_result() {
90        let result = catch_training_panic(|| Some(42));
91        assert_eq!(result, Some(Some(42)));
92    }
93
94    #[test]
95    fn test_catch_training_panic_string_payload() {
96        // Test panic with a String payload (not &str)
97        let result = catch_training_panic(|| -> i32 {
98            panic!("{}", "formatted panic message".to_string());
99        });
100        assert_eq!(result, None);
101    }
102
103    #[test]
104    fn test_catch_training_panic_complex_computation() {
105        let result = catch_training_panic(|| {
106            let mut sum = 0;
107            for i in 0..100 {
108                sum += i;
109            }
110            sum
111        });
112        assert_eq!(result, Some(4950));
113    }
114
115    #[test]
116    fn test_catch_training_panic_unit_return() {
117        let result = catch_training_panic(|| {
118            // Operation that returns ()
119        });
120        assert_eq!(result, Some(()));
121    }
122
123    #[test]
124    fn test_catch_training_panic_bool_return() {
125        let result = catch_training_panic(|| true);
126        assert_eq!(result, Some(true));
127    }
128
129    #[test]
130    fn test_catch_training_panic_nested_panic() {
131        let result = catch_training_panic(|| -> i32 {
132            let _inner = catch_training_panic(|| -> i32 {
133                panic!("inner panic");
134            });
135            // inner panic is caught, outer should succeed
136            99
137        });
138        assert_eq!(result, Some(99));
139    }
140
141    #[test]
142    fn test_install_panic_hook_idempotent() {
143        // Installing the hook twice should not crash
144        install_panic_hook();
145        install_panic_hook();
146        // Restore default hook
147        let _ = panic::take_hook();
148    }
149
150    #[test]
151    fn test_catch_training_panic_after_hook_install() {
152        install_panic_hook();
153        let result = catch_training_panic(|| -> i32 {
154            panic!("test after hook install");
155        });
156        assert_eq!(result, None);
157        // Restore default hook
158        let _ = panic::take_hook();
159    }
160
161    #[test]
162    fn test_catch_training_panic_float_result() {
163        let result = catch_training_panic(|| 3.14f64);
164        assert_eq!(result, Some(3.14f64));
165    }
166
167    #[test]
168    fn test_catch_training_panic_tuple_result() {
169        let result = catch_training_panic(|| (42, "hello"));
170        assert_eq!(result, Some((42, "hello")));
171    }
172}