use std::{ any::{ Any, TypeId}, pin::Pin, hash::BuildHasherDefault };
use crate::prelude::{HashMap,Mutex};
pub struct State<'r, T: Send + Sync + 'static>(&'r T);
impl<'r, T: Send + Sync + 'static> State<'r, T> {
#[inline(always)]
pub fn inner(&self) -> &'r T {
self.0
}
}
impl<T: Send + Sync + 'static> std::ops::Deref for State<'_, T> {
type Target = T;
#[inline(always)]
fn deref(&self) -> &T {
self.0
}
}
impl<T: Send + Sync + 'static> Clone for State<'_, T> {
fn clone(&self) -> Self {
State(self.0)
}
}
impl<T: Send + Sync + 'static + PartialEq> PartialEq for State<'_, T> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<T: Send + Sync + std::fmt::Debug> std::fmt::Debug for State<'_, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("State").field(&self.0).finish()
}
}
#[derive(Default)]
struct IdentHash(u64);
impl std::hash::Hasher for IdentHash {
fn finish(&self) -> u64 {
self.0
}
fn write(&mut self, bytes: &[u8]) {
for byte in bytes {
self.write_u8(*byte);
}
}
fn write_u8(&mut self, i: u8) {
self.0 = (self.0 << 8) | (i as u64);
}
fn write_u64(&mut self, i: u64) {
self.0 = i;
}
}
type TypeIdMap = HashMap<TypeId, Pin<Box<dyn Any + Sync + Send>>, BuildHasherDefault<IdentHash>>;
#[derive(Debug)]
pub struct StateManager {
map: Mutex<TypeIdMap>,
}
impl StateManager {
pub fn new() -> Self {
Self {
map: Default::default(),
}
}
pub fn set<T: Send + Sync + 'static>(&self, state: T) -> bool {
let mut map = self.map.lock();
let type_id = TypeId::of::<T>();
let already_set = map.contains_key(&type_id);
if !already_set {
let ptr = Box::new(state) as Box<dyn Any + Sync + Send>;
let pinned_ptr = Box::into_pin(ptr);
map.insert(
type_id,
pinned_ptr,
);
}
!already_set
}
pub unsafe fn unmanage<T: Send + Sync + 'static>(&self) -> Option<T> {
let mut map = self.map.lock();
let type_id = TypeId::of::<T>();
let pinned_ptr = map.remove(&type_id)?;
let ptr = unsafe { Pin::into_inner_unchecked(pinned_ptr) };
let value = unsafe {
ptr
.downcast::<T>()
.unwrap_unchecked()
};
Some(*value)
}
pub fn get<T: Send + Sync + 'static>(&self) -> State<'_, T> {
self
.try_get()
.unwrap_or_else(|| panic!("state not found for type {}", std::any::type_name::<T>()))
}
pub fn try_get<T: Send + Sync + 'static>(&self) -> Option<State<'_, T>> {
let map = self.map.lock();
let type_id = TypeId::of::<T>();
let ptr = map.get(&type_id)?;
let value = unsafe {
ptr
.downcast_ref::<T>()
.unwrap_unchecked()
};
let v_ref = unsafe { &*(value as *const T) };
Some(State(v_ref))
}
}