use super::{Error, Id, Response, Result};
use dashmap::{mapref::entry::Entry, DashMap};
use futures::{channel::oneshot, future};
use serde::Serialize;
use std::{
fmt::{self, Debug, Formatter},
future::Future,
sync::Arc,
};
pub struct ServerRequests(Arc<DashMap<Id, future::AbortHandle>>);
impl ServerRequests {
pub fn new() -> Self {
ServerRequests(Arc::new(DashMap::new()))
}
pub fn execute<F, T>(&self, id: Id, fut: F) -> impl Future<Output = Response> + Send + 'static
where
F: Future<Output = Result<T>> + Send + 'static,
T: Serialize,
{
if let Entry::Vacant(entry) = self.0.entry(id.clone()) {
let (handler_fut, abort_handle) = future::abortable(fut);
entry.insert(abort_handle);
let requests = self.0.clone();
future::Either::Left(async move {
let abort_result = handler_fut.await;
requests.remove(&id);
if let Ok(handler_result) = abort_result {
let result = handler_result.map(|v| serde_json::to_value(v).unwrap());
Response::from_parts(id, result)
} else {
Response::error(Some(id), Error::request_cancelled())
}
})
} else {
future::Either::Right(async { Response::error(Some(id), Error::invalid_request()) })
}
}
pub fn cancel(&self, id: &Id) {
if let Some((_, handle)) = self.0.remove(id) {
handle.abort();
log::info!("successfully cancelled request with ID: {}", id);
} else {
log::warn!(
"client asked to cancel request {}, but no such pending request exists, ignoring",
id
);
}
}
pub fn cancel_all(&self) {
self.0.retain(|_, handle| {
handle.abort();
false
});
}
}
impl Debug for ServerRequests {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_set()
.entries(self.0.iter().map(|entry| entry.key().clone()))
.finish()
}
}
pub struct ClientRequests(pub(crate) DashMap<Id, oneshot::Sender<Response>>);
impl ClientRequests {
pub fn new() -> Self {
ClientRequests(DashMap::new())
}
pub fn insert(&self, r: Response) {
match r.id() {
None => log::warn!("received response with request ID of `null`, ignoring"),
Some(id) => match self.0.remove(id) {
Some((_, tx)) => {
let _ = tx.send(r);
},
None => log::warn!("received response with unknown request ID: {}", id),
},
}
}
pub fn wait(&self, id: Id) -> impl Future<Output = Response> + Send + 'static {
match self.0.entry(id) {
Entry::Vacant(entry) => {
let (tx, rx) = oneshot::channel();
entry.insert(tx);
async { rx.await.expect("sender already dropped") }
},
_ => panic!("concurrent waits for the same request ID can't happen, this is a bug"),
}
}
}
impl Debug for ClientRequests {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_set()
.entries(self.0.iter().map(|entry| entry.key().clone()))
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
mod client_requests {
use super::*;
use serde_json::json;
#[test]
fn debug() {
let client_requests = ClientRequests::new();
format!("{:?}", client_requests);
}
#[tokio::test]
#[should_panic]
async fn wait_current() {
let pending = ClientRequests::new();
let id = Id::Number(1);
tokio::spawn(pending.wait(id.clone()));
tokio::spawn(pending.wait(id));
}
#[tokio::test]
async fn wait_insert() {
let pending = ClientRequests::new();
let id = Id::Number(1);
let wait_fut = tokio::spawn(pending.wait(id.clone()));
let expected = Response::ok(id.clone(), json!({}));
pending.insert(expected.clone());
let actual = wait_fut.await.expect("task panicked");
assert_eq!(expected, actual);
}
#[tokio::test]
async fn unbalanced_insert() {
let pending = ClientRequests::new();
let id = Id::Number(1);
let expected = Response::ok(id, json!({}));
pending.insert(expected);
}
}
mod server_requests {
use super::*;
use serde_json::json;
use std::time::Duration;
#[test]
fn debug() {
let server_requests = ServerRequests::new();
format!("{:?}", server_requests);
}
#[tokio::test]
async fn execute() {
let pending = ServerRequests::new();
let id = Id::Number(1);
let response = pending.execute(id.clone(), async { Ok(json!({})) }).await;
assert_eq!(response, Response::ok(id, json!({})));
}
#[tokio::test]
async fn execute_concurrent() {
let pending = ServerRequests::new();
let id = Id::Number(1);
let fut0 = pending.execute(id.clone(), async { Ok(json!({})) });
let fut1 = pending.execute(id.clone(), async { Ok(json!({})) });
assert_eq!(fut0.await, Response::ok(id.clone(), json!({})));
assert_eq!(fut1.await, Response::error(Some(id.clone()), Error::invalid_request()));
}
#[tokio::test]
async fn cancel() {
let pending = ServerRequests::new();
let id = Id::Number(1);
let handler_fut = tokio::spawn(pending.execute(id.clone(), async {
tokio::time::sleep(Duration::from_secs(50)).await;
Ok(json!({}))
}));
tokio::time::sleep(Duration::from_millis(30)).await;
pending.cancel(&id);
let res = handler_fut.await.expect("task panicked");
assert_eq!(res, Response::error(Some(id), Error::request_cancelled()));
}
#[tokio::test]
async fn cancel_non_existent() {
let pending = ServerRequests::new();
let id = Id::Number(1);
pending.cancel(&id);
}
#[tokio::test]
async fn cancel_all() {
let pending = ServerRequests::new();
let id1 = Id::Number(1);
let handler_fut1 = tokio::spawn(pending.execute(id1.clone(), async {
tokio::time::sleep(Duration::from_secs(50)).await;
Ok(json!({}))
}));
let id2 = Id::Number(2);
let handler_fut2 = tokio::spawn(pending.execute(id2.clone(), async {
tokio::time::sleep(Duration::from_secs(50)).await;
Ok(json!({}))
}));
tokio::time::sleep(Duration::from_millis(30)).await;
pending.cancel_all();
let res1 = handler_fut1.await.expect("task panicked");
assert_eq!(res1, Response::error(Some(id1), Error::request_cancelled()));
let res2 = handler_fut2.await.expect("task panicked");
assert_eq!(res2, Response::error(Some(id2), Error::request_cancelled()));
}
}
}