use crate::{
binary_search,
lifetime::{
DanglingCheckerResult, DanglingCheckerResults, LifetimeError, OutlivedLifetimeErrorVariants,
},
strategy::{Identifyable, Strategy},
untyped::{ArcAutoFreePointer, AutoFreePointer, FromArcAutoFreePointer, UntypedFn},
AllRegistered, AnyStrategy, InternalBuildResult, Registered, Resolvable, ServiceProducer,
TypeNamed, UntypedFnFactory, UntypedFnFactoryContext,
};
use abi_stable::std_types::{RArc, RVec};
use alloc::vec::Vec;
use core::{
any::{type_name, Any},
fmt,
fmt::{Debug, Formatter},
marker::PhantomData,
mem::swap,
};
use std::sync::OnceLock;
pub struct ServiceProvider<TS: Strategy + 'static = AnyStrategy> {
immutable_state: RArc<ServiceProviderImmutableState<TS>>,
service_states: RArc<ServiceProviderMutableState>,
#[cfg(debug_assertions)]
is_root: bool,
}
impl<TS: Strategy + 'static> Debug for ServiceProvider<TS> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_fmt(format_args!(
"ServiceProvider (services: {}, with_state: {})",
self.immutable_state.producers.len(),
self.service_states.shared_services.len()
))
}
}
#[cfg(debug_assertions)]
impl<TS: Strategy + 'static> Drop for ServiceProvider<TS> {
fn drop(&mut self) {
if !self.is_root {
return;
}
let mut swapped_service_states = RArc::new(ServiceProviderMutableState {
base: None,
shared_services: RVec::new(),
});
swap(&mut swapped_service_states, &mut self.service_states);
match RArc::try_unwrap(swapped_service_states) {
Ok(service_states) => {
let checkers: Vec<_> = service_states
.shared_services
.into_iter()
.filter_map(|c| {
let x = c.get()?;
let weak = x.inner.downgrade();
if weak.strong_count() > 0 {
Some(TypeNamed {
inner: weak,
type_name: x.type_name,
})
} else {
None
}
})
.collect();
let errors = checkers
.into_iter()
.filter_map(|x| {
(x.inner.strong_count() > 0).then(|| {
DanglingCheckerResult::new(x.inner.strong_count(), x.type_name.into())
})
})
.collect::<DanglingCheckerResults>();
if !errors.is_empty() {
unsafe {
(crate::MINFAC_ERROR_HANDLER)(&LifetimeError::new(
OutlivedLifetimeErrorVariants::SharedServices(errors),
))
};
}
}
Err(x) => unsafe {
let remaining_references = RArc::strong_count(&x) - 1;
crate::MINFAC_ERROR_HANDLER(&LifetimeError::new(
OutlivedLifetimeErrorVariants::WeakServiceProvider {
remaining_references,
},
));
},
}
}
}
impl<TS: Strategy + 'static> ServiceProvider<TS> {
pub fn resolve_unchecked<T: Resolvable<TS>>(&self) -> T::ItemPreChecked {
let precheck_key =
T::precheck(&self.immutable_state.types).expect("Resolve unkwnown service");
T::resolve_prechecked(self, &precheck_key)
}
pub fn get<T: Identifyable<TS::Id>>(&self) -> Option<T> {
self.resolve_item::<Registered<T>>()
}
pub fn get_all<T: Identifyable<TS::Id>>(&self) -> ServiceIterator<T, TS> {
self.resolve_item::<AllRegistered<T>>()
}
pub fn has(&self, identifier: &TS::Id) -> bool {
self.immutable_state.types.binary_search(identifier).is_ok()
}
pub fn resolve<T: Resolvable<TS>>(&self) -> Option<T> {
let precheck_key = T::precheck(&self.immutable_state.types).ok()?;
Some(T::resolve_prechecked_self(self, &precheck_key))
}
pub(crate) fn resolve_item<T: Resolvable<TS>>(&self) -> T::Item {
T::resolve(self)
}
pub(crate) fn get_producers(&self) -> &RVec<UntypedFn<TS>> {
&self.immutable_state.producers
}
pub(crate) fn new(
immutable_state: RArc<ServiceProviderImmutableState<TS>>,
shared_services: RVec<OnceLock<TypeNamed<ArcAutoFreePointer>>>,
base: Option<AutoFreePointer>,
) -> Self {
Self {
immutable_state,
service_states: RArc::new(ServiceProviderMutableState {
shared_services,
base,
}),
#[cfg(debug_assertions)]
is_root: true,
}
}
pub(crate) fn build_service_producer_for_base<T: Identifyable<TS::Id> + Clone + Send + Sync>(
) -> UntypedFnFactory<TS> {
extern "C" fn factory<
T: Identifyable<TS::Id> + Clone + 'static + Send + Sync,
TS: Strategy + 'static,
>(
stage_1_data: AutoFreePointer,
_ctx: &mut UntypedFnFactoryContext<TS>,
) -> InternalBuildResult<TS> {
extern "C" fn creator<
T: Identifyable<TS::Id> + Clone + 'static + Send + Sync,
TS: Strategy + 'static,
>(
provider: *const ServiceProvider<TS>,
_stage_2_data: *const AutoFreePointer,
) -> T {
let provider = unsafe { &*provider as &ServiceProvider<TS> };
match &provider.service_states.base {
Some(x) => unsafe { &*(x.get_pointer() as *const T) }.clone(),
None => panic!("Expected ServiceProviderFactory to set a value for `base`"),
}
}
Ok(UntypedFn::create(creator::<T, TS>, stage_1_data)).into()
}
UntypedFnFactory::no_alloc(core::ptr::null(), factory::<T, TS>)
}
pub(crate) fn get_or_initialize_pos<
T: Any + Send + Sync + Into<ArcAutoFreePointer> + FromArcAutoFreePointer,
TFn: Fn() -> T,
>(
&self,
index: usize,
initializer: TFn,
) -> T {
let pointer = self.service_states.shared_services[index].get_or_init(|| TypeNamed {
inner: initializer().into(),
type_name: type_name::<T>().into(),
});
unsafe { T::from_ref(&pointer.inner) }
}
}
pub struct WeakServiceProvider<TS: Strategy + 'static = AnyStrategy>(ServiceProvider<TS>);
impl<TS: Strategy + 'static> WeakServiceProvider<TS> {
pub(crate) unsafe fn clone_producers(&self) -> impl Iterator<Item = ServiceProducer<TS>> {
type OuterContextType<TS> = (&'static UntypedFn<TS>, &'static WeakServiceProvider<TS>);
let static_self = &*(self as *const Self);
static_self
.0
.immutable_state
.producers
.iter()
.zip(static_self.0.immutable_state.types.iter())
.map(move |(parent_producer, parent_type)| {
extern "C" fn factory<TS: Strategy + 'static>(
outer_ctx: AutoFreePointer,
_: &mut UntypedFnFactoryContext<TS>,
) -> InternalBuildResult<TS> {
let ptr = outer_ctx.get_pointer() as *const OuterContextType<TS>;
unsafe {
let (parent_producer, static_self) = &*ptr;
Ok(parent_producer.bind(&static_self.0)).into()
}
}
let factory =
UntypedFnFactory::boxed((parent_producer, static_self), factory::<TS>);
ServiceProducer::<TS>::new_with_type(factory, *parent_type)
})
}
pub fn resolve_unchecked<T: Resolvable<TS>>(&self) -> T::ItemPreChecked {
self.0.resolve_unchecked::<T>()
}
pub fn resolve<T: Resolvable<TS>>(&self) -> Option<T> {
self.0.resolve::<T>()
}
pub fn get<T: Identifyable<TS::Id>>(&self) -> Option<T> {
self.0.get::<T>()
}
pub fn get_all<T: Identifyable<TS::Id>>(&self) -> ServiceIterator<T, TS> {
self.0.get_all::<T>()
}
pub fn has(&self, identifier: &TS::Id) -> bool {
self.0.has(identifier)
}
}
impl<TS: Strategy + 'static> Clone for WeakServiceProvider<TS> {
fn clone(&self) -> Self {
Self(ServiceProvider::<TS> {
immutable_state: self.0.immutable_state.clone(),
service_states: self.0.service_states.clone(),
#[cfg(debug_assertions)]
is_root: false,
})
}
}
impl<'a, TS: Strategy + 'static> From<&'a ServiceProvider<TS>> for WeakServiceProvider<TS> {
fn from(provider: &'a ServiceProvider<TS>) -> Self {
WeakServiceProvider(ServiceProvider {
immutable_state: provider.immutable_state.clone(),
service_states: provider.service_states.clone(),
#[cfg(debug_assertions)]
is_root: false,
})
}
}
pub(crate) struct ServiceProviderImmutableState<TS: Strategy + 'static> {
types: RVec<TS::Id>,
producers: RVec<UntypedFn<TS>>,
_parents: RVec<WeakServiceProvider<TS>>,
}
impl<TS: Strategy + 'static> ServiceProviderImmutableState<TS> {
pub(crate) fn new(
types: RVec<TS::Id>,
producers: RVec<UntypedFn<TS>>,
_parents: RVec<WeakServiceProvider<TS>>,
) -> Self {
Self {
types,
producers,
_parents,
}
}
}
#[repr(C)]
pub(crate) struct ServiceProviderMutableState {
base: Option<AutoFreePointer>,
shared_services: RVec<OnceLock<TypeNamed<ArcAutoFreePointer>>>,
}
pub struct ServiceIterator<T, TS: Strategy + 'static = AnyStrategy> {
next_pos: Option<usize>,
provider: WeakServiceProvider<TS>,
item_type: PhantomData<T>,
}
impl<T, TS: Strategy + 'static> ServiceIterator<T, TS> {
pub(crate) fn new(provider: WeakServiceProvider<TS>, next_pos: Option<usize>) -> Self {
Self {
provider,
item_type: PhantomData,
next_pos,
}
}
}
impl<TS: Strategy + 'static, T: Identifyable<TS::Id>> Iterator for ServiceIterator<T, TS> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.next_pos.map(|i| {
self.next_pos = self
.provider
.0
.immutable_state
.producers
.get(i + 1)
.and_then(|next| (next.get_result_type_id() == &T::get_id()).then(|| i + 1));
unsafe { crate::resolvable::resolve_unchecked::<TS, T>(&self.provider.0, i) }
})
}
fn last(self) -> Option<Self::Item>
where
Self: Sized,
{
self.next_pos.map(|i| {
let pos = binary_search::binary_search_last_by_key(
&self.provider.0.immutable_state.producers[i..],
&T::get_id(),
UntypedFn::<TS>::get_result_type_id,
);
let pos = pos.expect("to be present if next_pos has value");
unsafe { crate::resolvable::resolve_unchecked::<TS, T>(&self.provider.0, i + pos) }
})
}
fn count(self) -> usize
where
Self: Sized,
{
self.next_pos
.map(|i| {
let pos = binary_search::binary_search_last_by_key(
&self.provider.0.immutable_state.producers[i..],
&T::get_id(),
UntypedFn::get_result_type_id,
)
.expect("having at least one item because has next_pos");
pos + 1
})
.unwrap_or(0)
}
}