#![no_std]
use core::{
any::TypeId,
fmt,
hint::unreachable_unchecked,
marker::PhantomData,
mem::{self, MaybeUninit},
};
pub struct PointerArray<'a, const N: usize = 4> {
queue: [MaybeUninit<(TypeId, RefVariant)>; N],
head: usize,
_marker: PhantomData<&'a ()>,
}
impl Default for PointerArray<'_> {
fn default() -> Self {
Self::new()
}
}
impl<'a, const N: usize> PointerArray<'a, N> {
pub const fn new() -> Self {
PointerArray {
queue: unsafe { MaybeUninit::uninit().assume_init() },
head: 0,
_marker: PhantomData,
}
}
}
pub struct Full<'a, const N: usize>(PointerArray<'a, N>);
impl<'a, const N: usize> Full<'a, N> {
pub fn into_inner(self) -> PointerArray<'a, N> {
self.0
}
}
impl<const N: usize> fmt::Debug for Full<'_, N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Full(..)")
}
}
impl<const N: usize> fmt::Display for Full<'_, N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "PointerArray is full")
}
}
impl<'a, const N: usize> PointerArray<'a, N> {
pub fn insert_mut<'b, T: 'static>(
mut self,
value: &'b mut T,
) -> Result<PointerArray<'b, N>, Full<'a, N>>
where
'a: 'b,
{
let (id, opt) = self.try_find_one::<T>();
match opt {
Some(ptr) => {
let _ = mem::replace(ptr, value.into());
}
None => {
if self.head == N {
return Err(Full(self));
}
unsafe { self.write(id, value.into()) }
}
}
Ok(self)
}
pub fn insert_ref<'b, T: 'static>(
mut self,
value: &'b T,
) -> Result<PointerArray<'b, N>, Full<'a, N>>
where
'a: 'b,
{
let (id, opt) = self.try_find_one::<T>();
match opt {
Some(ptr) => {
let _ = mem::replace(ptr, value.into());
}
None => {
if self.head == N {
return Err(Full(self));
}
unsafe { self.write(id, value.into()) }
}
}
Ok(self)
}
pub fn remove_ref<T: 'static>(&mut self) -> Option<&'a T> {
self.remove::<T, _, _>(|r| match r {
RefVariant::Ref(_) => match mem::replace(r, RefVariant::None) {
RefVariant::Ref(t) => Some(unsafe { &*(t as *const T) }),
_ => unsafe { unreachable_unchecked() },
},
_ => None,
})
}
pub fn remove_mut<T: 'static>(&mut self) -> Option<&'a mut T> {
self.remove::<T, _, _>(|r| match r {
RefVariant::Mut(_) => match mem::replace(r, RefVariant::None) {
RefVariant::Mut(t) => Some(unsafe { &mut *(t as *mut T) }),
_ => unsafe { unreachable_unchecked() },
},
_ => None,
})
}
pub fn get<T: 'static>(&self) -> Option<&'a T> {
let id = TypeId::of::<T>();
self.try_find_init(&id).and_then(|r| match r {
RefVariant::Ref(t) => Some(unsafe { &*(*t as *const T) }),
RefVariant::Mut(t) => Some(unsafe { &*(*t as *mut T as *const T) }),
RefVariant::None => None,
})
}
pub fn get_mut<T: 'static>(&mut self) -> Option<&'a mut T> {
let id = TypeId::of::<T>();
self.try_find_init_mut(&id).and_then(|r| match r {
RefVariant::Mut(t) => Some(unsafe { &mut *(*t as *mut T) }),
_ => None,
})
}
fn try_find_one<T: 'static>(&mut self) -> (TypeId, Option<&mut RefVariant>) {
let id = TypeId::of::<T>();
let opt = self.try_find_init_mut(&id);
(id, opt)
}
fn remove<T, F, R>(&mut self, func: F) -> Option<R>
where
T: 'static,
R: 'a,
F: for<'r> Fn(&'r mut RefVariant) -> Option<R>,
{
let id = TypeId::of::<T>();
self.try_find_init_mut(&id).and_then(func)
}
fn try_find_init(&self, id: &TypeId) -> Option<&RefVariant> {
self.queue.iter().take(self.head).find_map(|v| {
let (i, opt) = unsafe { v.assume_init_ref() };
(i == id).then(|| opt)
})
}
fn try_find_init_mut(&mut self, id: &TypeId) -> Option<&mut RefVariant> {
self.queue.iter_mut().take(self.head).find_map(|v| {
let (i, opt) = unsafe { v.assume_init_mut() };
(i == id).then(|| opt)
})
}
unsafe fn write(&mut self, id: TypeId, value: RefVariant) {
self.queue.get_unchecked_mut(self.head).write((id, value));
self.head += 1;
}
}
enum RefVariant {
None,
Mut(*mut ()),
Ref(*const ()),
}
impl<T> From<&T> for RefVariant {
fn from(t: &T) -> Self {
Self::Ref(t as *const T as *const ())
}
}
impl<T> From<&mut T> for RefVariant {
fn from(t: &mut T) -> Self {
Self::Mut(t as *mut T as *mut ())
}
}
#[cfg(test)]
mod test {
use super::*;
extern crate alloc;
use alloc::{boxed::Box, string::String, vec, vec::Vec};
#[test]
fn test() {
let map = PointerArray::default();
let mut s = String::from("hello,string!");
let s2 = String::from("hello,box!").into_boxed_str();
let map = map.insert_mut(&mut s).unwrap();
let map = map.insert_ref(&s2).unwrap();
fn scope(mut map: PointerArray<'_>) {
let s2 = map.remove_ref::<Box<str>>().unwrap();
let mut v = vec![String::from("hello,string!")];
let s = map.remove_mut::<String>().unwrap();
let mut map = map.insert_mut(&mut v).unwrap();
assert_eq!(s, "hello,string!");
let v = map.remove_mut::<Vec<String>>().unwrap();
assert_eq!(s, &v.pop().unwrap());
assert_eq!(&**s2, "hello,box!");
}
assert_eq!(&*s2, "hello,box!");
scope(map);
assert_eq!(s, "hello,string!");
}
#[test]
fn out_of_bound() {
let map = PointerArray::<1>::new();
let mut s = String::from("hello,string!");
let b = String::from("hello,box!").into_boxed_str();
let map = map.insert_mut(&mut s).unwrap();
assert!(map.insert_ref(&b).is_err());
}
#[test]
fn error_retake() {
let map = PointerArray::<1>::new();
let mut s = String::from("hello,string!");
let b = String::from("hello,box!").into_boxed_str();
let mut map = map
.insert_mut(&mut s)
.unwrap()
.insert_ref(&b)
.err()
.unwrap()
.into_inner();
assert_eq!(map.remove_mut::<String>().unwrap(), "hello,string!");
}
#[test]
fn ref_variant() {
let map = PointerArray::<2>::new();
let mut s = String::from("hello,string!");
let b = String::from("hello,box!").into_boxed_str();
let mut map = map.insert_mut(&mut s).unwrap().insert_ref(&b).unwrap();
assert!(map.remove_ref::<String>().is_none());
assert!(map.remove_mut::<Box<str>>().is_none());
assert!(map.remove_mut::<String>().is_some());
assert!(map.remove_ref::<Box<str>>().is_some());
}
#[test]
fn get() {
let map = PointerArray::<2>::new();
let mut s = String::from("hello,string!");
let b = String::from("hello,box!").into_boxed_str();
let mut map = map.insert_mut(&mut s).unwrap().insert_ref(&b).unwrap();
assert!(map.get::<String>().is_some());
assert!(map.get_mut::<String>().is_some());
assert!(map.get::<Box<str>>().is_some());
assert!(map.get_mut::<Box<str>>().is_none());
}
}