1use std::panic;
9
10pub fn install_panic_hook() {
13 let default_hook = panic::take_hook();
14 panic::set_hook(Box::new(move |info| {
15 let location = info.location().map_or_else(
17 || "unknown location".to_string(),
18 |loc| format!("{}:{}:{}", loc.file(), loc.line(), loc.column()),
19 );
20 let message = if let Some(s) = info.payload().downcast_ref::<&str>() {
21 (*s).to_string()
22 } else if let Some(s) = info.payload().downcast_ref::<String>() {
23 s.clone()
24 } else {
25 "unknown panic payload".to_string()
26 };
27
28 eprintln!("[entrenar::panic] at {location}: {message}");
29
30 default_hook(info);
32 }));
33}
34
35pub fn catch_training_panic<F, T>(op: F) -> Option<T>
40where
41 F: FnOnce() -> T + panic::UnwindSafe,
42{
43 match panic::catch_unwind(op) {
44 Ok(result) => Some(result),
45 Err(_) => {
46 eprintln!("[entrenar::safety] Training operation panicked, returning None");
47 None
48 }
49 }
50}
51
52#[cfg(test)]
53mod tests {
54 use super::*;
55
56 #[test]
57 fn test_catch_training_panic_success() {
58 let result = catch_training_panic(|| 42);
59 assert_eq!(result, Some(42));
60 }
61
62 #[test]
63 fn test_catch_training_panic_failure() {
64 let result = catch_training_panic(|| -> i32 { panic!("test panic") });
65 assert_eq!(result, None);
66 }
67
68 #[test]
69 fn test_install_panic_hook_does_not_panic() {
70 install_panic_hook();
72 let _ = panic::take_hook();
74 }
75
76 #[test]
77 fn test_catch_training_panic_with_string_result() {
78 let result = catch_training_panic(|| "hello".to_string());
79 assert_eq!(result, Some("hello".to_string()));
80 }
81
82 #[test]
83 fn test_catch_training_panic_with_vec_result() {
84 let result = catch_training_panic(|| vec![1, 2, 3]);
85 assert_eq!(result, Some(vec![1, 2, 3]));
86 }
87
88 #[test]
89 fn test_catch_training_panic_with_option_result() {
90 let result = catch_training_panic(|| Some(42));
91 assert_eq!(result, Some(Some(42)));
92 }
93
94 #[test]
95 fn test_catch_training_panic_string_payload() {
96 let result = catch_training_panic(|| -> i32 {
98 panic!("{}", "formatted panic message".to_string());
99 });
100 assert_eq!(result, None);
101 }
102
103 #[test]
104 fn test_catch_training_panic_complex_computation() {
105 let result = catch_training_panic(|| {
106 let mut sum = 0;
107 for i in 0..100 {
108 sum += i;
109 }
110 sum
111 });
112 assert_eq!(result, Some(4950));
113 }
114
115 #[test]
116 fn test_catch_training_panic_unit_return() {
117 let result = catch_training_panic(|| {
118 });
120 assert_eq!(result, Some(()));
121 }
122
123 #[test]
124 fn test_catch_training_panic_bool_return() {
125 let result = catch_training_panic(|| true);
126 assert_eq!(result, Some(true));
127 }
128
129 #[test]
130 fn test_catch_training_panic_nested_panic() {
131 let result = catch_training_panic(|| -> i32 {
132 let _inner = catch_training_panic(|| -> i32 {
133 panic!("inner panic");
134 });
135 99
137 });
138 assert_eq!(result, Some(99));
139 }
140
141 #[test]
142 fn test_install_panic_hook_idempotent() {
143 install_panic_hook();
145 install_panic_hook();
146 let _ = panic::take_hook();
148 }
149
150 #[test]
151 fn test_catch_training_panic_after_hook_install() {
152 install_panic_hook();
153 let result = catch_training_panic(|| -> i32 {
154 panic!("test after hook install");
155 });
156 assert_eq!(result, None);
157 let _ = panic::take_hook();
159 }
160
161 #[test]
162 fn test_catch_training_panic_float_result() {
163 let result = catch_training_panic(|| 3.14f64);
164 assert_eq!(result, Some(3.14f64));
165 }
166
167 #[test]
168 fn test_catch_training_panic_tuple_result() {
169 let result = catch_training_panic(|| (42, "hello"));
170 assert_eq!(result, Some((42, "hello")));
171 }
172}