use std::cell::Cell;
use std::mem;
use std::ptr;
use super::{Allocator, Error, Block, BlockOwner, HeapAllocator, HEAP};
pub struct Scoped<'parent, A: 'parent + Allocator> {
allocator: &'parent A,
current: Cell<*mut u8>,
end: *mut u8,
root: bool,
start: *mut u8,
}
impl Scoped<'static, HeapAllocator> {
pub fn new(size: usize) -> Result<Self, Error> {
Scoped::new_from(HEAP, size)
}
}
impl<'parent, A: Allocator> Scoped<'parent, A> {
pub fn new_from(alloc: &'parent A, size: usize) -> Result<Self, Error> {
match unsafe { alloc.allocate_raw(size, mem::align_of::<usize>()) } {
Ok(block) => Ok(Scoped {
allocator: alloc,
current: Cell::new(block.ptr()),
end: unsafe { block.ptr().offset(block.size() as isize) },
root: true,
start: block.ptr(),
}),
Err(err) => Err(err),
}
}
pub fn scope<F, U>(&self, f: F) -> Result<U, ()>
where F: FnMut(&Self) -> U
{
if self.is_scoped() {
return Err(());
}
let mut f = f;
let old = self.current.get();
let alloc = Scoped {
allocator: self.allocator,
current: self.current.clone(),
end: self.end,
root: false,
start: old,
};
self.current.set(ptr::null_mut());
let u = f(&alloc);
self.current.set(old);
mem::forget(alloc);
Ok(u)
}
pub fn is_scoped(&self) -> bool {
self.current.get().is_null()
}
}
unsafe impl<'a, A: Allocator> Allocator for Scoped<'a, A> {
unsafe fn allocate_raw(&self, size: usize, align: usize) -> Result<Block, Error> {
if self.is_scoped() {
return Err(Error::AllocatorSpecific("Called allocate on already scoped \
allocator."
.into()));
}
if size == 0 {
return Ok(Block::empty());
}
let current_ptr = self.current.get();
let aligned_ptr = super::align_forward(current_ptr, align);
let end_ptr = aligned_ptr.offset(size as isize);
if end_ptr > self.end {
Err(Error::OutOfMemory)
} else {
self.current.set(end_ptr);
Ok(Block::new(aligned_ptr, size, align))
}
}
unsafe fn reallocate_raw<'b>(&'b self, block: Block<'b>, new_size: usize) -> Result<Block<'b>, (Error, Block<'b>)> {
let current_ptr = self.current.get();
if new_size == 0 {
Ok(Block::empty())
} else if block.is_empty() {
Err((Error::UnsupportedAlignment, block))
} else if block.ptr().offset(block.size() as isize) == current_ptr {
let new_cur = current_ptr.offset((new_size - block.size()) as isize);
if new_cur < self.end {
self.current.set(new_cur);
Ok(Block::new(block.ptr(), new_size, block.align()))
} else {
Err((Error::OutOfMemory, block))
}
} else {
match self.allocate_raw(new_size, block.align()) {
Ok(new_block) => {
ptr::copy_nonoverlapping(block.ptr(), new_block.ptr(), block.size());
Ok(new_block)
}
Err(err) => {
Err((err, block))
}
}
}
}
unsafe fn deallocate_raw(&self, block: Block) {
if block.is_empty() || block.ptr().is_null() {
return;
}
let current_ptr = self.current.get();
if !self.is_scoped() && block.ptr().offset(block.size() as isize) == current_ptr {
self.current.set(block.ptr());
}
}
}
impl<'a, A: Allocator> BlockOwner for Scoped<'a, A> {
fn owns_block(&self, block: &Block) -> bool {
let ptr = block.ptr();
ptr >= self.start && ptr <= self.end
}
}
impl<'a, A: Allocator> Drop for Scoped<'a, A> {
fn drop(&mut self) {
let size = self.end as usize - self.start as usize;
if self.root && size > 0 {
unsafe {
self.allocator
.deallocate_raw(Block::new(self.start, size, mem::align_of::<usize>()))
}
}
}
}
unsafe impl<'a, A: 'a + Allocator + Sync> Send for Scoped<'a, A> {}
#[cfg(test)]
mod tests {
use super::super::*;
#[test]
#[should_panic]
fn use_outer() {
let alloc = Scoped::new(4).unwrap();
let mut outer_val = alloc.allocate(0i32).unwrap();
alloc.scope(|_inner| {
outer_val = alloc.allocate(1i32).unwrap();
})
.unwrap();
}
#[test]
fn scope_scope() {
let alloc = Scoped::new(64).unwrap();
let _ = alloc.allocate(0).unwrap();
alloc.scope(|inner| {
let _ = inner.allocate(32);
inner.scope(|bottom| {
let _ = bottom.allocate(23);
})
.unwrap();
})
.unwrap();
}
#[test]
fn out_of_memory() {
let alloc = Scoped::new(0).unwrap();
let (err, _) = alloc.allocate(1i32).err().unwrap();
assert_eq!(err, Error::OutOfMemory);
}
#[test]
fn placement_in() {
let alloc = Scoped::new(8_000_000).unwrap();
let _big = in alloc.make_place().unwrap() { [0u8; 8_000_000] };
}
#[test]
fn owning() {
let alloc = Scoped::new(64).unwrap();
let val = alloc.allocate(1i32).unwrap();
assert!(alloc.owns(&val));
alloc.scope(|inner| {
let in_val = inner.allocate(2i32).unwrap();
assert!(inner.owns(&in_val));
assert!(!inner.owns(&val));
})
.unwrap();
}
#[test]
fn mutex_sharing() {
use std::thread;
use std::sync::{Arc, Mutex};
let alloc = Scoped::new(64).unwrap();
let data = Arc::new(Mutex::new(alloc));
for i in 0..10 {
let data = data.clone();
thread::spawn(move || {
let alloc_handle = data.lock().unwrap();
let _ = alloc_handle.allocate(i).unwrap();
});
}
}
}