ext_php_rs/zend/
try_catch.rs1use crate::ffi::{
2    ext_php_rs_zend_bailout, ext_php_rs_zend_first_try_catch, ext_php_rs_zend_try_catch,
3};
4use std::ffi::c_void;
5use std::panic::{catch_unwind, resume_unwind, RefUnwindSafe};
6use std::ptr::null_mut;
7
8#[derive(Debug)]
10pub struct CatchError;
11
12pub(crate) unsafe extern "C" fn panic_wrapper<R, F: FnMut() -> R + RefUnwindSafe>(
13    ctx: *const c_void,
14) -> *const c_void {
15    let panic = catch_unwind(|| (*(ctx as *mut F))());
18
19    Box::into_raw(Box::new(panic)).cast::<c_void>()
20}
21
22pub fn try_catch<R, F: FnMut() -> R + RefUnwindSafe>(func: F) -> Result<R, CatchError> {
37    do_try_catch(func, false)
38}
39
40pub fn try_catch_first<R, F: FnMut() -> R + RefUnwindSafe>(func: F) -> Result<R, CatchError> {
58    do_try_catch(func, true)
59}
60
61fn do_try_catch<R, F: FnMut() -> R + RefUnwindSafe>(func: F, first: bool) -> Result<R, CatchError> {
62    let mut panic_ptr = null_mut();
63    let has_bailout = unsafe {
64        if first {
65            ext_php_rs_zend_first_try_catch(
66                panic_wrapper::<R, F>,
67                (&raw const func).cast::<c_void>(),
68                &raw mut panic_ptr,
69            )
70        } else {
71            ext_php_rs_zend_try_catch(
72                panic_wrapper::<R, F>,
73                (&raw const func).cast::<c_void>(),
74                &raw mut panic_ptr,
75            )
76        }
77    };
78
79    let panic = panic_ptr.cast::<std::thread::Result<R>>();
80
81    if panic.is_null() || has_bailout {
83        return Err(CatchError);
84    }
85
86    match unsafe { *Box::from_raw(panic.cast::<std::thread::Result<R>>()) } {
87        Ok(r) => Ok(r),
88        Err(err) => {
89            resume_unwind(err);
91        }
92    }
93}
94
95pub unsafe fn bailout() -> ! {
109    ext_php_rs_zend_bailout();
110}
111
112#[cfg(feature = "embed")]
113#[cfg(test)]
114mod tests {
115    use crate::embed::Embed;
116    use crate::zend::{bailout, try_catch};
117    use std::ptr::null_mut;
118
119    #[test]
120    fn test_catch() {
121        Embed::run(|| {
122            let catch = try_catch(|| {
123                unsafe {
124                    bailout();
125                }
126
127                #[allow(unreachable_code)]
128                #[allow(clippy::assertions_on_constants)]
129                {
130                    assert!(false);
131                }
132            });
133
134            assert!(catch.is_err());
135        });
136    }
137
138    #[test]
139    fn test_no_catch() {
140        Embed::run(|| {
141            let catch = try_catch(|| {
142                #[allow(clippy::assertions_on_constants)]
143                {
144                    assert!(true);
145                }
146            });
147
148            assert!(catch.is_ok());
149        });
150    }
151
152    #[test]
153    fn test_bailout() {
154        Embed::run(|| {
155            unsafe {
156                bailout();
157            }
158
159            #[allow(unreachable_code)]
160            #[allow(clippy::assertions_on_constants)]
161            {
162                assert!(false);
163            }
164        });
165    }
166
167    #[test]
168    #[should_panic(expected = "should panic")]
169    fn test_panic() {
170        Embed::run(|| {
171            let _ = try_catch(|| {
172                panic!("should panic");
173            });
174        });
175    }
176
177    #[test]
178    fn test_return() {
179        let foo = Embed::run(|| {
180            let result = try_catch(|| "foo");
181
182            assert!(result.is_ok());
183
184            #[allow(clippy::unwrap_used)]
185            result.unwrap()
186        });
187
188        assert_eq!(foo, "foo");
189    }
190
191    #[test]
192    fn test_memory_leak() {
193        Embed::run(|| {
194            let mut ptr = null_mut();
195
196            let _ = try_catch(|| {
197                let mut result = "foo".to_string();
198                ptr = &raw mut result;
199
200                unsafe {
201                    bailout();
202                }
203            });
204
205            let result = unsafe { &*ptr as &str };
207
208            assert_eq!(result, "foo");
209        });
210    }
211}