alloc_safe/
alloc.rs

1//! Memory allocation error.
2
3use std::alloc::Layout;
4use std::collections::TryReserveError;
5use std::error::Error;
6use std::fmt;
7use std::panic::{PanicInfo, UnwindSafe};
8use std::sync::atomic::{AtomicBool, Ordering};
9
10/// The error type for allocation failure.
11#[derive(Copy, Clone)]
12#[repr(transparent)]
13pub struct AllocError(Layout);
14
15impl AllocError {
16    /// Creates a new `AllocError`.
17    #[must_use]
18    #[inline]
19    pub const fn new(layout: Layout) -> Self {
20        AllocError(layout)
21    }
22
23    /// Returns the memory layout of the `AllocError`.
24    #[must_use]
25    #[inline]
26    pub const fn layout(self) -> Layout {
27        self.0
28    }
29}
30
31impl fmt::Debug for AllocError {
32    #[inline]
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        f.debug_struct("AllocError")
35            .field("size", &self.0.size())
36            .field("align", &self.0.align())
37            .finish()
38    }
39}
40
41impl fmt::Display for AllocError {
42    #[inline]
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        write!(
45            f,
46            "failed to allocate memory by required layout {{size: {}, align: {}}}",
47            self.0.size(),
48            self.0.align()
49        )
50    }
51}
52
53impl Error for AllocError {}
54
55impl From<TryReserveError> for AllocError {
56    #[inline]
57    fn from(e: TryReserveError) -> Self {
58        use std::collections::TryReserveErrorKind;
59        match e.kind() {
60            TryReserveErrorKind::AllocError { layout, .. } => AllocError::new(layout),
61            TryReserveErrorKind::CapacityOverflow => {
62                unreachable!("unexpected capacity overflow")
63            }
64        }
65    }
66}
67
68fn alloc_error_hook(layout: Layout) {
69    std::panic::panic_any(AllocError(layout))
70}
71
72type PanicHook = Box<dyn Fn(&PanicInfo<'_>) + 'static + Sync + Send>;
73
74fn panic_hook(panic_info: &PanicInfo<'_>) {
75    // panic abort except alloc error
76    if !panic_info.payload().is::<AllocError>() {
77        std::process::abort();
78    }
79}
80
81/// Invokes a closure, capturing the panic of memory allocation error if one occurs.
82///
83/// This function will return `Ok` with the closure's result if the closure
84/// does not panic, and will return `AllocError` if allocation error occurs. The
85/// process will abort if other panics occur.
86///
87/// Notes that this function will set panic hook and alloc error hook.
88#[inline]
89pub fn catch_alloc_error<F: FnOnce() -> R + UnwindSafe, R>(f: F) -> Result<R, AllocError> {
90    static SET_HOOK: AtomicBool = AtomicBool::new(false);
91    if !SET_HOOK.load(Ordering::Acquire) {
92        let hook: PanicHook = Box::try_new(panic_hook)
93            .map_err(|_| AllocError::new(Layout::new::<fn(&PanicInfo)>()))?;
94        std::panic::set_hook(hook);
95        std::alloc::set_alloc_error_hook(alloc_error_hook);
96        SET_HOOK.store(true, Ordering::Release);
97    }
98
99    #[cfg(feature = "global-allocator")]
100    allocator::ThreadPanic::try_reserve_mem()?;
101
102    let result = std::panic::catch_unwind(f);
103    match result {
104        Ok(r) => Ok(r),
105        Err(e) => match e.downcast_ref::<AllocError>() {
106            None => unreachable!(),
107            Some(e) => Err(*e),
108        },
109    }
110}
111
112#[cfg(feature = "global-allocator")]
113mod allocator {
114    use crate::AllocError;
115    use std::alloc::{GlobalAlloc, Layout, System};
116    use std::cell::RefCell;
117    use std::ptr::NonNull;
118
119    #[global_allocator]
120    static GLOBAL: Alloc = Alloc;
121
122    struct Alloc;
123
124    unsafe impl GlobalAlloc for Alloc {
125        #[inline]
126        unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
127            let ptr = System.alloc(layout);
128
129            if ptr.is_null() && std::thread::panicking() {
130                if let Some(p) = ThreadPanic::take_mem(layout) {
131                    return p.as_ptr();
132                }
133            }
134
135            ptr
136        }
137
138        #[inline]
139        unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
140            System.dealloc(ptr, layout)
141        }
142    }
143
144    struct PanicMem {
145        // See core::panic::BoxMeUp
146        box_me_up: Option<NonNull<u8>>,
147
148        // Panic handler doesn't alloc memory for Exception in Windows.
149        #[cfg(not(target_os = "windows"))]
150        exception: Option<NonNull<u8>>,
151    }
152
153    impl PanicMem {
154        const BOX_ME_UP_LAYOUT: Layout = unsafe { Layout::from_size_align_unchecked(16, 8) };
155
156        #[cfg(not(target_os = "windows"))]
157        const EXCEPTION_LAYOUT: Layout = unsafe { Layout::from_size_align_unchecked(80, 8) };
158
159        #[inline]
160        const fn new() -> Self {
161            PanicMem {
162                box_me_up: None,
163
164                #[cfg(not(target_os = "windows"))]
165                exception: None,
166            }
167        }
168
169        #[inline]
170        fn try_reserve(&mut self) -> Result<(), AllocError> {
171            if self.box_me_up.is_none() {
172                let ptr = unsafe { System.alloc(PanicMem::BOX_ME_UP_LAYOUT) };
173                if ptr.is_null() {
174                    return Err(AllocError::new(PanicMem::BOX_ME_UP_LAYOUT));
175                } else {
176                    self.box_me_up = unsafe { Some(NonNull::new_unchecked(ptr)) };
177                }
178            }
179
180            #[cfg(not(target_os = "windows"))]
181            if self.exception.is_none() {
182                let ptr = unsafe { System.alloc(PanicMem::EXCEPTION_LAYOUT) };
183                if ptr.is_null() {
184                    return Err(AllocError::new(PanicMem::EXCEPTION_LAYOUT));
185                } else {
186                    self.exception = unsafe { Some(NonNull::new_unchecked(ptr)) };
187                }
188            }
189
190            Ok(())
191        }
192
193        #[inline]
194        fn take_mem(&mut self, layout: Layout) -> Option<NonNull<u8>> {
195            if layout == PanicMem::BOX_ME_UP_LAYOUT {
196                return self.box_me_up.take();
197            }
198
199            #[cfg(not(target_os = "windows"))]
200            if layout == PanicMem::EXCEPTION_LAYOUT {
201                return self.exception.take();
202            }
203
204            None
205        }
206    }
207
208    impl Drop for PanicMem {
209        #[inline]
210        fn drop(&mut self) {
211            if let Some(mut ptr) = self.box_me_up.take() {
212                unsafe { System.dealloc(ptr.as_mut(), PanicMem::BOX_ME_UP_LAYOUT) };
213            }
214
215            #[cfg(not(target_os = "windows"))]
216            if let Some(mut ptr) = self.exception.take() {
217                unsafe { System.dealloc(ptr.as_mut(), PanicMem::EXCEPTION_LAYOUT) };
218            }
219        }
220    }
221
222    thread_local! {
223        static THREAD_PANIC_MEM: RefCell<PanicMem> = RefCell::new(PanicMem::new());
224    }
225
226    pub struct ThreadPanic;
227
228    impl ThreadPanic {
229        #[inline]
230        pub fn try_reserve_mem() -> Result<(), AllocError> {
231            THREAD_PANIC_MEM.with(|panic_mem| panic_mem.borrow_mut().try_reserve())
232        }
233
234        #[inline]
235        pub fn take_mem(layout: Layout) -> Option<NonNull<u8>> {
236            THREAD_PANIC_MEM.with(|panic_mem| panic_mem.borrow_mut().take_mem(layout))
237        }
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::catch_alloc_error;
244    use std::alloc::{AllocError as StdAllocError, Allocator, Layout};
245    use std::ptr::NonNull;
246
247    struct NoMem;
248
249    unsafe impl Allocator for NoMem {
250        #[inline]
251        fn allocate(&self, _layout: Layout) -> Result<NonNull<[u8]>, StdAllocError> {
252            Err(StdAllocError)
253        }
254
255        #[inline]
256        unsafe fn deallocate(&self, _ptr: NonNull<u8>, _layout: Layout) {
257            unreachable!()
258        }
259    }
260
261    #[test]
262    fn test_catch_alloc_error() {
263        let result = catch_alloc_error(|| Vec::<u8, _>::with_capacity_in(10, NoMem));
264        assert_eq!(
265            result.unwrap_err().layout(),
266            Layout::from_size_align(10, 1).unwrap()
267        );
268    }
269}