use std::{
any::Any,
collections::{hash_map::Entry, HashMap},
future::ready,
pin::Pin,
sync::Arc,
};
use backtrace::Backtrace;
use futures::{Future, FutureExt};
use tokio::sync::RwLock;
use super::{provider, Dependency, Error, Injector, Key, Result};
#[derive(Default, Clone)]
pub struct Inject(pub(crate) Arc<RwLock<HashMap<Key, Injector>>>);
impl Inject {
pub async fn get_key<T: Any + Send + Sync>(&self, key: Key) -> Result<Arc<T>> {
let available = self.get_available_keys().await;
if let Some(dep) = self.get_key_opt::<T>(key.clone()).await? {
Ok(dep)
} else {
Err(Error::NotFound {
missing: key,
available,
backtrace: Arc::new(Backtrace::new()),
})
}
}
pub async fn get_key_opt<T: Any + Send + Sync>(&self, key: Key) -> Result<Option<Arc<T>>> {
if let Some(injector) = self.0.read().await.get(&key) {
let pending = injector.request(self.clone()).await;
let value = pending.await?;
return value
.downcast::<T>()
.map(Some)
.map_err(|_err| Error::TypeMismatch(key));
}
Ok(None)
}
pub async fn inject_key<T: Any + Send + Sync>(&self, key: Key, dep: T) -> Result<()> {
if self.0.read().await.contains_key(&key) {
Err(Error::Occupied(key))
} else {
let _ = self.override_key(key, dep).await?;
Ok(())
}
}
pub async fn replace_key<T: Any + Send + Sync>(&self, key: Key, dep: T) -> Result<()> {
let available = self.get_available_keys().await;
if self.0.read().await.contains_key(&key) {
let _ = self.override_key(key, dep).await?;
Ok(())
} else {
Err(Error::NotFound {
missing: key,
available,
backtrace: Arc::new(Backtrace::new()),
})
}
}
pub async fn override_key<T: Any + Send + Sync>(&self, key: Key, dep: T) -> Result<bool> {
match self.0.write().await.entry(key.clone()) {
Entry::Occupied(mut entry) => {
let pending: Pin<
Box<dyn Future<Output = provider::Result<Arc<Dependency>>> + Send>,
> = Box::pin(ready::<provider::Result<Arc<dyn Any + Send + Sync>>>(Ok(
Arc::new(dep),
)));
let _ = entry.insert(Injector::from_pending(pending.shared()));
Ok(true)
}
Entry::Vacant(entry) => {
let pending: Pin<
Box<dyn Future<Output = provider::Result<Arc<Dependency>>> + Send>,
> = Box::pin(ready::<provider::Result<Arc<dyn Any + Send + Sync>>>(Ok(
Arc::new(dep),
)));
let _ = entry.insert(Injector::from_pending(pending.shared()));
Ok(false)
}
}
}
pub async fn consume_key<T: Any + Send + Sync>(&self, key: Key) -> Result<T> {
let available = self.get_available_keys().await;
self.consume_key_opt(key.clone())
.await?
.ok_or(Error::NotFound {
missing: key,
available,
backtrace: Arc::new(Backtrace::new()),
})
}
pub async fn consume_key_opt<T: Any + Send + Sync>(&self, key: Key) -> Result<Option<T>> {
if let Some(dep) = self.get_key_opt::<T>(key.clone()).await? {
self.remove_key(key.clone()).await?;
return Arc::try_unwrap(dep)
.map(Some)
.map_err(|arc| Error::CannotConsume {
key,
strong_count: Arc::strong_count(&arc),
});
};
Ok(None)
}
pub async fn modify_key<T, F>(&self, key: Key, modify: F) -> Result<()>
where
T: Any + Send + Sync,
F: FnOnce(T) -> Result<T>,
{
if let Some(dep) = self.get_key_opt::<T>(key.clone()).await? {
self.remove_key(key.clone()).await?;
let dep = Arc::try_unwrap(dep)
.map(Some)
.map_err(|arc| Error::CannotConsume {
key: key.clone(),
strong_count: Arc::strong_count(&arc),
})?;
if let Some(dep) = dep {
self.inject_key(key.clone(), modify(dep)?).await?;
}
return Ok(());
};
Err(Error::NotFound {
missing: key,
available: self.get_available_keys().await,
backtrace: Arc::new(Backtrace::new()),
})
}
pub async fn remove_key(&self, key: Key) -> Result<()> {
let available = self.get_available_keys().await;
match self.0.write().await.entry(key.clone()) {
Entry::Occupied(entry) => {
let _ = entry.remove();
Ok(())
}
Entry::Vacant(_) => Err(Error::NotFound {
missing: key,
available,
backtrace: Arc::new(Backtrace::new()),
}),
}
}
pub async fn eject_key<T: Any + Send + Sync>(self, key: Key) -> Result<T> {
let available = self.get_available_keys().await;
self.eject_key_opt(key.clone())
.await?
.ok_or(Error::NotFound {
missing: key,
available,
backtrace: Arc::new(Backtrace::new()),
})
}
pub async fn eject_key_opt<T: Any + Send + Sync>(self, key: Key) -> Result<Option<T>> {
let dep = self.get_key_opt::<T>(key.clone()).await?;
drop(self);
if let Some(value) = dep {
return Arc::try_unwrap(value)
.map(Some)
.map_err(|arc| Error::CannotConsume {
key,
strong_count: Arc::strong_count(&arc),
});
}
Ok(None)
}
pub async fn get_available_keys(&self) -> Vec<Key> {
let container = self.0.read().await;
container.keys().cloned().collect()
}
}
#[cfg(test)]
pub(crate) mod test {
use derive_new::new;
pub trait HasId: Send + Sync {
fn get_id(&self) -> String;
}
#[derive(Debug, Clone, new)]
pub struct TestService {
pub(crate) id: String,
}
impl HasId for TestService {
fn get_id(&self) -> String {
self.id.clone()
}
}
#[derive(new)]
pub struct OtherService {
pub(crate) other_id: String,
}
impl HasId for OtherService {
fn get_id(&self) -> String {
self.other_id.clone()
}
}
}