use std::any::TypeId;
use std::ptr;
use std::sync::{Arc, RwLock};
use crate::error::Result;
use crate::utils::map::TypedMap;
use super::{BoxedFactory, BoxedResolver, Container, Factory, Resolver};
#[derive(Debug)]
pub struct ResolvedRaw {
ptr: *const (),
clone_fn: unsafe fn(*const ()) -> *const (),
drop_fn: unsafe fn(*const ()),
}
unsafe impl Send for ResolvedRaw {}
unsafe impl Sync for ResolvedRaw {}
unsafe fn clone_arc<T>(ptr: *const ()) -> *const ()
where
T: 'static,
{
unsafe { Arc::increment_strong_count(ptr as *const T) };
ptr
}
unsafe fn drop_arc<T>(ptr: *const ())
where
T: 'static,
{
drop(unsafe { Arc::from_raw(ptr as *const T) });
}
impl<T> From<Arc<T>> for ResolvedRaw
where
T: 'static + Sized,
{
fn from(arc: Arc<T>) -> Self {
ResolvedRaw {
ptr: Arc::into_raw(arc) as *const (),
clone_fn: clone_arc::<T>,
drop_fn: drop_arc::<T>,
}
}
}
impl Clone for ResolvedRaw {
fn clone(&self) -> Self {
ResolvedRaw {
ptr: unsafe { (self.clone_fn)(self.ptr) },
clone_fn: self.clone_fn,
drop_fn: self.drop_fn,
}
}
}
impl Drop for ResolvedRaw {
fn drop(&mut self) {
if self.ptr.is_null() {
return;
}
unsafe { (self.drop_fn)(self.ptr) };
}
}
impl ResolvedRaw {
unsafe fn recover<T>(mut self) -> Arc<T> {
let ptr = self.ptr as *const T;
self.ptr = ptr::null::<()>();
unsafe { Arc::from_raw(ptr) }
}
}
trait IntoResolvedRaw {
fn into_resolved_raw(self) -> ResolvedRaw;
}
impl<T> IntoResolvedRaw for T
where
T: 'static + Sized,
{
fn into_resolved_raw(self) -> ResolvedRaw {
let arc = Arc::new(self);
ResolvedRaw {
ptr: Arc::into_raw(arc) as *const (),
clone_fn: clone_arc::<T>,
drop_fn: drop_arc::<T>,
}
}
}
#[derive(Clone)]
struct MapIntoResolvedRaw<T>(T);
impl<Rg, R> Resolver<Rg> for MapIntoResolvedRaw<R>
where
R: Resolver<Rg> + Clone + Send + Sync + 'static,
R::Return: IntoResolvedRaw + 'static,
{
type Return = ResolvedRaw;
fn resolve(&self, registry: &Rg) -> Result<ResolvedRaw> {
self
.0
.resolve(registry)
.map(|resolved| resolved.into_resolved_raw())
}
}
struct RawResolver<Registry>(BoxedResolver<Registry, ResolvedRaw>);
impl<Registry> RawResolver<Registry> {
fn new<R>(resolver: R) -> Self
where
R: Resolver<Registry> + Clone + Send + Sync + 'static,
R::Return: IntoResolvedRaw + 'static,
{
Self(BoxedResolver::new(MapIntoResolvedRaw(resolver)))
}
}
#[derive(Default, Clone)]
pub struct RawContainer {
map: Arc<RwLock<TypedMap<RawResolver<Self>>>>,
cache: Arc<RwLock<TypedMap<ResolvedRaw>>>,
}
impl RawContainer {
pub fn new() -> Self {
Self::default()
}
}
impl Container for RawContainer {
fn add<F, T, Deps>(&mut self, factory: F)
where
F: Factory<T, Deps>,
T: 'static,
Deps: 'static,
{
let id: TypeId = TypeId::of::<T>();
self
.map
.write()
.unwrap()
.insert(id, RawResolver::new(BoxedFactory::new(factory)));
self.cache.write().unwrap().remove(&id);
}
fn bind<T>(&mut self, instance: T)
where
T: 'static,
{
let id: TypeId = TypeId::of::<T>();
self
.cache
.write()
.unwrap()
.insert(id, instance.into_resolved_raw());
}
fn get<T>(&self) -> Result<Arc<T>>
where
T: 'static,
{
use crate::error::TypedError;
let id: TypeId = TypeId::of::<T>();
let mut entry = self.cache.read().unwrap().get(&id).cloned();
if entry.is_none() {
let resolver = self
.map
.read()
.unwrap()
.get(&id)
.ok_or_else(|| {
super::TypeNotRegistered::error(format!(
"Type '{}' is not registered in the container",
std::any::type_name::<T>()
))
})?
.0
.clone();
let raw = resolver.resolve(self)?;
self.cache.write().unwrap().entry(id).or_insert(raw.clone());
entry = Some(raw);
}
entry.map(|raw| unsafe { raw.recover() }).ok_or_else(|| {
unreachable!(
"Entry must be Some at this point as we either got it from cache or just inserted it"
)
})
}
fn drop<T>(&mut self)
where
T: 'static,
{
let id: TypeId = TypeId::of::<T>();
self.cache.write().unwrap().remove(&id);
}
}
#[cfg(test)]
mod tests {
use super::{RawContainer, ResolvedRaw};
use crate::container::{Container, DependencyResolutionFailed, TypeNotRegistered, Value};
use std::sync::Arc;
#[derive(PartialEq, Debug, Clone)]
struct Service;
#[derive(PartialEq, Debug)]
struct App {
service: Arc<Service>,
}
#[test]
fn test_create_from_arc() {
let value = Arc::new(42);
assert_eq!(Arc::strong_count(&value), 1);
let raw: ResolvedRaw = value.clone().into();
assert_eq!(Arc::strong_count(&value), 2);
let recovered = unsafe { raw.recover::<i32>() };
assert_eq!(Arc::strong_count(&value), 2);
assert_eq!(*recovered, 42);
}
#[test]
fn test_resolved_raw_clone_increments_strong_count() {
use std::mem::ManuallyDrop;
let value = Arc::new(123_i32);
assert_eq!(Arc::strong_count(&value), 1);
let raw: ManuallyDrop<ResolvedRaw> = ManuallyDrop::new(value.clone().into());
assert_eq!(Arc::strong_count(&value), 2);
let _raw2: ManuallyDrop<ResolvedRaw> = ManuallyDrop::new((*raw).clone());
assert_eq!(Arc::strong_count(&value), 3,);
}
#[test]
fn test_all_together() {
let mut container = RawContainer::default();
container.add(Value(Service));
container.add(|service: Arc<Service>| App { service });
let app = container.get::<App>();
let app2 = container.get::<App>();
assert!(app.is_ok());
assert_eq!(Arc::into_raw(app.unwrap()), Arc::into_raw(app2.unwrap()));
}
#[test]
fn test_drop_removes_instance_from_cache_without_factory() {
let mut container = RawContainer::default();
container.bind(Service);
let service = container.get::<Service>();
assert!(service.is_ok());
container.drop::<Service>();
let service_after_drop = container.get::<Service>();
assert!(service_after_drop.is_err());
}
#[test]
fn test_drop_allows_factory_to_create_new_instance() {
use std::sync::atomic::{AtomicU32, Ordering};
#[derive(Debug, Clone)]
struct CountedService {
id: u32,
}
let mut container = RawContainer::default();
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
container.add(move || {
let id = counter_clone.fetch_add(1, Ordering::SeqCst);
CountedService { id }
});
let service1 = container
.get::<CountedService>()
.expect("First get should create and cache instance");
assert_eq!(service1.id, 0);
let service1_again = container
.get::<CountedService>()
.expect("Second get should return cached instance");
assert_eq!(service1_again.id, 0);
assert!(Arc::ptr_eq(&service1, &service1_again));
container.drop::<CountedService>();
let service2 = container
.get::<CountedService>()
.expect("Third get should re-resolve from factory");
assert_eq!(
service2.id, 1,
"Factory should create new instance after drop"
);
assert!(
!Arc::ptr_eq(&service1, &service2),
"New instance should be created, not reusing dropped one"
);
}
#[test]
fn test_type_not_registered_error() {
let container = RawContainer::new();
let result = container.get::<Service>();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.is(TypeNotRegistered));
assert!(err.message().contains("Service"));
}
#[test]
fn test_dependency_chain_error() {
#[derive(Debug)]
struct Config;
#[derive(Debug)]
struct ServiceWithConfig {
_config: Arc<Config>,
}
let mut container = RawContainer::new();
container.add(|config: Arc<Config>| ServiceWithConfig { _config: config });
let result = container.get::<ServiceWithConfig>();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.is(DependencyResolutionFailed));
assert!(err.message().contains("ServiceWithConfig"));
assert!(err.message().contains("Config"));
let source = err.source().expect("Should have source error");
assert!(source.to_string().contains("Config"));
}
#[test]
fn test_deep_dependency_chain_error() {
#[derive(Debug)]
struct Config;
#[derive(Debug)]
struct ServiceWithConfig {
_config: Arc<Config>,
}
#[derive(Debug)]
struct AppWithService {
_service: Arc<ServiceWithConfig>,
}
let mut container = RawContainer::new();
container.add(|config: Arc<Config>| ServiceWithConfig { _config: config });
container.add(|service: Arc<ServiceWithConfig>| AppWithService { _service: service });
let result = container.get::<AppWithService>();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message().contains("AppWithService"));
assert!(err.message().contains("ServiceWithConfig"));
let mut source_opt = err.source();
let mut depth = 0;
while let Some(source) = source_opt {
depth += 1;
source_opt = source.source();
}
assert_eq!(depth, 2, "Should have 2 levels in the error chain");
}
}