use std::{
future::Future,
pin::Pin,
sync::{Arc, Mutex, Weak},
task::{Context, Poll, Waker},
};
struct Shared<T> {
result: Option<T>,
waker: Option<Waker>,
}
pub struct Observer<T> {
shared: Arc<Mutex<Shared<T>>>,
}
impl<T> Future for Observer<T> {
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut shared = self.shared.lock().unwrap();
match &shared.result {
None => {
if let Some(ref mut waker) = &mut shared.waker {
waker.clone_from(cx.waker())
} else {
shared.waker = Some(cx.waker().clone());
}
Poll::Pending
}
Some(_) => Poll::Ready(shared.result.take().unwrap()),
}
}
}
pub struct AsyncEvents<K, T> {
wakers: Mutex<Vec<Promise<K, T>>>,
}
impl<K, T> AsyncEvents<K, T> {
pub fn new() -> Self {
Self {
wakers: Mutex::new(Vec::new()),
}
}
}
impl<K, T> Default for AsyncEvents<K, T> {
fn default() -> Self {
Self::new()
}
}
impl<K, V> AsyncEvents<K, V>
where
K: Eq,
{
#[deprecated(note = "Please use output_of instead")]
pub fn wait_for_output(&self, event_id: K) -> Observer<V> {
let strong = Arc::new(Mutex::new(Shared {
result: None,
waker: None,
}));
let weak = Arc::downgrade(&strong);
{
let mut wakers = self.wakers.lock().unwrap();
wakers.retain(|promise| !promise.is_orphan());
wakers.push(Promise {
key: event_id,
shared: weak,
});
}
Observer { shared: strong }
}
pub fn output_of(&self, event_id: K) -> Observer<V> {
let strong = Arc::new(Mutex::new(Shared {
result: None,
waker: None,
}));
let weak = Arc::downgrade(&strong);
{
let mut wakers = self.wakers.lock().unwrap();
wakers.retain(|promise| !promise.is_orphan());
wakers.push(Promise {
key: event_id,
shared: weak,
});
}
Observer { shared: strong }
}
pub fn resolve_all_with(&self, event_ids: &[K], output: V)
where
V: Clone,
{
self.resolve_all_if(|event_id| {
if event_ids.contains(event_id) {
Some(output.clone())
} else {
None
}
})
}
pub fn resolve_all_if(&self, f: impl Fn(&K) -> Option<V>)
where
V: Clone,
{
let mut wakers = self.wakers.lock().unwrap();
for promise in wakers.iter_mut() {
if let Some(output) = f(&promise.key) {
promise.resolve(output)
}
}
}
pub fn resolve_one(&self, event_id: K, output: V) {
let mut wakers = self.wakers.lock().unwrap();
if let Some(promise) = wakers.iter_mut().find(|p| p.key == event_id) {
promise.resolve(output);
}
}
}
struct Promise<K, T> {
key: K,
shared: Weak<Mutex<Shared<T>>>,
}
impl<K, T> Promise<K, T> {
fn resolve(&mut self, result: T) {
if let Some(strong) = self.shared.upgrade() {
let mut shared = strong.lock().unwrap();
shared.result = Some(result);
if let Some(waker) = shared.waker.take() {
waker.wake()
}
}
}
fn is_orphan(&self) -> bool {
self.shared.strong_count() == 0
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::AsyncEvents;
use tokio::{self, time::timeout};
const ZERO: Duration = Duration::from_secs(0);
#[tokio::test]
async fn pending() {
let pm: AsyncEvents<i32, ()> = AsyncEvents::new();
let future = pm.output_of(1);
timeout(ZERO, future).await.unwrap_err();
}
#[tokio::test]
async fn resolved() {
let pm = AsyncEvents::new();
let future = pm.output_of(1);
pm.resolve_all_with(&[1], 42);
assert_eq!(42, timeout(ZERO, future).await.unwrap());
}
#[tokio::test]
async fn multiple_observers_resolve_all() {
let pm = AsyncEvents::new();
let obs_1 = pm.output_of(1);
let obs_2 = pm.output_of(1);
pm.resolve_all_with(&[1], 42);
assert_eq!(42, timeout(ZERO, obs_1).await.unwrap());
assert_eq!(42, timeout(ZERO, obs_2).await.unwrap());
}
#[tokio::test]
async fn multiple_observers_resolve_one() {
let pm = AsyncEvents::new();
let obs_1 = pm.output_of(1);
let obs_2 = pm.output_of(1);
pm.resolve_one(1, 42);
assert_eq!(42, timeout(ZERO, obs_1).await.unwrap());
assert!(timeout(ZERO, obs_2).await.is_err());
}
#[tokio::test]
async fn resolve_all_observers_with_the_same_output() {
let pm = AsyncEvents::new();
let obs_1 = pm.output_of(1);
let obs_2 = pm.output_of(2);
pm.resolve_all_if(|_event_id| Some(42));
assert_eq!(42, timeout(ZERO, obs_1).await.unwrap());
assert_eq!(42, timeout(ZERO, obs_2).await.unwrap());
}
}