1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
//! A port of my `bounded_alloc` module from *Monadic Bot*.
//! Unlike the original, this one is more tightly bound to the
//! closures it is used with.
//! Also unlike the original, this one does not rely on *undefined behavior*.
//! By default, that is.
//! There are two ways to make this bounded allocator useful:
//! - Rely on `std` library UB and make alloc() unwind
//! - Rely on a Nightly feature and set an unwinding alloc_error_hook
//!
//! There is a secret CFG option for the first thing, and you can do the second thing yourself like:
//! ```
//! #![feature(alloc_error_hook)]
//! use ::std::alloc::{Layout, set_alloc_error_hook};
//! fn your_alloc_error_hook(layout: Layout) {
//!    panic!("memory allocation of {} bytes failed", layout.size());
//! }
//! ::std::alloc::set_alloc_error_hook(your_alloc_error_hook);
//! ```
//! In either case, the [`with_max_alloc`](./macro.with_max_alloc.html) macro simplifies basic usage
//! by performing the `catch_unwind` for you.
use ::std::alloc::{System, Layout};
use ::core::sync::atomic::{AtomicUsize, Ordering};
use super::Store;

/// A [`Store`] wrapper which provides an allocation limit on top of whatever allocator it is given.
pub struct Bounded<T> {
    exclusive_upper: usize,
    total: AtomicUsize,
    alloc: T,
}
impl Bounded<System> {
    pub const fn system(exclusive_upper: usize) -> Bounded<System> {
        Bounded { exclusive_upper, total: AtomicUsize::new(0), alloc: System }
    }
}
impl<T> Bounded<T> {
    pub const fn new(exclusive_upper: usize, alloc: T) -> Self {
        Self { exclusive_upper, total: AtomicUsize::new(0), alloc }
    }
    fn attempt_add(&self, amt: usize) -> Result<usize, usize> {
        self.total.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |x| {
            // Alter the stored value of total if and only if
            // the new total is a permitted amount.
            // Note that the saturating addition + exclusive upper bound
            // entails that any case that would cause overflow is necessarily disallowed.
            let x = x.saturating_add(amt);
            if x >= self.exclusive_upper {
                None
            } else {
                Some(x)
            }
        })
    }
}

#[cfg(OUTSOURCE_HEAP_BOUNDED_UB)]
pub struct Oom;
#[cfg(OUTSOURCE_HEAP_BOUNDED_UB)]
fn fail_alloc() -> *mut u8 { ::std::panic::panic_any(Oom) }
#[cfg(not(OUTSOURCE_HEAP_BOUNDED_UB))]
fn fail_alloc() -> *mut u8 { ::core::ptr::null_mut() }

unsafe impl<T: Store> Store for Bounded<T> {
    unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
        if layout.size() > 0 {
            // A first pass outline of what we *want* to do here is:
            // 1. Saturating increment the allocation total.
            // 2. Check if we went over our capacity.
            // 3. If we did go over, fail allocation and restore total to what it was.
            // But we're in a multithreaded context, and making the intermediate
            // state of total visible to other threads will cause spurious failures.
            // So, instead, we load the total first, perform our checks,
            // then save whatever the new value is.
            // Except, doing that will cause a Time of Check to Time of Use bug.
            // Since another allocation could be performed on another thread
            // at roughly the same time, and finish allocating something that would
            // make our current allocation jump over the allowed capacity.
            // Even further instead, we're going to use the fetch_update method
            // on AtomicUsize, which appears to do what we want (which is a compare_exchange loop).
            let total = self.attempt_add(layout.size());
            match total {
                Ok(_) => unsafe { Store::alloc(&self.alloc, layout) },
                Err(_) => fail_alloc(),
            }
        } else {
            unreachable!("calling GlobalAlloc::alloc with a layout.size == 0 is library UB")
        }
    }
    unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
        // It is impossible for this subtraction to overflow,
        // since the corresponding addition already passed alloc.
        // There is a proof in alloc that no overflowing adds to
        // the stored total ever occur, which means subtracting
        // that which has been successfully added to the stored total
        // cannot possibly overflow past 0.
        let prev = self.total.fetch_sub(layout.size(), Ordering::Relaxed);
        debug_assert!(prev >= layout.size());
        unsafe { Store::dealloc(&self.alloc, ptr, layout) }
    }
    unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 {
        match self.attempt_add(layout.size()) {
            Ok(_) => unsafe { Store::alloc_zeroed(&self.alloc, layout) },
            Err(_) => fail_alloc(),
        }
    }
    unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
        if new_size > layout.size() {
            let diff = new_size - layout.size();
            match self.attempt_add(diff) {
                Ok(_) => unsafe { Store::realloc(&self.alloc, ptr, layout, new_size) },
                Err(_) => fail_alloc(),
            }
        } else if new_size < layout.size() {
            let diff = layout.size() - new_size;
            // See proof in dealloc. Shrinking an existing allocation is analogous.
            let prev = self.total.fetch_sub(diff, Ordering::Relaxed);
            debug_assert!(prev > diff);
            unsafe { Store::realloc(&self.alloc, ptr, layout, new_size) }
        } else {
            // do nothing lmao
            ptr
        }
    }
}

/// Run a task with a statically known bound on allocation.
#[macro_export]
macro_rules! with_max_alloc {
    ($amt:expr, || $blk:expr) => {
        {
            static BOUND_PROVIDER: $crate::bounded::Bounded<System> = $crate::bounded::Bounded::system($amt);
            ::std::panic::catch_unwind(
                ::core::panic::AssertUnwindSafe(|| $crate::run(&BOUND_PROVIDER, || $blk)))
        }
    }
}
#[doc(inline)]
pub use crate::with_max_alloc;