use std::alloc::Layout;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Mutex;
use crate::api::alloc::SmartAlloc;
pub struct TaskAlloc {
alloc: SmartAlloc,
allocations: Mutex<Vec<TrackedAllocation>>,
count: AtomicUsize,
}
struct TrackedAllocation {
ptr: NonNull<u8>,
layout: Layout,
drop_fn: Option<unsafe fn(NonNull<u8>)>,
}
unsafe impl Send for TrackedAllocation {}
unsafe impl Sync for TrackedAllocation {}
impl TaskAlloc {
pub fn new(alloc: &SmartAlloc) -> Self {
Self {
alloc: alloc.clone(),
allocations: Mutex::new(Vec::new()),
count: AtomicUsize::new(0),
}
}
pub fn alloc_box<T>(&self, value: T) -> TaskBox<T> {
let layout = Layout::new::<T>();
let ptr = unsafe {
let raw = std::alloc::alloc(layout);
if raw.is_null() {
std::alloc::handle_alloc_error(layout);
}
std::ptr::write(raw as *mut T, value);
NonNull::new_unchecked(raw)
};
let tracked = TrackedAllocation {
ptr,
layout,
drop_fn: Some(drop_typed::<T>),
};
self.allocations.lock().unwrap().push(tracked);
self.count.fetch_add(1, Ordering::Relaxed);
TaskBox {
ptr: ptr.cast(),
_marker: PhantomData,
}
}
pub fn alloc_slice<T>(&self, len: usize) -> TaskSlice<T> {
let layout = Layout::array::<T>(len).expect("layout overflow");
let ptr = unsafe {
let raw = std::alloc::alloc(layout);
if raw.is_null() {
std::alloc::handle_alloc_error(layout);
}
NonNull::new_unchecked(raw)
};
let tracked = TrackedAllocation {
ptr,
layout,
drop_fn: None,
};
self.allocations.lock().unwrap().push(tracked);
self.count.fetch_add(1, Ordering::Relaxed);
TaskSlice {
ptr: ptr.cast(),
len,
_marker: PhantomData,
}
}
pub fn allocation_count(&self) -> usize {
self.count.load(Ordering::Relaxed)
}
pub fn allocator(&self) -> &SmartAlloc {
&self.alloc
}
}
impl Drop for TaskAlloc {
fn drop(&mut self) {
let allocations = self.allocations.get_mut().unwrap();
for tracked in allocations.drain(..) {
unsafe {
if let Some(drop_fn) = tracked.drop_fn {
drop_fn(tracked.ptr);
}
std::alloc::dealloc(tracked.ptr.as_ptr(), tracked.layout);
}
}
}
}
unsafe impl Send for TaskAlloc {}
unsafe impl Sync for TaskAlloc {}
unsafe fn drop_typed<T>(ptr: NonNull<u8>) {
std::ptr::drop_in_place(ptr.as_ptr() as *mut T);
}
pub struct TaskBox<T> {
ptr: NonNull<T>,
_marker: PhantomData<T>,
}
impl<T> Deref for TaskBox<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { self.ptr.as_ref() }
}
}
impl<T> DerefMut for TaskBox<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.ptr.as_mut() }
}
}
impl<T> TaskBox<T> {
pub fn as_ptr(&self) -> *const T {
self.ptr.as_ptr()
}
pub fn as_mut_ptr(&mut self) -> *mut T {
self.ptr.as_ptr()
}
}
unsafe impl<T: Send> Send for TaskBox<T> {}
unsafe impl<T: Sync> Sync for TaskBox<T> {}
pub struct TaskSlice<T> {
ptr: NonNull<T>,
len: usize,
_marker: PhantomData<T>,
}
impl<T> TaskSlice<T> {
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn as_ptr(&self) -> *const T {
self.ptr.as_ptr()
}
pub fn as_mut_ptr(&mut self) -> *mut T {
self.ptr.as_ptr()
}
pub unsafe fn as_slice(&self) -> &[T] {
std::slice::from_raw_parts(self.ptr.as_ptr(), self.len)
}
pub unsafe fn as_mut_slice(&mut self) -> &mut [T] {
std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len)
}
}
unsafe impl<T: Send> Send for TaskSlice<T> {}
unsafe impl<T: Sync> Sync for TaskSlice<T> {}
#[cfg(test)]
mod tests {
use super::*;
use crate::AllocConfig;
#[test]
fn task_alloc_creates_and_drops() {
let alloc = SmartAlloc::new(AllocConfig::default());
{
let task = TaskAlloc::new(&alloc);
let _a = task.alloc_box(42u32);
let _b = task.alloc_box(String::from("hello"));
assert_eq!(task.allocation_count(), 2);
}
}
#[test]
fn task_box_deref() {
let alloc = SmartAlloc::new(AllocConfig::default());
let task = TaskAlloc::new(&alloc);
let boxed = task.alloc_box(vec![1, 2, 3]);
assert_eq!(&*boxed, &vec![1, 2, 3]);
}
#[test]
fn task_slice_basic() {
let alloc = SmartAlloc::new(AllocConfig::default());
let task = TaskAlloc::new(&alloc);
let mut slice: TaskSlice<u32> = task.alloc_slice(10);
assert_eq!(slice.len(), 10);
unsafe {
for i in 0..10 {
slice.as_mut_ptr().add(i).write(i as u32);
}
let s = slice.as_slice();
assert_eq!(s[5], 5);
}
}
}