use core::ops::Not;
use std::cell::UnsafeCell;
use std::marker::{PhantomData, PhantomPinned};
use std::mem::MaybeUninit;
use std::pin::Pin;
use std::ptr;
use crate::util::Span;
use executorch_sys as sys;
pub trait MemoryAllocator<'a> {
#[doc(hidden)]
fn _cpp_ptr(&self) -> *const sys::MemoryAllocator;
#[cfg(feature = "std")]
#[doc(hidden)]
fn _into_unique_ptr(self) -> sys::cxx::UniquePtr<sys::MemoryAllocator>
where
Self: Sized;
private_decl! {}
#[allow(clippy::mut_from_ref)]
fn allocate_raw(&self, size: usize, alignment: usize) -> Option<&mut [u8]> {
let ptr = unsafe {
sys::executorch_MemoryAllocator_allocate(self._cpp_ptr().cast_mut(), size, alignment)
};
ptr.is_null()
.not()
.then(|| unsafe { std::slice::from_raw_parts_mut(ptr as *mut u8, size) })
}
}
pub trait MemoryAllocatorExt<'a>: MemoryAllocator<'a> {
#[allow(clippy::mut_from_ref)]
fn allocate_uninit<T>(&self) -> Option<&mut MaybeUninit<T>> {
let ptr = self.allocate_raw(std::mem::size_of::<T>(), std::mem::align_of::<T>())?;
let ptr = ptr.as_mut_ptr() as *mut MaybeUninit<T>;
Some(unsafe { &mut *ptr })
}
#[allow(clippy::mut_from_ref)]
fn allocate<T>(&self) -> Option<&mut T>
where
T: NoDrop + Default,
{
let val = self.allocate_uninit::<T>()?;
val.write(Default::default());
Some(unsafe { val.assume_init_mut() })
}
#[allow(clippy::mut_from_ref)]
fn allocate_pinned<T>(&self) -> Option<Pin<&mut T>>
where
T: NoDrop + Default,
{
let val = self.allocate::<T>()?;
Some(unsafe { Pin::new_unchecked(val) })
}
#[allow(clippy::mut_from_ref)]
fn allocate_arr<T>(&self, len: usize) -> Option<&mut [T]>
where
T: NoDrop + Default,
{
self.allocate_arr_fn(len, |_| Default::default())
}
#[allow(clippy::mut_from_ref)]
fn allocate_arr_fn<T>(&self, len: usize, f: impl Fn(usize) -> T) -> Option<&mut [T]>
where
T: NoDrop,
{
let elm_size = std::mem::size_of::<T>();
let alignment = std::mem::align_of::<T>();
let actual_elm_size = (elm_size + alignment - 1) & !(alignment - 1);
let total_size = actual_elm_size * len;
let ptr = self.allocate_raw(total_size, alignment)?.as_mut_ptr() as *mut T;
assert_eq!(actual_elm_size, {
let elm0_addr =
(&unsafe { std::slice::from_raw_parts_mut(ptr, 2) }[0]) as *const T as usize;
let elm1_addr =
(&unsafe { std::slice::from_raw_parts_mut(ptr, 2) }[1]) as *const T as usize;
elm1_addr - elm0_addr
});
let slice = unsafe { std::slice::from_raw_parts_mut(ptr, len) };
for (i, elm) in slice.iter_mut().enumerate() {
let ptr = elm as *mut T;
unsafe { ptr.write(f(i)) };
}
Some(slice)
}
}
impl<'a, T: MemoryAllocator<'a> + ?Sized> MemoryAllocatorExt<'a> for T {}
pub unsafe trait NoDrop {}
unsafe impl<T: Copy> NoDrop for T {}
unsafe impl<T: Storable> NoDrop for Storage<T> {}
unsafe impl<T: crate::util::SpanElement> NoDrop for crate::util::Span<'_, T> {}
pub struct BufferMemoryAllocator<'a>(
pub(crate) UnsafeCell<sys::MemoryAllocator>,
PhantomData<&'a ()>,
);
impl<'a> BufferMemoryAllocator<'a> {
pub fn new(buffer: &'a mut [u8]) -> Self {
let size = buffer.len().try_into().unwrap();
let base_addr = buffer.as_mut_ptr();
let allocator = unsafe { sys::executorch_MemoryAllocator_new(size, base_addr) };
Self(UnsafeCell::new(allocator), PhantomData)
}
}
impl<'a> MemoryAllocator<'a> for BufferMemoryAllocator<'a> {
fn _cpp_ptr(&self) -> *const sys::MemoryAllocator {
self.0.get()
}
#[cfg(feature = "std")]
fn _into_unique_ptr(self) -> sys::cxx::UniquePtr<sys::MemoryAllocator>
where
Self: Sized,
{
let self_pin = core::pin::Pin::new(unsafe { &mut *self._cpp_ptr().cast_mut() });
let ptr = sys::BufferMemoryAllocator_into_memory_allocator_unique_ptr(self_pin);
#[allow(clippy::forget_non_drop)]
std::mem::forget(self); ptr
}
private_impl! {}
}
#[cfg(feature = "std")]
pub use malloc_allocator::MallocMemoryAllocator;
#[cfg(feature = "std")]
mod malloc_allocator {
use super::*;
pub struct MallocMemoryAllocator(UnsafeCell<sys::cxx::UniquePtr<sys::MallocMemoryAllocator>>);
impl Default for MallocMemoryAllocator {
fn default() -> Self {
Self::new()
}
}
impl MallocMemoryAllocator {
pub fn new() -> Self {
Self(UnsafeCell::new(sys::MallocMemoryAllocator_new()))
}
}
impl MemoryAllocator<'static> for MallocMemoryAllocator {
fn _cpp_ptr(&self) -> *const sys::MemoryAllocator {
let self_ = unsafe { &mut *self.0.get() }.as_mut().unwrap();
unsafe { sys::MallocMemoryAllocator_as_memory_allocator(self_) }
}
fn _into_unique_ptr(self) -> sys::cxx::UniquePtr<sys::MemoryAllocator>
where
Self: Sized,
{
sys::MallocMemoryAllocator_into_memory_allocator_unique_ptr(self.0.into_inner())
}
private_impl! {}
}
}
pub struct HierarchicalAllocator<'a>(pub(crate) sys::HierarchicalAllocator, PhantomData<&'a ()>);
impl<'a> HierarchicalAllocator<'a> {
pub fn new(buffers: &'a mut [Span<'a, u8>]) -> Self {
let buffers = unsafe {
std::mem::transmute::<&'a mut [Span<'a, u8>], &'a mut [sys::SpanU8]>(buffers)
};
let buffers = sys::SpanSpanU8 {
data: buffers.as_mut_ptr(),
len: buffers.len(),
};
Self(
unsafe { sys::executorch_HierarchicalAllocator_new(buffers) },
PhantomData,
)
}
}
impl Drop for HierarchicalAllocator<'_> {
fn drop(&mut self) {
unsafe { sys::executorch_HierarchicalAllocator_destructor(&mut self.0) };
}
}
pub struct MemoryManager<'a>(
pub(crate) UnsafeCell<sys::MemoryManager>,
PhantomData<&'a ()>,
);
impl<'a> MemoryManager<'a> {
pub fn new(
method_allocator: &'a dyn MemoryAllocator<'a>,
planned_memory: Option<&'a mut HierarchicalAllocator>,
temp_allocator: Option<&'a dyn MemoryAllocator<'a>>,
) -> Self {
let planned_memory = planned_memory
.map(|x| &mut x.0 as *mut _)
.unwrap_or(ptr::null_mut());
let temp_allocator = temp_allocator
.map(|x| x._cpp_ptr().cast_mut())
.unwrap_or(ptr::null_mut());
Self(
UnsafeCell::new(unsafe {
sys::executorch_MemoryManager_new(
method_allocator._cpp_ptr().cast_mut(),
planned_memory,
temp_allocator,
)
}),
PhantomData,
)
}
}
#[repr(transparent)]
pub struct Storage<T: Storable>(MaybeUninit<T::__Storage>, PhantomPinned);
impl<T: Storable> Default for Storage<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Storable> Storage<T> {
pub(crate) const fn new() -> Self {
Self(MaybeUninit::uninit(), PhantomPinned)
}
pub(crate) fn as_mut_ptr(&mut self) -> *mut T::__Storage {
self.0.as_mut_ptr()
}
}
#[macro_export]
macro_rules! storage {
($t:ty) => {
core::pin::pin!($crate::memory::Storage::<$t>::default())
};
($t:ty, [$n:expr]) => {
core::pin::pin!([0; $n].map(|_| $crate::memory::Storage::<$t>::default()))
};
($t:ty, ($n:expr)) => {{
let n = $n;
let mut vec = $crate::__private::alloc::Vec::with_capacity(n);
vec.resize_with(n, || $crate::memory::Storage::<$t>::default());
std::pin::Pin::from(vec.into_boxed_slice())
}};
}
pub trait Storable {
#[doc(hidden)]
type __Storage;
}
macro_rules! impl_default_storable {
($($t:ty),*) => {
$(
impl Storable for $t {
type __Storage = $t;
}
)*
};
}
impl_default_storable!(i8, i16, i32, i64, u8, u16, u32, u64, f32, f64);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn buffer_memory_allocator() {
let mut buffer: [u8; 16384] = [0; 16384];
let allocator = BufferMemoryAllocator::new(&mut buffer);
let allocator_init = |size| {
let buffer = allocator.allocate_raw(size, 1).unwrap();
BufferMemoryAllocator::new(buffer)
};
test_memory_allocator(allocator_init, true);
}
#[cfg(feature = "std")]
#[test]
fn malloc_memory_allocator() {
let mut idx = 0;
let allocator_init = |_size| {
idx += 1;
if idx % 2 == 0 {
MallocMemoryAllocator::default()
} else {
MallocMemoryAllocator::new()
}
};
test_memory_allocator(allocator_init, false);
}
fn test_memory_allocator<'a, T>(mut allocator_init: impl FnMut(usize) -> T, is_bounded: bool)
where
T: super::MemoryAllocator<'a>,
{
let sizes = [1, 2, 4, 5, 8, 13, 31];
let alignments = [1, 2, 4, 8, 16, 32];
let allocations = sizes
.into_iter()
.flat_map(|size| alignments.map(|alignment| (size, alignment)));
let raw_allocations_size = allocations.clone().map(|(size, _)| size).sum::<usize>() * 2;
let allocator = allocator_init(raw_allocations_size);
for (size, alignment) in allocations.clone() {
let allocation = allocator.allocate_raw(size, alignment).unwrap_or_else(|| {
panic!("Failed to allocate {size} bytes with alignment {alignment}")
});
assert_eq!(allocation.len(), size);
assert_eq!(allocation.as_ptr() as usize % alignment, 0);
}
if is_bounded {
let allocator = allocator_init(0);
assert!(allocator.allocate_raw(5, 8).is_none());
}
let allocator = allocator_init(1024);
assert!(allocator.allocate::<[u8; 1]>().is_some());
assert!(allocator.allocate::<[u8; 2]>().is_some());
assert!(allocator.allocate::<[u8; 4]>().is_some());
assert!(allocator.allocate::<[u8; 8]>().is_some());
assert!(allocator.allocate::<[f64; 15]>().is_some());
let allocator = allocator_init(1024);
assert!(allocator.allocate_pinned::<[u8; 1]>().is_some());
assert!(allocator.allocate_pinned::<[u8; 2]>().is_some());
assert!(allocator.allocate_pinned::<[u8; 4]>().is_some());
assert!(allocator.allocate_pinned::<[u8; 8]>().is_some());
assert!(allocator.allocate_pinned::<[f64; 15]>().is_some());
let allocator = allocator_init(4096);
for sizes in sizes {
let arr = allocator.allocate_arr::<u8>(sizes).unwrap();
assert_eq!(arr.len(), sizes);
assert!(arr.iter().all(|&x| x == 0));
let arr = allocator.allocate_arr::<f32>(sizes).unwrap();
assert_eq!(arr.len(), sizes);
assert!(arr.iter().all(|&x| x == 0.0));
}
let allocator = allocator_init(4096);
for sizes in sizes {
let arr = allocator.allocate_arr_fn(sizes, |i| i as u8).unwrap();
assert_eq!(arr.len(), sizes);
assert!(arr.iter().enumerate().all(|(i, &x)| x == i as u8));
let arr = allocator.allocate_arr_fn(sizes, |i| i as f32).unwrap();
assert_eq!(arr.len(), sizes);
assert!(arr.iter().enumerate().all(|(i, &x)| x == i as f32));
}
}
#[test]
fn storage_macro() {
let _: std::pin::Pin<&mut super::Storage<i32>> = storage!(i32);
let s: std::pin::Pin<&mut [super::Storage<i32>; 3]> = storage!(i32, [3]);
assert_eq!(s.len(), 3);
#[cfg(feature = "std")]
{
let dynamic_size = 2 + std::env::var("unknown-at-compile-time").is_ok() as usize;
let s: std::pin::Pin<crate::alloc::Box<[super::Storage<i32>]>> =
storage!(i32, (dynamic_size));
assert_eq!(s.len(), dynamic_size);
}
}
}