use crate::macros::{
debug_handleallocerror_precondition, debug_handleallocerror_precondition_valid_layout,
precondition_memory_range,
};
use crate::zeroize::zeroize_mem;
use alloc::alloc::handle_alloc_error;
use allocator_api2::alloc::{AllocError, Allocator};
use core::alloc::{GlobalAlloc, Layout};
use core::ptr::NonNull;
#[derive(Debug, Default)]
pub struct ZeroizeAlloc<A> {
backend_alloc: A,
}
impl<A> ZeroizeAlloc<A> {
pub const fn new(backend_alloc: A) -> Self {
Self { backend_alloc }
}
}
unsafe impl<A: GlobalAlloc> GlobalAlloc for ZeroizeAlloc<A> {
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
debug_handleallocerror_precondition_valid_layout!(layout);
debug_handleallocerror_precondition!(layout.size() != 0, layout);
unsafe { self.backend_alloc.alloc(layout) }
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
debug_handleallocerror_precondition!(!ptr.is_null(), layout);
debug_handleallocerror_precondition_valid_layout!(layout);
debug_handleallocerror_precondition!(layout.size() != 0, layout);
precondition_memory_range!(ptr, layout.size());
if cfg!(debug_assertions) {
if ptr.addr().checked_add(layout.size()).is_none() {
handle_alloc_error(layout);
}
}
unsafe {
zeroize_mem(ptr, layout.size());
}
unsafe { self.backend_alloc.dealloc(ptr, layout) }
}
unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 {
debug_handleallocerror_precondition_valid_layout!(layout);
debug_handleallocerror_precondition!(layout.size() != 0, layout);
unsafe { self.backend_alloc.alloc_zeroed(layout) }
}
}
unsafe impl<A: Allocator> Allocator for ZeroizeAlloc<A> {
fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
debug_handleallocerror_precondition_valid_layout!(layout);
self.backend_alloc.allocate(layout)
}
fn allocate_zeroed(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
debug_handleallocerror_precondition_valid_layout!(layout);
self.backend_alloc.allocate_zeroed(layout)
}
unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
debug_handleallocerror_precondition_valid_layout!(layout);
unsafe {
zeroize_mem(ptr.as_ptr(), layout.size());
}
unsafe { self.backend_alloc.deallocate(ptr, layout) }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::allocator_api::{Box, Vec};
use std::alloc::System;
#[test]
fn box_allocation_8b() {
let allocator = ZeroizeAlloc::new(System);
let _heap_mem = Box::new_in([1u8; 8], &allocator);
}
#[test]
fn box_allocation_9b() {
let allocator = ZeroizeAlloc::new(System);
let _heap_mem = Box::new_in([1u8; 9], &allocator);
}
#[test]
fn box_allocation_zst() {
let allocator = ZeroizeAlloc::new(System);
let _heap_mem = Box::new_in([(); 8], &allocator);
}
#[test]
fn vec_allocation_9b() {
let allocator = ZeroizeAlloc::new(System);
let _heap_mem = Vec::<u8, _>::with_capacity_in(9, &allocator);
}
#[test]
fn vec_allocation_grow_repeated() {
let allocator = ZeroizeAlloc::new(System);
let mut heap_mem = Vec::<u8, _>::with_capacity_in(9, &allocator);
heap_mem.reserve(1);
heap_mem.reserve(7);
}
#[test]
fn vec_allocation_shrink() {
let allocator = ZeroizeAlloc::new(System);
let mut heap_mem = Vec::<u8, _>::with_capacity_in(9, &allocator);
heap_mem.push(255);
heap_mem.shrink_to_fit();
}
#[test]
fn allocate_zeroed() {
let allocator = ZeroizeAlloc::new(System);
let layout = Layout::new::<[u8; 16]>();
let ptr = allocator
.allocate_zeroed(layout)
.expect("allocation failed");
for i in 0..16 {
let val: u8 = unsafe { (ptr.as_ptr() as *const u8).add(i).read() };
assert_eq!(val, 0_u8);
}
unsafe {
allocator.deallocate(ptr.cast(), layout);
}
}
}