1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
//! This crate contains [an allocator](struct.NAlloc.html) that can be used to wrap another
//! allocator to turn allocation on and off. This is meant to be used in unit tests.
//!
//! To use it, declare a static variable with the `#[global_allocator]`
//! attribute. It can wrap any allocator implementing
//! [`GlobalAlloc`](https://doc.rust-lang.org/std/alloc/trait.GlobalAlloc.html).
//!
//! ```rust
//! # extern crate std;
//! #[global_allocator]
//! static ALLOCATOR: nalloc::NAlloc<std::alloc::System> = {
//!     nalloc::NAlloc::new(std::alloc::System)
//! };
//! ```
//!
//! Allocation is allowed by default. To prevent it, call the `deny` method
//! on the allocator. When allocation is attempted while a lock is alive,
//! the process will abort.
//!
//! ```rust,should_panic
//! # extern crate std;
//! # #[global_allocator]
//! # static ALLOCATOR: nalloc::NAlloc<std::alloc::System> = {
//! #     nalloc::NAlloc::new(std::alloc::System)
//! # };
//! let this_is_allowed = vec![1, 2, 3];
//!
//! let _lock = ALLOCATOR.deny();
//! let this_will_abort = vec![4, 5, 6];
//! ```
//!
//! # Limitations
//! ## Parallel tests
//!
//! Note that by nature, the default test executor will use this allocator if
//! you add it in your test module. This will cause issues as the test executor
//! itself allocate memory. You can circumvent this by using
//! `cargo test -- --test-threads=1`.
//!
//! ## Aborting
//! If allocation is attempted while a lock is alive, the process will abort.
//! This means the entire process will be killed, rather than a single thread,
//! and it is not catchable with
//! [`catch_unwind`](https://doc.rust-lang.org/std/panic/fn.catch_unwind.html).

#![no_std]

#![forbid(warnings)]
#![forbid(missing_docs)]
extern crate alloc;

/// A wrapper around an allocator to turn allocation on and off.
#[derive(Debug)]
pub struct NAlloc<T> {
    wrapped: T,
    /// A counter for locks. Allocation is only allowed when the counter is 0.
    state: core::sync::atomic::AtomicU64,
}

impl<T> NAlloc<T> {
    /// Wraps an allocator.
    pub const fn new(wrapped: T) -> NAlloc<T> {
        Self {
            wrapped,
            state: core::sync::atomic::AtomicU64::new(0),
        }
    }

    /// Forbid allocations.
    ///
    /// This functions returns a lock that must be kept alive as long as no
    /// allocations are allowed.
    #[must_use = "The lock must stay alive as long as no allocations are allowed."]
    pub fn deny<'a>(&'a self) -> AllocationLocker<'a, T> {
        if self
            .state
            .fetch_add(1, core::sync::atomic::Ordering::Release)
            == u64::MAX
        {
            panic!("Allocation counter wrapped around");
        }

        AllocationLocker { allocator: self }
    }

    fn unlock<'a>(&'a self) {
        if self
            .state
            .fetch_sub(1, core::sync::atomic::Ordering::Release)
            == 0
        {
            panic!("Allocation counter wrapped around");
        }
    }
}

unsafe impl<T: alloc::alloc::GlobalAlloc> alloc::alloc::GlobalAlloc for NAlloc<T> {
    unsafe fn alloc(&self, layout: alloc::alloc::Layout) -> *mut u8 {
        if self.state.load(core::sync::atomic::Ordering::Relaxed) == 0 {
            self.wrapped.alloc(layout)
        } else {
            alloc::alloc::handle_alloc_error(layout)
        }
    }

    unsafe fn dealloc(&self, ptr: *mut u8, layout: alloc::alloc::Layout) {
        self.wrapped.dealloc(ptr, layout)
    }

    unsafe fn alloc_zeroed(&self, layout: alloc::alloc::Layout) -> *mut u8 {
        if self.state.load(core::sync::atomic::Ordering::Relaxed) == 0 {
            self.wrapped.alloc_zeroed(layout)
        } else {
            alloc::alloc::handle_alloc_error(layout)
        }
    }

    unsafe fn realloc(&self, ptr: *mut u8, layout: alloc::alloc::Layout, new_size: usize) -> *mut u8 {
        if self.state.load(core::sync::atomic::Ordering::Relaxed) == 0 {
            self.wrapped.realloc(ptr, layout, new_size)
        } else {
            alloc::alloc::handle_alloc_error(layout)
        }
    }
}

/// A lock that must be kept alive as long as no allocation is allowed.
pub struct AllocationLocker<'a, T> {
    allocator: &'a NAlloc<T>,
}

impl<'a, T> Drop for AllocationLocker<'a, T> {
    fn drop(&mut self) {
        self.allocator.unlock()
    }
}