use std::collections::{HashMap, VecDeque};
use std::fmt::{Display, Formatter};
use std::future::Future;
use std::hash::Hash;
use std::pin::Pin;
use std::sync::{Mutex};
use notify_future::{Notify};
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
pub enum WaiterError {
AlreadyExist,
Timeout,
NoWaiter,
}
impl Display for WaiterError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
WaiterError::AlreadyExist => write!(f, "AlreadyExist"),
WaiterError::Timeout => write!(f, "Timeout"),
WaiterError::NoWaiter => write!(f, "NoWaiter"),
}
}
}
impl std::error::Error for WaiterError {
}
pub type WaiterResult<T> = Result<T, WaiterError>;
pub struct ResultFuture<'a, R> {
future: Pin<Box<dyn Future<Output = Result<R, WaiterError>> + 'a + Send>>,
}
impl <'a, R> ResultFuture<'a, R> {
pub fn new(future: Pin<Box<dyn Future<Output = Result<R, WaiterError>> + 'a + Send>>) -> Self {
Self {
future,
}
}
}
impl <'a, R> Future for ResultFuture<'a, R> {
type Output = Result<R, WaiterError>;
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
self.get_mut().future.as_mut().poll(cx)
}
}
struct CallbackWaiterState<K, R> {
result_notifies: HashMap<K, Option<Notify<R>>>,
result_cache: HashMap<K, VecDeque<R>>,
}
pub struct CallbackWaiter<K, R> {
state: Mutex<CallbackWaiterState<K, R>>,
}
impl <K: Hash + Eq + Clone + 'static + Send, R: 'static + Send> CallbackWaiter<K, R> {
pub fn new() -> Self {
Self {
state: Mutex::new(CallbackWaiterState {
result_notifies: HashMap::new(),
result_cache: HashMap::new(),
})
}
}
pub fn create_result_future(&self, callback_id: K) -> WaiterResult<ResultFuture<R>> {
let waiter = {
let mut state = self.state.lock().unwrap();
let notifies = state.result_notifies.get(&callback_id);
if let Some(notifies) = notifies {
if let Some(notifies) = notifies {
if!notifies.is_canceled() {
return Err(WaiterError::AlreadyExist);
}
}
}
if let Some(result) = state.result_cache.get_mut(&callback_id) {
if let Some(ret) = result.pop_front() {
return Ok(ResultFuture::new(Box::pin(async move {
Ok(ret)
})));
}
}
let (notify, waiter) = Notify::new();
state.result_notifies.insert(callback_id.clone(), Some(notify));
waiter
};
Ok(ResultFuture::new(Box::pin(async move {
let ret = waiter.await;
{
let mut state = self.state.lock().unwrap();
state.result_notifies.remove(&callback_id);
}
Ok(ret)
})))
}
pub fn create_timeout_result_future(&self, callback_id: K, timeout: std::time::Duration) -> WaiterResult<ResultFuture<R>> {
let waiter = {
let mut state = self.state.lock().unwrap();
let notifies = state.result_notifies.get(&callback_id);
if let Some(notifies) = notifies {
if let Some(notifies) = notifies {
if!notifies.is_canceled() {
return Err(WaiterError::AlreadyExist);
}
}
}
if let Some(result) = state.result_cache.get_mut(&callback_id) {
if let Some(ret) = result.pop_front() {
return Ok(ResultFuture::new(Box::pin(async move {
Ok(ret)
})));
}
}
let (notify, waiter) = Notify::new();
state.result_notifies.insert(callback_id.clone(), Some(notify));
waiter
};
Ok(ResultFuture::new(Box::pin(async move {
let ret = tokio::time::timeout(timeout, waiter).await;
{
let mut state = self.state.lock().unwrap();
state.result_notifies.remove(&callback_id);
}
match ret {
Ok(ret) => Ok(ret),
Err(_) => Err(WaiterError::Timeout)
}
})))
}
pub fn set_result(&self, callback_id: K, result: R) -> Result<(), WaiterError> {
let mut state = self.state.lock().unwrap();
if let Some(future) = state.result_notifies.get_mut(&callback_id) {
if let Some(future) = future.take() {
if !future.is_canceled() {
future.notify(result);
return Ok(());
}
}
}
Err(WaiterError::NoWaiter)
}
pub fn set_result_with_cache(&self, callback_id: K, result: R) {
let mut state = self.state.lock().unwrap();
if let Some(future) = state.result_notifies.get_mut(&callback_id) {
if let Some(future) = future.take() {
if !future.is_canceled() {
future.notify(result);
return;
}
}
}
if let Some(cache) = state.result_cache.get_mut(&callback_id) {
cache.push_back(result);
} else {
let mut cache = VecDeque::new();
cache.push_back(result);
state.result_cache.insert(callback_id, cache);
}
}
}
struct SingleCallbackWaiterState<R> {
result_notify: Option<Option<Notify<R>>>,
result_cache: VecDeque<R>,
}
pub struct SingleCallbackWaiter<R> {
state: Mutex<SingleCallbackWaiterState<R>>,
}
impl <R: 'static + Send> SingleCallbackWaiter<R> {
pub fn new() -> Self {
Self {
state: Mutex::new(SingleCallbackWaiterState {
result_notify: None,
result_cache: VecDeque::new(),
})
}
}
pub fn create_result_future(&self) -> WaiterResult<ResultFuture<R>> {
let waiter = {
let mut state = self.state.lock().unwrap();
if let Some(notify) = state.result_notify.as_ref() {
if let Some(notify) = notify {
if !notify.is_canceled() {
return Err(WaiterError::AlreadyExist);
}
}
}
if let Some(ret) = state.result_cache.pop_front() {
return Ok(ResultFuture::new(Box::pin(async move {
Ok(ret)
})));
}
let (notify, waiter) = Notify::new();
state.result_notify = Some(Some(notify));
waiter
};
Ok(ResultFuture::new(Box::pin(async move {
let ret = waiter.await;
{
let mut state = self.state.lock().unwrap();
state.result_notify = None;
}
Ok(ret)
})))
}
pub fn create_timeout_result_future(&self, timeout: std::time::Duration) -> WaiterResult<ResultFuture<R>> {
let waiter = {
let mut state = self.state.lock().unwrap();
if let Some(notify) = state.result_notify.as_ref() {
if let Some(notify) = notify {
if !notify.is_canceled() {
return Err(WaiterError::AlreadyExist);
}
}
}
if let Some(ret) = state.result_cache.pop_front() {
return Ok(ResultFuture::new(Box::pin(async move {
Ok(ret)
})));
}
let (notify, waiter) = Notify::new();
state.result_notify = Some(Some(notify));
waiter
};
Ok(ResultFuture::new(Box::pin(async move {
let ret = tokio::time::timeout(timeout, waiter).await;
{
let mut state = self.state.lock().unwrap();
state.result_notify = None;
}
match ret {
Ok(ret) => Ok(ret),
Err(_) => {
Err(WaiterError::Timeout)
}
}
})))
}
pub fn set_result(&self, result: R) -> Result<(), WaiterError> {
let mut state = self.state.lock().unwrap();
if let Some(future) = state.result_notify.as_mut() {
if let Some(future) = future.take() {
if !future.is_canceled() {
future.notify(result);
return Ok(());
}
}
}
Err(WaiterError::NoWaiter)
}
pub fn set_result_with_cache(&self, result: R) {
let mut state = self.state.lock().unwrap();
if let Some(future) = state.result_notify.as_mut() {
if let Some(future) = future.take() {
if !future.is_canceled() {
future.notify(result);
return;
}
}
}
state.result_cache.push_back(result);
}
}
#[cfg(test)]
mod test {
use super::*;
use std::sync::Arc;
use tokio::time::{sleep, Duration};
#[tokio::test]
async fn test_waiter() {
let waiter = Arc::new(CallbackWaiter::new());
let callback_id = 1;
let result_future = waiter.create_result_future(callback_id).unwrap();
assert!(waiter.create_result_future(callback_id).is_err());
let tmp = waiter.clone();
tokio::spawn(async move {
sleep(Duration::from_millis(1000)).await;
let ret = tmp.set_result(callback_id, 1);
assert!(ret.is_ok());
});
let ret = result_future.await.unwrap();
assert_eq!(ret, 1);
}
#[tokio::test]
async fn test_waiter1() {
let waiter = Arc::new(CallbackWaiter::new());
let callback_id = 1;
let tmp = waiter.clone();
tokio::spawn(async move {
tmp.set_result_with_cache(callback_id, 1);
});
let result_future = waiter.create_result_future(callback_id).unwrap();
let ret = result_future.await.unwrap();
assert_eq!(ret, 1);
}
#[tokio::test]
async fn test_waiter_timout() {
let waiter = Arc::new(CallbackWaiter::new());
let callback_id = 1;
let result_future = waiter
.create_timeout_result_future(callback_id, Duration::from_secs(2))
.unwrap();
let tmp = waiter.clone();
tokio::spawn(async move {
sleep(Duration::from_millis(1000)).await;
let ret = tmp.set_result(callback_id, 1);
assert!(ret.is_ok());
});
let ret = result_future.await.unwrap();
assert_eq!(ret, 1);
}
#[tokio::test]
async fn test_waiter_timout2() {
let waiter = Arc::new(CallbackWaiter::new());
let callback_id = 1;
let result_future = waiter
.create_timeout_result_future(callback_id, Duration::from_secs(2))
.unwrap();
let tmp = waiter.clone();
tokio::spawn(async move {
sleep(Duration::from_secs(3)).await;
let ret = tmp.set_result(callback_id, 1);
assert!(ret.is_err());
});
match result_future.await {
Ok(_) => {}
Err(e) => {
assert_eq!(e, WaiterError::Timeout);
}
}
}
#[tokio::test]
async fn test_waiter_timout3() {
let waiter = Arc::new(CallbackWaiter::new());
let callback_id = 1;
let tmp = waiter.clone();
tokio::spawn(async move {
let ret = tmp.set_result(callback_id, 1);
assert!(ret.is_err());
})
.await
.unwrap();
let result_future = waiter
.create_timeout_result_future(callback_id, Duration::from_secs(2))
.unwrap();
assert!(waiter
.create_timeout_result_future(callback_id, Duration::from_secs(2))
.is_err());
match result_future.await {
Ok(_) => {}
Err(e) => {
assert_eq!(e, WaiterError::Timeout);
}
}
}
#[tokio::test]
async fn test_signle_waiter() {
let waiter = Arc::new(SingleCallbackWaiter::new());
let result_future = waiter.create_result_future().unwrap();
assert!(waiter.create_result_future().is_err());
let tmp = waiter.clone();
tokio::spawn(async move {
sleep(Duration::from_millis(1000)).await;
let ret = tmp.set_result(1);
assert!(ret.is_ok());
});
let ret = result_future.await.unwrap();
assert_eq!(ret, 1);
}
#[tokio::test]
async fn test_single_waiter1() {
let waiter = Arc::new(SingleCallbackWaiter::new());
let tmp = waiter.clone();
tokio::spawn(async move {
tmp.set_result_with_cache(1);
});
let result_future = waiter.create_result_future().unwrap();
let ret = result_future.await.unwrap();
assert_eq!(ret, 1);
}
#[tokio::test]
async fn test_single_waiter_timout() {
let waiter = Arc::new(SingleCallbackWaiter::new());
let result_future = waiter
.create_timeout_result_future(Duration::from_secs(2))
.unwrap();
assert!(waiter
.create_timeout_result_future(Duration::from_secs(2))
.is_err());
let tmp = waiter.clone();
tokio::spawn(async move {
sleep(Duration::from_millis(1000)).await;
let ret = tmp.set_result(1);
assert!(ret.is_ok());
});
let ret = result_future.await.unwrap();
assert_eq!(ret, 1);
}
#[tokio::test]
async fn test_single_waiter_timout2() {
let waiter = Arc::new(SingleCallbackWaiter::new());
let result_future = waiter
.create_timeout_result_future(Duration::from_secs(2))
.unwrap();
let tmp = waiter.clone();
tokio::spawn(async move {
sleep(Duration::from_secs(3)).await;
let ret = tmp.set_result(1);
assert!(ret.is_err());
});
match result_future.await {
Ok(_) => {}
Err(e) => {
assert_eq!(e, WaiterError::Timeout);
}
}
}
#[tokio::test]
async fn test_single_waiter_timout3() {
let waiter = Arc::new(SingleCallbackWaiter::new());
let tmp = waiter.clone();
tokio::spawn(async move {
let ret = tmp.set_result(1);
assert!(ret.is_err());
})
.await
.unwrap();
let result_future = waiter
.create_timeout_result_future(Duration::from_secs(2))
.unwrap();
match result_future.await {
Ok(_) => {}
Err(e) => {
assert_eq!(e, WaiterError::Timeout);
}
}
}
#[tokio::test]
async fn test_waiter_reregister_after_future_drop() {
let waiter = Arc::new(CallbackWaiter::new());
let callback_id = 42;
let dropped_future = waiter.create_result_future(callback_id).unwrap();
drop(dropped_future);
sleep(Duration::from_millis(10)).await;
let result_future = waiter.create_result_future(callback_id).unwrap();
let tmp = waiter.clone();
tokio::spawn(async move {
tmp.set_result(callback_id, 7).unwrap();
});
let ret = result_future.await.unwrap();
assert_eq!(ret, 7);
}
#[tokio::test]
async fn test_waiter_cache_fifo_under_load() {
let waiter = CallbackWaiter::new();
let callback_id = 1;
let total = 200;
for i in 0..total {
waiter.set_result_with_cache(callback_id, i);
}
for expected in 0..total {
let ret = waiter
.create_result_future(callback_id)
.unwrap()
.await
.unwrap();
assert_eq!(ret, expected);
}
}
#[tokio::test]
async fn test_waiter_timeout_set_result_race() {
for callback_id in 0..50 {
let waiter = Arc::new(CallbackWaiter::new());
let result_future = waiter
.create_timeout_result_future(callback_id, Duration::from_millis(50))
.unwrap();
let tmp = waiter.clone();
let set_task = tokio::spawn(async move {
sleep(Duration::from_millis(50)).await;
tmp.set_result(callback_id, 1)
});
let future_result = result_future.await;
let set_result = set_task.await.unwrap();
match (future_result, set_result) {
(Ok(1), Ok(())) => {}
(Err(WaiterError::Timeout), Err(WaiterError::NoWaiter)) => {}
(other_future, other_set) => {
panic!("unexpected race outcome: {:?}, {:?}", other_future, other_set);
}
}
}
}
}