use crate::internals::mem;
use crate::util::{
align_up_ptr_mut, align_up_usize, is_aligned_ptr, large_offset_from, nonnull_as_mut_ptr,
unlikely,
};
use crate::zeroize::zeroize_mem;
use allocator_api2::alloc::{AllocError, Allocator};
use core::alloc::Layout;
use core::cell::Cell;
use core::ptr::{self, NonNull};
use mirai_annotations::debug_checked_precondition;
pub struct SecStackSinglePageAlloc {
bytes: Cell<usize>,
page: mem::Page,
stack_offset: Cell<usize>,
}
impl SecStackSinglePageAlloc {
#[cfg(test)]
fn consistency_check(&self) {
let bytes = self.bytes.get();
let stack_offset = self.stack_offset.get();
assert!(
stack_offset % 8 == 0,
"safety critical SecStackSinglePageAlloc invariant: offset alignment"
);
assert!(
stack_offset <= self.page.page_size(),
"safety critical SecStackSinglePageAlloc invariant: offset in page size"
);
assert!(
is_aligned_ptr(self.page.as_ptr(), 8),
"safety critical SecStackSinglePageAlloc invariant: page alignment"
);
assert!(
bytes <= stack_offset,
"critical SecStackSinglePageAlloc consistency: allocated bytes in offset"
);
assert!(
bytes % 8 == 0,
"SecStackSinglePageAlloc consistency: allocated bytes 8 multiple"
);
}
}
#[cfg(debug_assertions)]
impl Drop for SecStackSinglePageAlloc {
#[cfg(feature = "std")]
fn drop(&mut self) {
if self.bytes.get() != 0 {
std::process::abort();
}
let page_ptr: *const u8 = self.page.as_ptr();
for offset in 0..self.page.page_size() {
let byte = unsafe { page_ptr.wrapping_add(offset).read() };
if byte != 0 {
std::process::abort();
}
}
}
#[cfg(not(feature = "std"))]
fn drop(&mut self) {
debug_assert!(self.bytes.get() == 0);
let page_ptr: *const u8 = self.page.as_ptr();
for offset in 0..self.page.page_size() {
let byte = unsafe { page_ptr.wrapping_add(offset).read() };
assert!(byte == 0);
}
}
}
#[cfg(any(unix, windows))]
impl SecStackSinglePageAlloc {
pub fn new() -> Result<Self, mem::PageAllocError> {
let page = mem::Page::alloc_new_lock()?;
Ok(Self {
bytes: Cell::new(0),
page,
stack_offset: Cell::new(0),
})
}
}
impl SecStackSinglePageAlloc {
fn ptr_is_last_allocation(&self, ptr: NonNull<u8>, rounded_size: usize) -> bool {
let alloc_start_offset = unsafe { large_offset_from(ptr.as_ptr(), self.page.as_ptr()) };
let alloc_end_offset = alloc_start_offset + rounded_size;
alloc_end_offset == self.stack_offset.get()
}
#[must_use]
pub unsafe fn allocate_zerosized(align: usize) -> NonNull<[u8]> {
debug_checked_precondition!(align.is_power_of_two());
let dangling: *mut u8 = ptr::without_provenance_mut(align);
let zerosized_slice: *mut [u8] = ptr::slice_from_raw_parts_mut(dangling, 0);
unsafe { NonNull::new_unchecked(zerosized_slice) }
}
pub unsafe fn realloc_shrink(
&self,
ptr: NonNull<u8>,
old_layout: Layout,
new_layout: Layout,
) -> Result<NonNull<[u8]>, AllocError> {
debug_checked_precondition!(
new_layout.size() <= old_layout.size(),
"`new_layout.size()` must be smaller than or equal to `old_layout.size()`"
);
let new_ptr = self.allocate(new_layout)?;
unsafe {
ptr::copy_nonoverlapping(ptr.as_ptr(), nonnull_as_mut_ptr(new_ptr), new_layout.size());
self.deallocate(ptr, old_layout);
}
Ok(new_ptr)
}
pub unsafe fn realloc_grow(
&self,
ptr: NonNull<u8>,
old_layout: Layout,
new_layout: Layout,
) -> Result<NonNull<[u8]>, AllocError> {
debug_checked_precondition!(
new_layout.size() >= old_layout.size(),
"`new_layout.size()` must be greater than or equal to `old_layout.size()`"
);
let new_ptr = self.allocate(new_layout)?;
unsafe {
ptr::copy_nonoverlapping(ptr.as_ptr(), nonnull_as_mut_ptr(new_ptr), old_layout.size());
self.deallocate(ptr, old_layout);
}
Ok(new_ptr)
}
}
unsafe impl Allocator for SecStackSinglePageAlloc {
fn allocate_zeroed(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
debug_checked_precondition!(layout.align().is_power_of_two());
if layout.size() == 0 {
return Ok(unsafe { Self::allocate_zerosized(layout.align()) });
}
let rounded_req_size = layout.size().wrapping_add(7usize) & !7usize;
if unlikely(rounded_req_size == 0) {
return Err(AllocError);
}
if rounded_req_size > self.page.page_size() - self.stack_offset.get() {
return Err(AllocError);
}
let stack_ptr: *mut u8 = unsafe { self.page.as_ptr_mut().add(self.stack_offset.get()) };
if layout.align() <= 8 {
debug_assert!(
layout.align() == 1
|| layout.align() == 2
|| layout.align() == 4
|| layout.align() == 8
);
let alloc_slice_ptr: *mut [u8] =
ptr::slice_from_raw_parts_mut(stack_ptr, rounded_req_size);
let alloc_slice_ptr: NonNull<[u8]> = unsafe { NonNull::new_unchecked(alloc_slice_ptr) };
self.stack_offset
.set(self.stack_offset.get() + rounded_req_size);
self.bytes.set(self.bytes.get() + rounded_req_size);
Ok(alloc_slice_ptr)
} else {
let next_aligned_ptr = unsafe { align_up_ptr_mut(stack_ptr, layout.align()) };
if unlikely(next_aligned_ptr.is_null()) {
return Err(AllocError);
}
let next_align_pageoffset =
unsafe { large_offset_from(next_aligned_ptr, self.page.as_ptr()) };
if next_align_pageoffset >= self.page.page_size() {
return Err(AllocError);
}
if rounded_req_size > self.page.page_size() - next_align_pageoffset {
return Err(AllocError);
}
let alloc_slice_ptr: *mut [u8] =
ptr::slice_from_raw_parts_mut(next_aligned_ptr, rounded_req_size);
let alloc_slice_ptr: NonNull<[u8]> = unsafe { NonNull::new_unchecked(alloc_slice_ptr) };
self.stack_offset
.set(next_align_pageoffset + rounded_req_size);
self.bytes.set(self.bytes.get() + rounded_req_size);
Ok(alloc_slice_ptr)
}
}
fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
self.allocate_zeroed(layout)
}
unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
if layout.size() == 0 {
return;
}
debug_checked_precondition!(self.page.as_ptr().addr() <= ptr.as_ptr().addr());
debug_checked_precondition!(
ptr.as_ptr().addr() <= self.page.as_ptr().addr() + self.stack_offset.get()
);
let rounded_req_size = align_up_usize(layout.size(), 8);
let ptr = self.page.as_ptr_mut().with_addr(ptr.as_ptr().addr());
unsafe {
zeroize_mem(ptr, rounded_req_size);
}
self.bytes.set(self.bytes.get() - rounded_req_size);
if self.bytes.get() == 0 {
self.stack_offset.set(0);
return;
}
let alloc_start_offset = unsafe { large_offset_from(ptr, self.page.as_ptr()) };
let alloc_end_offset = alloc_start_offset + rounded_req_size;
if alloc_end_offset == self.stack_offset.get() {
self.stack_offset.set(alloc_start_offset);
}
}
unsafe fn shrink(
&self,
ptr: NonNull<u8>,
old_layout: Layout,
new_layout: Layout,
) -> Result<NonNull<[u8]>, AllocError> {
debug_checked_precondition!(
new_layout.size() <= old_layout.size(),
"`new_layout.size()` must be smaller than or equal to `old_layout.size()`"
);
if new_layout.size() == 0 {
unsafe {
self.deallocate(ptr, old_layout);
}
return Ok(unsafe { Self::allocate_zerosized(new_layout.align()) });
}
debug_checked_precondition!(self.page.as_ptr().addr() <= ptr.as_ptr().addr());
debug_checked_precondition!(
ptr.as_ptr().addr() <= self.page.as_ptr().addr() + self.stack_offset.get()
);
if is_aligned_ptr(ptr.as_ptr(), new_layout.align()) {
let rounded_size: usize = align_up_usize(old_layout.size(), 8);
let new_rounded_size: usize = align_up_usize(new_layout.size(), 8);
let new_alloc_end: *mut u8 = unsafe { ptr.as_ptr().add(new_rounded_size) };
let size_decrease: usize = rounded_size - new_rounded_size;
unsafe {
zeroize_mem(new_alloc_end, size_decrease);
}
self.bytes.set(self.bytes.get() - size_decrease);
if self.ptr_is_last_allocation(ptr, rounded_size) {
self.stack_offset
.set(self.stack_offset.get() - size_decrease);
}
let alloc_slice_ptr: *mut [u8] =
ptr::slice_from_raw_parts_mut(ptr.as_ptr(), new_rounded_size);
let alloc_slice_ptr: NonNull<[u8]> = unsafe { NonNull::new_unchecked(alloc_slice_ptr) };
Ok(alloc_slice_ptr)
} else {
unsafe { self.realloc_shrink(ptr, old_layout, new_layout) }
}
}
unsafe fn grow_zeroed(
&self,
ptr: NonNull<u8>,
old_layout: Layout,
new_layout: Layout,
) -> Result<NonNull<[u8]>, AllocError> {
debug_checked_precondition!(
new_layout.size() >= old_layout.size(),
"`new_layout.size()` must be greater than or equal to `old_layout.size()`"
);
if old_layout.size() == 0 {
return self.allocate(new_layout);
}
debug_checked_precondition!(self.page.as_ptr().addr() <= ptr.as_ptr().addr());
debug_checked_precondition!(
ptr.as_ptr().addr() <= self.page.as_ptr().addr() + self.stack_offset.get()
);
if is_aligned_ptr(ptr.as_ptr(), new_layout.align()) {
let rounded_size: usize = align_up_usize(old_layout.size(), 8);
if self.ptr_is_last_allocation(ptr, rounded_size) {
let new_rounded_size: usize = align_up_usize(new_layout.size(), 8);
if unlikely(new_rounded_size == 0) {
return Err(AllocError);
}
let alloc_start_offset =
unsafe { large_offset_from(ptr.as_ptr(), self.page.as_ptr()) };
if new_rounded_size > self.page.page_size() - alloc_start_offset {
return Err(AllocError);
}
let size_increase: usize = new_rounded_size - rounded_size;
self.bytes.set(self.bytes.get() + size_increase);
self.stack_offset
.set(self.stack_offset.get() + size_increase);
let alloc_slice_ptr: *mut [u8] =
ptr::slice_from_raw_parts_mut(ptr.as_ptr(), new_rounded_size);
let alloc_slice_ptr: NonNull<[u8]> =
unsafe { NonNull::new_unchecked(alloc_slice_ptr) };
return Ok(alloc_slice_ptr);
}
}
unsafe { self.realloc_grow(ptr, old_layout, new_layout) }
}
unsafe fn grow(
&self,
ptr: NonNull<u8>,
old_layout: Layout,
new_layout: Layout,
) -> Result<NonNull<[u8]>, AllocError> {
unsafe { self.grow_zeroed(ptr, old_layout, new_layout) }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::allocator_api::{Box, Vec};
use std::mem::drop;
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[repr(align(16))]
struct Align16(u128);
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[repr(align(16))]
struct ByteAlign16(u8);
#[test]
fn create_consistency() {
let allocator = SecStackSinglePageAlloc::new().expect("allocator creation failed");
allocator.consistency_check();
}
#[test]
fn box_allocation_8b() {
let allocator = SecStackSinglePageAlloc::new().expect("allocator creation failed");
allocator.consistency_check();
{
let _heap_mem = Box::new_in([1u8; 8], &allocator);
allocator.consistency_check();
} allocator.consistency_check();
}
#[test]
fn box_allocation_9b() {
let allocator = SecStackSinglePageAlloc::new().expect("allocator creation failed");
allocator.consistency_check();
{
let _heap_mem = Box::new_in([1u8; 9], &allocator);
allocator.consistency_check();
} allocator.consistency_check();
}
#[test]
fn box_allocation_zst() {
let allocator = SecStackSinglePageAlloc::new().expect("allocator creation failed");
allocator.consistency_check();
{
let _heap_mem = Box::new_in([(); 8], &allocator);
allocator.consistency_check();
} allocator.consistency_check();
}
#[test]
fn multiple_box_allocations() {
let allocator = SecStackSinglePageAlloc::new().expect("allocator creation failed");
allocator.consistency_check();
{
let _heap_mem = Box::new_in([1u8; 9], &allocator);
allocator.consistency_check();
{
let _heap_mem2 = Box::new_in([1u8; 9], &allocator);
allocator.consistency_check();
} allocator.consistency_check();
{
let _heap_mem2prime = Box::new_in([1u8; 9], &allocator);
allocator.consistency_check();
} allocator.consistency_check();
} allocator.consistency_check();
}
#[test]
fn multiple_box_allocations_high_align() {
let allocator = SecStackSinglePageAlloc::new().expect("allocator creation failed");
allocator.consistency_check();
{
let _heap_mem = Box::new_in([Align16(1); 5], &allocator);
allocator.consistency_check();
{
let _heap_mem2 = Box::new_in([Align16(1); 9], &allocator);
allocator.consistency_check();
} allocator.consistency_check();
{
let _heap_mem2prime = Box::new_in([Align16(1); 2], &allocator);
allocator.consistency_check();
} allocator.consistency_check();
} allocator.consistency_check();
}
#[test]
fn multiple_box_allocations_mixed_align() {
let allocator = SecStackSinglePageAlloc::new().expect("allocator creation failed");
allocator.consistency_check();
{
let _heap_mem = Box::new_in([1u8; 17], &allocator);
allocator.consistency_check();
{
let _heap_mem2 = Box::new_in([Align16(1); 9], &allocator);
allocator.consistency_check();
} allocator.consistency_check();
{
let _heap_mem2prime = Box::new_in([Align16(1); 2], &allocator);
allocator.consistency_check();
} allocator.consistency_check();
} allocator.consistency_check();
}
#[test]
fn many_box_allocations_mixed_align_nonstacked_drop() {
let allocator = SecStackSinglePageAlloc::new().expect("allocator creation failed");
allocator.consistency_check();
{
let heap_mem1 = Box::new_in([Align16(1); 11], &allocator);
allocator.consistency_check();
let heap_mem2 = Box::new_in([ByteAlign16(1); 51], &allocator);
allocator.consistency_check();
let heap_mem3 = Box::new_in([1u8; 143], &allocator);
allocator.consistency_check();
drop(heap_mem3);
allocator.consistency_check();
let heap_mem4 = Box::new_in(ByteAlign16(1), &allocator);
allocator.consistency_check();
let heap_mem5 = Box::new_in(Align16(1), &allocator);
allocator.consistency_check();
drop(heap_mem2);
allocator.consistency_check();
drop(heap_mem1);
allocator.consistency_check();
drop(heap_mem4);
allocator.consistency_check();
drop(heap_mem5);
allocator.consistency_check();
} allocator.consistency_check();
}
#[test]
fn vec_allocation_9b() {
let allocator = SecStackSinglePageAlloc::new().expect("allocator creation failed");
allocator.consistency_check();
{
let _heap_mem = Vec::<u8, _>::with_capacity_in(9, &allocator);
allocator.consistency_check();
} allocator.consistency_check();
}
#[test]
fn vec_allocation_grow_repeated() {
let allocator = SecStackSinglePageAlloc::new().expect("allocator creation failed");
allocator.consistency_check();
{
let mut heap_mem = Vec::<u8, _>::with_capacity_in(9, &allocator);
allocator.consistency_check();
heap_mem.reserve(10);
allocator.consistency_check();
heap_mem.reserve(17);
allocator.consistency_check();
} allocator.consistency_check();
}
#[test]
fn vec_allocation_nonfinal_grow() {
let allocator = SecStackSinglePageAlloc::new().expect("allocator creation failed");
allocator.consistency_check();
{
let mut heap_mem = Vec::<u8, _>::with_capacity_in(9, &allocator);
allocator.consistency_check();
{
let _heap_mem2 = Box::new_in(37_u64, &allocator);
allocator.consistency_check();
heap_mem.reserve(10);
allocator.consistency_check();
heap_mem.reserve(17);
allocator.consistency_check();
} allocator.consistency_check();
} allocator.consistency_check();
}
#[test]
fn vec_allocation_shrink() {
let allocator = SecStackSinglePageAlloc::new().expect("allocator creation failed");
allocator.consistency_check();
{
let mut heap_mem = Vec::<u8, _>::with_capacity_in(9, &allocator);
allocator.consistency_check();
heap_mem.push(255);
allocator.consistency_check();
heap_mem.shrink_to_fit();
allocator.consistency_check();
} allocator.consistency_check();
}
#[test]
fn vec_allocation_nonfinal_shrink() {
let allocator = SecStackSinglePageAlloc::new().expect("allocator creation failed");
allocator.consistency_check();
{
let mut heap_mem = Vec::<u8, _>::with_capacity_in(9, &allocator);
allocator.consistency_check();
{
let _heap_mem2 = Box::new_in(37_u64, &allocator);
allocator.consistency_check();
heap_mem.push(1);
allocator.consistency_check();
heap_mem.shrink_to_fit();
allocator.consistency_check();
} allocator.consistency_check();
} allocator.consistency_check();
}
#[test]
fn allocate_zeroed() {
let allocator = SecStackSinglePageAlloc::new().expect("allocator creation failed");
let layout = Layout::new::<[u8; 16]>();
let ptr = allocator
.allocate_zeroed(layout)
.expect("allocation failed");
for i in 0..16 {
let val: u8 = unsafe { (ptr.as_ptr() as *const u8).add(i).read() };
assert_eq!(val, 0_u8);
}
unsafe {
allocator.deallocate(ptr.cast(), layout);
}
}
}