#![feature(async_await)]
use futures::{
channel::mpsc,
ready,
stream::{self, StreamExt as _},
};
use hashbrown::HashMap;
use parking_lot::{Mutex, RwLock};
use serde_hashkey as hashkey;
use std::{
any::{Any, TypeId},
error, fmt,
future::Future,
marker,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
#[macro_use]
#[allow(unused_imports)]
extern crate async_injector_derive;
#[doc(hidden)]
pub use self::async_injector_derive::*;
pub use async_trait::async_trait;
#[async_trait]
pub trait Provider
where
Self: Sized,
{
type Output;
async fn clear() -> Option<Self::Output> {
None
}
async fn build(self) -> Option<Self::Output> {
None
}
}
#[derive(Debug)]
pub enum Error {
Shutdown,
EndOfDriverStream,
DriverAlreadyConfigured,
SerializationError(serde_hashkey::Error),
}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
Error::Shutdown => "injector is shutting down".fmt(fmt),
Error::EndOfDriverStream => "end of driver stream".fmt(fmt),
Error::DriverAlreadyConfigured => "driver already configured".fmt(fmt),
Error::SerializationError(..) => "serialization error".fmt(fmt),
}
}
}
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match self {
Error::SerializationError(e) => Some(e),
_ => None,
}
}
}
impl From<serde_hashkey::Error> for Error {
fn from(value: serde_hashkey::Error) -> Self {
Error::SerializationError(value)
}
}
struct Sender {
tx: mpsc::UnboundedSender<Option<Box<dyn Any + Send + Sync + 'static>>>,
}
pub struct Stream<T> {
rx: mpsc::UnboundedReceiver<Option<Box<dyn Any + Send + Sync + 'static>>>,
marker: marker::PhantomData<T>,
}
impl<T> stream::Stream for Stream<T>
where
T: Unpin + Any + Send + Sync + 'static,
{
type Item = Option<T>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let value = match ready!(Pin::new(&mut self.rx).poll_next(cx)) {
Some(Some(value)) => value,
Some(None) => return Poll::Ready(Some(None)),
None => return Poll::Ready(None),
};
match (value as Box<dyn Any + 'static>).downcast::<T>() {
Ok(value) => Poll::Ready(Some(Some(*value))),
Err(_) => panic!("downcast failed"),
}
}
}
impl<T> stream::FusedStream for Stream<T> {
fn is_terminated(&self) -> bool {
false
}
}
#[derive(Default)]
struct Storage {
value: Option<Box<dyn Any + Send + Sync + 'static>>,
subs: Vec<Sender>,
}
impl Storage {
fn try_send<S>(&mut self, send: S)
where
S: Fn() -> Option<Box<dyn Any + Send + Sync + 'static>>,
{
let mut to_delete = smallvec::SmallVec::<[usize; 16]>::new();
for (idx, s) in self.subs.iter().enumerate() {
if let Err(e) = s.tx.unbounded_send(send()) {
if e.is_disconnected() {
to_delete.push(idx);
continue;
}
log::warn!("failed to send resource update: {}", e);
}
}
if to_delete.is_empty() {
return;
}
for (c, idx) in to_delete.into_iter().enumerate() {
let _ = self.subs.swap_remove(idx.saturating_sub(c));
}
}
}
struct Inner {
storage: RwLock<HashMap<RawKey, Storage>>,
drivers: mpsc::UnboundedSender<Driver>,
drivers_rx: Mutex<Option<mpsc::UnboundedReceiver<Driver>>>,
}
#[derive(Clone)]
pub struct Injector {
inner: Arc<Inner>,
}
impl Injector {
pub fn new() -> Self {
let (drivers, drivers_rx) = mpsc::unbounded();
Self {
inner: Arc::new(Inner {
storage: Default::default(),
drivers,
drivers_rx: Mutex::new(Some(drivers_rx)),
}),
}
}
pub fn clear<T>(&self)
where
T: Clone + Any + Send + Sync + 'static,
{
self.clear_key::<T>(&Key::<T>::of())
}
pub fn clear_key<T>(&self, key: &Key<T>)
where
T: Clone + Any + Send + Sync + 'static,
{
let key = key.as_raw_key();
let mut storage = self.inner.storage.write();
let storage = match storage.get_mut(&key) {
Some(storage) => storage,
None => return,
};
if let None = storage.value.take() {
return;
}
storage.try_send(|| None);
}
pub fn update<T>(&self, value: T)
where
T: Any + Send + Sync + 'static + Clone,
{
self.update_key(&Key::<T>::of(), value)
}
pub fn update_key<T>(&self, key: &Key<T>, value: T)
where
T: Any + Send + Sync + 'static + Clone,
{
let key = key.as_raw_key();
let mut storage = self.inner.storage.write();
let storage = storage.entry(key).or_default();
storage.try_send(|| Some(Box::new(value.clone())));
storage.value = Some(Box::new(value));
}
pub fn get<T>(&self) -> Option<T>
where
T: Any + Send + Sync + 'static + Clone,
{
self.get_key(&Key::<T>::of())
}
pub fn get_key<T>(&self, key: &Key<T>) -> Option<T>
where
T: Any + Send + Sync + 'static + Clone,
{
let key = key.as_raw_key();
let storage = self.inner.storage.read();
let storage = storage.get(&key)?;
let value = storage.value.as_ref()?;
match value.downcast_ref::<T>() {
Some(value) => Some(value.clone()),
None => panic!("downcast failed"),
}
}
pub fn stream<T>(&self) -> (Stream<T>, Option<T>)
where
T: Any + Send + Sync + 'static + Clone,
{
self.stream_key(&Key::<T>::of())
}
pub fn stream_key<T>(&self, key: &Key<T>) -> (Stream<T>, Option<T>)
where
T: Any + Send + Sync + 'static + Clone,
{
let key = key.as_raw_key();
let (tx, rx) = mpsc::unbounded();
let value = {
let mut storage = self.inner.storage.write();
let storage = storage.entry(key).or_default();
storage.subs.push(Sender { tx: tx.clone() });
match storage.value.as_ref() {
Some(value) => match value.downcast_ref::<T>() {
Some(value) => Some(value.clone()),
None => panic!("downcast failed"),
},
None => None,
}
};
let stream = Stream {
rx,
marker: marker::PhantomData,
};
(stream, value)
}
pub fn var<T>(&self) -> Result<Arc<RwLock<Option<T>>>, Error>
where
T: Any + Send + Sync + 'static + Clone + Unpin,
{
self.var_key(&Key::<T>::of())
}
pub fn var_key<T>(&self, key: &Key<T>) -> Result<Arc<RwLock<Option<T>>>, Error>
where
T: Any + Send + Sync + 'static + Clone + Unpin,
{
use futures::StreamExt as _;
let (mut stream, value) = self.stream_key(key);
let value = Arc::new(RwLock::new(value));
let future_value = value.clone();
let future = async move {
while let Some(update) = stream.next().await {
*future_value.write() = update;
}
};
let result = self.inner.drivers.unbounded_send(Driver {
future: Box::pin(future),
});
if let Err(e) = result {
if !e.is_disconnected() {
return Err(Error::Shutdown);
}
}
Ok(value)
}
pub async fn drive(self) -> Result<(), Error> {
let mut rx = self
.inner
.drivers_rx
.lock()
.take()
.ok_or(Error::DriverAlreadyConfigured)?;
let mut drivers = stream::FuturesUnordered::new();
loop {
while drivers.is_empty() {
drivers.push(rx.next().await.ok_or(Error::EndOfDriverStream)?);
}
while !drivers.is_empty() {
futures::select! {
driver = rx.next() => drivers.push(driver.ok_or(Error::EndOfDriverStream)?),
() = drivers.select_next_some() => (),
}
}
}
}
}
enum Empty {}
#[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq, Hash)]
pub struct RawKey {
type_id: TypeId,
tag_type_id: TypeId,
tag: hashkey::Key,
}
#[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq, Hash)]
pub struct Key<T>
where
T: Any,
{
type_id: TypeId,
tag_type_id: TypeId,
tag: hashkey::Key,
marker: std::marker::PhantomData<T>,
}
impl<T> Key<T>
where
T: Any,
{
pub fn of() -> Self {
Self {
type_id: TypeId::of::<T>(),
tag_type_id: TypeId::of::<Empty>(),
tag: hashkey::Key::Unit,
marker: std::marker::PhantomData,
}
}
pub fn tagged<K>(tag: K) -> Result<Self, Error>
where
K: Any + serde::Serialize,
{
Ok(Self {
type_id: TypeId::of::<T>(),
tag_type_id: TypeId::of::<K>(),
tag: hashkey::to_key(&tag)?,
marker: std::marker::PhantomData,
})
}
fn as_raw_key(&self) -> RawKey {
RawKey {
type_id: self.type_id,
tag_type_id: self.tag_type_id,
tag: self.tag.clone(),
}
}
}
struct Driver {
future: Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
}
impl Future for Driver {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.future.as_mut().poll(cx)
}
}