1use std::alloc::Layout;
4use std::collections::TryReserveError;
5use std::error::Error;
6use std::fmt;
7use std::panic::{PanicInfo, UnwindSafe};
8use std::sync::atomic::{AtomicBool, Ordering};
9
10#[derive(Copy, Clone)]
12#[repr(transparent)]
13pub struct AllocError(Layout);
14
15impl AllocError {
16 #[must_use]
18 #[inline]
19 pub const fn new(layout: Layout) -> Self {
20 AllocError(layout)
21 }
22
23 #[must_use]
25 #[inline]
26 pub const fn layout(self) -> Layout {
27 self.0
28 }
29}
30
31impl fmt::Debug for AllocError {
32 #[inline]
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 f.debug_struct("AllocError")
35 .field("size", &self.0.size())
36 .field("align", &self.0.align())
37 .finish()
38 }
39}
40
41impl fmt::Display for AllocError {
42 #[inline]
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 write!(
45 f,
46 "failed to allocate memory by required layout {{size: {}, align: {}}}",
47 self.0.size(),
48 self.0.align()
49 )
50 }
51}
52
53impl Error for AllocError {}
54
55impl From<TryReserveError> for AllocError {
56 #[inline]
57 fn from(e: TryReserveError) -> Self {
58 use std::collections::TryReserveErrorKind;
59 match e.kind() {
60 TryReserveErrorKind::AllocError { layout, .. } => AllocError::new(layout),
61 TryReserveErrorKind::CapacityOverflow => {
62 unreachable!("unexpected capacity overflow")
63 }
64 }
65 }
66}
67
68fn alloc_error_hook(layout: Layout) {
69 std::panic::panic_any(AllocError(layout))
70}
71
72type PanicHook = Box<dyn Fn(&PanicInfo<'_>) + 'static + Sync + Send>;
73
74fn panic_hook(panic_info: &PanicInfo<'_>) {
75 if !panic_info.payload().is::<AllocError>() {
77 std::process::abort();
78 }
79}
80
81#[inline]
89pub fn catch_alloc_error<F: FnOnce() -> R + UnwindSafe, R>(f: F) -> Result<R, AllocError> {
90 static SET_HOOK: AtomicBool = AtomicBool::new(false);
91 if !SET_HOOK.load(Ordering::Acquire) {
92 let hook: PanicHook = Box::try_new(panic_hook)
93 .map_err(|_| AllocError::new(Layout::new::<fn(&PanicInfo)>()))?;
94 std::panic::set_hook(hook);
95 std::alloc::set_alloc_error_hook(alloc_error_hook);
96 SET_HOOK.store(true, Ordering::Release);
97 }
98
99 #[cfg(feature = "global-allocator")]
100 allocator::ThreadPanic::try_reserve_mem()?;
101
102 let result = std::panic::catch_unwind(f);
103 match result {
104 Ok(r) => Ok(r),
105 Err(e) => match e.downcast_ref::<AllocError>() {
106 None => unreachable!(),
107 Some(e) => Err(*e),
108 },
109 }
110}
111
112#[cfg(feature = "global-allocator")]
113mod allocator {
114 use crate::AllocError;
115 use std::alloc::{GlobalAlloc, Layout, System};
116 use std::cell::RefCell;
117 use std::ptr::NonNull;
118
119 #[global_allocator]
120 static GLOBAL: Alloc = Alloc;
121
122 struct Alloc;
123
124 unsafe impl GlobalAlloc for Alloc {
125 #[inline]
126 unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
127 let ptr = System.alloc(layout);
128
129 if ptr.is_null() && std::thread::panicking() {
130 if let Some(p) = ThreadPanic::take_mem(layout) {
131 return p.as_ptr();
132 }
133 }
134
135 ptr
136 }
137
138 #[inline]
139 unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
140 System.dealloc(ptr, layout)
141 }
142 }
143
144 struct PanicMem {
145 box_me_up: Option<NonNull<u8>>,
147
148 #[cfg(not(target_os = "windows"))]
150 exception: Option<NonNull<u8>>,
151 }
152
153 impl PanicMem {
154 const BOX_ME_UP_LAYOUT: Layout = unsafe { Layout::from_size_align_unchecked(16, 8) };
155
156 #[cfg(not(target_os = "windows"))]
157 const EXCEPTION_LAYOUT: Layout = unsafe { Layout::from_size_align_unchecked(80, 8) };
158
159 #[inline]
160 const fn new() -> Self {
161 PanicMem {
162 box_me_up: None,
163
164 #[cfg(not(target_os = "windows"))]
165 exception: None,
166 }
167 }
168
169 #[inline]
170 fn try_reserve(&mut self) -> Result<(), AllocError> {
171 if self.box_me_up.is_none() {
172 let ptr = unsafe { System.alloc(PanicMem::BOX_ME_UP_LAYOUT) };
173 if ptr.is_null() {
174 return Err(AllocError::new(PanicMem::BOX_ME_UP_LAYOUT));
175 } else {
176 self.box_me_up = unsafe { Some(NonNull::new_unchecked(ptr)) };
177 }
178 }
179
180 #[cfg(not(target_os = "windows"))]
181 if self.exception.is_none() {
182 let ptr = unsafe { System.alloc(PanicMem::EXCEPTION_LAYOUT) };
183 if ptr.is_null() {
184 return Err(AllocError::new(PanicMem::EXCEPTION_LAYOUT));
185 } else {
186 self.exception = unsafe { Some(NonNull::new_unchecked(ptr)) };
187 }
188 }
189
190 Ok(())
191 }
192
193 #[inline]
194 fn take_mem(&mut self, layout: Layout) -> Option<NonNull<u8>> {
195 if layout == PanicMem::BOX_ME_UP_LAYOUT {
196 return self.box_me_up.take();
197 }
198
199 #[cfg(not(target_os = "windows"))]
200 if layout == PanicMem::EXCEPTION_LAYOUT {
201 return self.exception.take();
202 }
203
204 None
205 }
206 }
207
208 impl Drop for PanicMem {
209 #[inline]
210 fn drop(&mut self) {
211 if let Some(mut ptr) = self.box_me_up.take() {
212 unsafe { System.dealloc(ptr.as_mut(), PanicMem::BOX_ME_UP_LAYOUT) };
213 }
214
215 #[cfg(not(target_os = "windows"))]
216 if let Some(mut ptr) = self.exception.take() {
217 unsafe { System.dealloc(ptr.as_mut(), PanicMem::EXCEPTION_LAYOUT) };
218 }
219 }
220 }
221
222 thread_local! {
223 static THREAD_PANIC_MEM: RefCell<PanicMem> = RefCell::new(PanicMem::new());
224 }
225
226 pub struct ThreadPanic;
227
228 impl ThreadPanic {
229 #[inline]
230 pub fn try_reserve_mem() -> Result<(), AllocError> {
231 THREAD_PANIC_MEM.with(|panic_mem| panic_mem.borrow_mut().try_reserve())
232 }
233
234 #[inline]
235 pub fn take_mem(layout: Layout) -> Option<NonNull<u8>> {
236 THREAD_PANIC_MEM.with(|panic_mem| panic_mem.borrow_mut().take_mem(layout))
237 }
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::catch_alloc_error;
244 use std::alloc::{AllocError as StdAllocError, Allocator, Layout};
245 use std::ptr::NonNull;
246
247 struct NoMem;
248
249 unsafe impl Allocator for NoMem {
250 #[inline]
251 fn allocate(&self, _layout: Layout) -> Result<NonNull<[u8]>, StdAllocError> {
252 Err(StdAllocError)
253 }
254
255 #[inline]
256 unsafe fn deallocate(&self, _ptr: NonNull<u8>, _layout: Layout) {
257 unreachable!()
258 }
259 }
260
261 #[test]
262 fn test_catch_alloc_error() {
263 let result = catch_alloc_error(|| Vec::<u8, _>::with_capacity_in(10, NoMem));
264 assert_eq!(
265 result.unwrap_err().layout(),
266 Layout::from_size_align(10, 1).unwrap()
267 );
268 }
269}