use std::{
any::{Any, TypeId},
panic::{AssertUnwindSafe, catch_unwind},
sync::{Arc, Condvar, Mutex, MutexGuard},
};
use crate::{Container, NidusError, RequestScope, Result, resolution};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum ProviderLifetime {
Singleton,
Transient,
Request,
}
pub trait Provider: Send + Sync + 'static {}
impl<T> Provider for T where T: Send + Sync + 'static {}
type ErasedProvider = dyn Any + Send + Sync;
type ProviderFactory = dyn Fn(&Container) -> Result<Arc<ErasedProvider>> + Send + Sync;
type RequestProviderFactory =
dyn for<'scope> Fn(&RequestScope<'scope>) -> Result<Arc<ErasedProvider>> + Send + Sync;
pub struct ProviderEntry {
type_id: TypeId,
type_name: &'static str,
lifetime: ProviderLifetime,
factory: Arc<ProviderFactory>,
request_factory: Option<Arc<RequestProviderFactory>>,
singleton: Mutex<SingletonState>,
singleton_ready: Condvar,
}
enum SingletonState {
Empty,
Initializing,
Ready(Arc<ErasedProvider>),
}
impl ProviderEntry {
pub fn new(
type_id: TypeId,
type_name: &'static str,
lifetime: ProviderLifetime,
factory: Arc<ProviderFactory>,
) -> Self {
Self {
type_id,
type_name,
lifetime,
factory,
request_factory: None,
singleton: Mutex::new(SingletonState::Empty),
singleton_ready: Condvar::new(),
}
}
pub fn new_request_scoped(
type_id: TypeId,
type_name: &'static str,
factory: Arc<ProviderFactory>,
request_factory: Arc<RequestProviderFactory>,
) -> Self {
Self {
type_id,
type_name,
lifetime: ProviderLifetime::Request,
factory,
request_factory: Some(request_factory),
singleton: Mutex::new(SingletonState::Empty),
singleton_ready: Condvar::new(),
}
}
pub fn type_name(&self) -> &'static str {
self.type_name
}
pub fn lifetime(&self) -> ProviderLifetime {
self.lifetime
}
pub(crate) fn resolve_erased(&self, container: &Container) -> Result<Arc<ErasedProvider>> {
match self.lifetime {
ProviderLifetime::Singleton => self.resolve_singleton(container),
ProviderLifetime::Transient | ProviderLifetime::Request => {
self.create_erased(container)
}
}
}
pub(crate) fn resolve_erased_in_scope(
&self,
scope: &RequestScope<'_>,
) -> Result<Arc<ErasedProvider>> {
match self.lifetime {
ProviderLifetime::Request => self.create_erased_in_scope(scope),
ProviderLifetime::Singleton | ProviderLifetime::Transient => {
self.resolve_erased(scope.container())
}
}
}
fn create_erased(&self, container: &Container) -> Result<Arc<ErasedProvider>> {
(self.factory)(container).map_err(|source| NidusError::ProviderFactory {
type_name: self.type_name,
source: Box::new(source),
})
}
fn resolve_singleton(&self, container: &Container) -> Result<Arc<ErasedProvider>> {
loop {
let mut singleton = lock_unpoisoned(&self.singleton);
match &*singleton {
SingletonState::Ready(instance) => return Ok(Arc::clone(instance)),
SingletonState::Initializing => {
if resolution::is_active(self.type_id) {
return Err(NidusError::CircularProviderResolution {
type_name: self.type_name,
});
}
drop(wait_unpoisoned(&self.singleton_ready, singleton));
}
SingletonState::Empty => {
let _guard = resolution::enter(self.type_id, self.type_name)?;
*singleton = SingletonState::Initializing;
drop(singleton);
let instance =
match catch_unwind(AssertUnwindSafe(|| self.create_erased(container))) {
Ok(outcome) => outcome,
Err(panic_payload) => {
let mut singleton = lock_unpoisoned(&self.singleton);
*singleton = SingletonState::Empty;
self.singleton_ready.notify_all();
drop(singleton);
std::panic::resume_unwind(panic_payload);
}
};
let mut singleton = lock_unpoisoned(&self.singleton);
match instance {
Ok(instance) => {
*singleton = SingletonState::Ready(Arc::clone(&instance));
self.singleton_ready.notify_all();
return Ok(instance);
}
Err(error) => {
*singleton = SingletonState::Empty;
self.singleton_ready.notify_all();
return Err(error);
}
}
}
}
}
}
fn create_erased_in_scope(&self, scope: &RequestScope<'_>) -> Result<Arc<ErasedProvider>> {
if let Some(factory) = &self.request_factory {
factory(scope).map_err(|source| NidusError::ProviderFactory {
type_name: self.type_name,
source: Box::new(source),
})
} else {
self.create_erased(scope.container())
}
}
}
fn lock_unpoisoned<T>(mutex: &Mutex<T>) -> MutexGuard<'_, T> {
mutex
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn wait_unpoisoned<'a, T>(condvar: &Condvar, guard: MutexGuard<'a, T>) -> MutexGuard<'a, T> {
condvar
.wait(guard)
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[cfg(test)]
mod tests {
use std::{
any::{Any, type_name},
sync::Arc,
thread,
};
use super::{ProviderEntry, ProviderLifetime};
use crate::Container;
#[test]
fn singleton_provider_recovers_from_poisoned_cache() {
let provider = Arc::new(ProviderEntry::new(
std::any::TypeId::of::<String>(),
type_name::<String>(),
ProviderLifetime::Singleton,
Arc::new(|_container| Ok(Arc::new("ready".to_owned()) as Arc<dyn Any + Send + Sync>)),
));
let poisoned_provider = Arc::clone(&provider);
let panic = thread::spawn(move || {
let _singleton = poisoned_provider.singleton.lock().unwrap();
panic!("poison singleton cache");
});
assert!(panic.join().is_err());
let value = provider
.resolve_erased(&Container::new())
.unwrap()
.downcast::<String>()
.unwrap();
assert_eq!(&*value, "ready");
}
}