#![allow(dead_code)]
use crate::factory::AnyFactory;
use ahash::RandomState;
use dashmap::DashMap;
use std::any::{Any, TypeId};
use std::sync::Arc;
#[cfg(feature = "perfect-hash")]
use std::hash::{Hash, Hasher};
#[inline]
pub(crate) unsafe fn downcast_arc_unchecked<T: Send + Sync + 'static>(
arc: Arc<dyn Any + Send + Sync>,
) -> Arc<T> {
let ptr = Arc::into_raw(arc);
unsafe { Arc::from_raw(ptr as *const T) }
}
pub struct ServiceStorage {
factories: DashMap<TypeId, AnyFactory, RandomState>,
parent: Option<Arc<ServiceStorage>>,
}
impl ServiceStorage {
#[inline]
pub fn new() -> Self {
Self {
factories: DashMap::with_capacity_and_hasher_and_shard_amount(
0,
RandomState::new(),
8, ),
parent: None,
}
}
#[inline]
pub fn with_capacity(capacity: usize) -> Self {
let shard_amount = if capacity <= 16 {
8
} else if capacity <= 64 {
16
} else {
32
};
Self {
factories: DashMap::with_capacity_and_hasher_and_shard_amount(
capacity,
RandomState::new(),
shard_amount,
),
parent: None,
}
}
#[inline]
pub fn with_parent(parent: Arc<ServiceStorage>) -> Self {
Self {
factories: DashMap::with_capacity_and_hasher_and_shard_amount(
0,
RandomState::new(),
4, ),
parent: Some(parent),
}
}
#[inline]
pub(crate) fn insert(&self, type_id: TypeId, factory: AnyFactory) {
self.factories.insert(type_id, factory);
}
#[inline]
pub fn contains(&self, type_id: &TypeId) -> bool {
self.factories.contains_key(type_id)
}
#[inline]
pub fn resolve(&self, type_id: &TypeId) -> Option<Arc<dyn Any + Send + Sync>> {
self.factories.get(type_id).map(|f| f.resolve())
}
#[inline]
pub fn get<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
self.resolve(&TypeId::of::<T>()).map(|any| {
unsafe { downcast_arc_unchecked(any) }
})
}
#[inline]
pub fn get_with_transient_flag<T: Send + Sync + 'static>(&self) -> Option<(Arc<T>, bool)> {
let type_id = TypeId::of::<T>();
self.factories.get(&type_id).map(|factory| {
let is_transient = factory.is_transient();
let service = factory.resolve();
let typed = unsafe { downcast_arc_unchecked(service) };
(typed, is_transient)
})
}
#[inline]
pub fn resolve_from_chain(&self, type_id: &TypeId) -> Option<Arc<dyn Any + Send + Sync>> {
if let Some(service) = self.resolve(type_id) {
return Some(service);
}
let mut current = self.parent.as_ref();
while let Some(storage) = current {
if let Some(service) = storage.resolve(type_id) {
return Some(service);
}
current = storage.parent.as_ref();
}
None
}
#[inline]
pub fn contains_in_chain(&self, type_id: &TypeId) -> bool {
if self.contains(type_id) {
return true;
}
let mut current = self.parent.as_ref();
while let Some(storage) = current {
if storage.contains(type_id) {
return true;
}
current = storage.parent.as_ref();
}
false
}
#[inline]
pub fn parent(&self) -> Option<&Arc<ServiceStorage>> {
self.parent.as_ref()
}
#[inline]
pub fn child(self: &Arc<Self>) -> Self {
Self {
factories: DashMap::with_capacity_and_hasher_and_shard_amount(0, RandomState::new(), 8),
parent: Some(Arc::clone(self)),
}
}
#[inline]
pub fn len(&self) -> usize {
self.factories.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.factories.is_empty()
}
#[inline]
pub fn clear(&self) {
self.factories.clear();
}
#[inline]
pub fn has_parent(&self) -> bool {
self.parent.is_some()
}
#[inline]
pub fn remove(&self, type_id: &TypeId) -> bool {
self.factories.remove(type_id).is_some()
}
pub fn type_ids(&self) -> Vec<TypeId> {
self.factories.iter().map(|r| *r.key()).collect()
}
#[inline]
pub fn is_transient(&self, type_id: &TypeId) -> bool {
self.factories
.get(type_id)
.map(|f| f.is_transient())
.unwrap_or(false)
}
}
impl Default for ServiceStorage {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for ServiceStorage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ServiceStorage")
.field("count", &self.len())
.finish()
}
}
#[cfg(feature = "perfect-hash")]
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
struct HashableTypeId(TypeId);
#[cfg(feature = "perfect-hash")]
impl Hash for HashableTypeId {
fn hash<H: Hasher>(&self, state: &mut H) {
let type_id_bits: u64 = unsafe { std::mem::transmute_copy(&self.0) };
type_id_bits.hash(state);
}
}
#[cfg(feature = "perfect-hash")]
pub struct FrozenStorage {
mphf: boomphf::Mphf<HashableTypeId>,
factories: Vec<AnyFactory>,
type_ids: Vec<TypeId>,
parent: Option<Arc<FrozenStorage>>,
}
#[cfg(feature = "perfect-hash")]
impl FrozenStorage {
pub fn from_storage(storage: &ServiceStorage) -> Self {
let entries: Vec<(TypeId, AnyFactory)> = storage
.factories
.iter()
.map(|r| (*r.key(), r.value().clone()))
.collect();
let n = entries.len();
if n == 0 {
return Self {
mphf: boomphf::Mphf::new(1.7, &[]),
factories: Vec::new(),
type_ids: Vec::new(),
parent: storage
.parent
.as_ref()
.map(|p| Arc::new(Self::from_storage(p))),
};
}
let hashable_ids: Vec<HashableTypeId> =
entries.iter().map(|(id, _)| HashableTypeId(*id)).collect();
let mphf = boomphf::Mphf::new(1.7, &hashable_ids);
let mut factories: Vec<Option<AnyFactory>> = (0..n).map(|_| None).collect();
let mut indexed_type_ids: Vec<Option<TypeId>> = (0..n).map(|_| None).collect();
for (type_id, factory) in entries {
let idx = mphf.hash(&HashableTypeId(type_id)) as usize;
factories[idx] = Some(factory);
indexed_type_ids[idx] = Some(type_id);
}
let factories: Vec<AnyFactory> = factories.into_iter().flatten().collect();
let type_ids: Vec<TypeId> = indexed_type_ids.into_iter().flatten().collect();
let parent = storage
.parent
.as_ref()
.map(|p| Arc::new(Self::from_storage(p)));
Self {
mphf,
factories,
type_ids,
parent,
}
}
#[inline]
pub fn resolve(&self, type_id: &TypeId) -> Option<Arc<dyn Any + Send + Sync>> {
let hashable = HashableTypeId(*type_id);
let idx = self.mphf.try_hash(&hashable)? as usize;
if idx >= self.factories.len() {
return None;
}
if self.type_ids[idx] != *type_id {
return None;
}
Some(self.factories[idx].resolve())
}
#[inline]
pub fn contains(&self, type_id: &TypeId) -> bool {
let hashable = HashableTypeId(*type_id);
if let Some(idx) = self.mphf.try_hash(&hashable) {
let idx = idx as usize;
idx < self.type_ids.len() && self.type_ids[idx] == *type_id
} else {
false
}
}
#[inline]
pub fn resolve_from_chain(&self, type_id: &TypeId) -> Option<Arc<dyn Any + Send + Sync>> {
if let Some(service) = self.resolve(type_id) {
return Some(service);
}
let mut current = self.parent.as_ref();
while let Some(storage) = current {
if let Some(service) = storage.resolve(type_id) {
return Some(service);
}
current = storage.parent.as_ref();
}
None
}
#[inline]
pub fn contains_in_chain(&self, type_id: &TypeId) -> bool {
if self.contains(type_id) {
return true;
}
let mut current = self.parent.as_ref();
while let Some(storage) = current {
if storage.contains(type_id) {
return true;
}
current = storage.parent.as_ref();
}
false
}
#[inline]
pub fn len(&self) -> usize {
self.factories.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.factories.is_empty()
}
#[inline]
pub fn is_transient(&self, type_id: &TypeId) -> bool {
let hashable = HashableTypeId(*type_id);
if let Some(idx) = self.mphf.try_hash(&hashable) {
let idx = idx as usize;
if idx < self.factories.len() && self.type_ids[idx] == *type_id {
return self.factories[idx].is_transient();
}
}
false
}
}
#[cfg(feature = "perfect-hash")]
impl std::fmt::Debug for FrozenStorage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FrozenStorage")
.field("count", &self.len())
.field("has_parent", &self.parent.is_some())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone)]
struct TestService {
value: i32,
}
#[cfg(feature = "perfect-hash")]
#[test]
fn test_frozen_storage() {
let storage = ServiceStorage::new();
storage.insert(
TypeId::of::<TestService>(),
AnyFactory::singleton(TestService { value: 42 }),
);
storage.insert(TypeId::of::<i32>(), AnyFactory::singleton(123i32));
storage.insert(
TypeId::of::<String>(),
AnyFactory::singleton("hello".to_string()),
);
let frozen = FrozenStorage::from_storage(&storage);
assert!(frozen.contains(&TypeId::of::<TestService>()));
assert!(frozen.contains(&TypeId::of::<i32>()));
assert!(frozen.contains(&TypeId::of::<String>()));
assert!(!frozen.contains(&TypeId::of::<bool>()));
let service = frozen.resolve(&TypeId::of::<TestService>()).unwrap();
let typed: Arc<TestService> = unsafe { downcast_arc_unchecked(service) };
assert_eq!(typed.value, 42);
assert_eq!(frozen.len(), 3);
}
#[cfg(feature = "perfect-hash")]
#[test]
fn test_frozen_storage_empty() {
let storage = ServiceStorage::new();
let frozen = FrozenStorage::from_storage(&storage);
assert!(frozen.is_empty());
assert_eq!(frozen.len(), 0);
assert!(!frozen.contains(&TypeId::of::<TestService>()));
}
#[test]
fn test_storage_insert_and_get() {
let storage = ServiceStorage::new();
let type_id = TypeId::of::<TestService>();
storage.insert(type_id, AnyFactory::singleton(TestService { value: 42 }));
let service = storage.get::<TestService>().unwrap();
assert_eq!(service.value, 42);
}
#[test]
fn test_storage_contains() {
let storage = ServiceStorage::new();
let type_id = TypeId::of::<TestService>();
assert!(!storage.contains(&type_id));
storage.insert(type_id, AnyFactory::singleton(TestService { value: 0 }));
assert!(storage.contains(&type_id));
}
#[test]
fn test_storage_remove() {
let storage = ServiceStorage::new();
let type_id = TypeId::of::<TestService>();
storage.insert(type_id, AnyFactory::singleton(TestService { value: 0 }));
assert!(storage.contains(&type_id));
storage.remove(&type_id);
assert!(!storage.contains(&type_id));
}
#[test]
fn test_unchecked_downcast_soundness() {
use std::any::Any;
let original: Arc<TestService> = Arc::new(TestService { value: 42 });
let original_ptr = Arc::as_ptr(&original);
let any_arc: Arc<dyn Any + Send + Sync> = original;
let recovered: Arc<TestService> = unsafe { super::downcast_arc_unchecked(any_arc) };
assert_eq!(original_ptr, Arc::as_ptr(&recovered));
assert_eq!(recovered.value, 42);
}
}