#![recursion_limit = "512"]
use future::FusedFuture;
use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
use futures::channel::oneshot::{channel, Sender};
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use futures::lock::Mutex;
use futures::select;
use futures::stream::{FuturesUnordered, StreamExt};
use futures::{self, Future};
use futures::{
future,
future::{BoxFuture, FutureExt},
};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::fmt::Debug;
use std::{collections::HashMap, pin::Pin, sync::Arc};
use thiserror::Error;
use tracing::{instrument, trace, trace_span, Instrument};
#[cfg(test)]
mod test;
pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {}
impl<T> AsyncReadWrite for T where T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {}
pub trait Msg: Serialize + DeserializeOwned + Send + Sync + 'static {}
impl<T> Msg for T where T: Serialize + DeserializeOwned + Send + Sync + 'static {}
#[derive(Error, Debug)]
pub enum Error {
#[error("queue management error")]
Queue,
#[error("unknown response {0}")]
UnknownResponse(u64),
#[error("response channel {0} dropped")]
ResponseChannelDropped(u64),
#[error(transparent)]
Io(#[from] futures::io::Error),
#[error(transparent)]
Postcard(#[from] postcard::Error),
#[error(transparent)]
Canceled(#[from] futures::channel::oneshot::Canceled),
}
struct EnvelopeWriter {
wr: Box<dyn AsyncWrite + Unpin + Send + Sync + 'static>,
}
impl EnvelopeWriter {
fn new<W: AsyncWrite + Unpin + Send + Sync + 'static>(wtr: W) -> Self {
EnvelopeWriter { wr: Box::new(wtr) }
}
#[instrument(skip(self, e))]
async fn send<OneWay: Msg, Request: Msg, Response: Msg>(
&mut self,
e: Envelope<OneWay, Request, Response>,
) -> Result<(), Error> {
use byteorder_async::{LittleEndian, WriterToByteOrder};
let bytes = postcard::to_allocvec(&e)?;
let wsz: u64 = bytes.len() as u64;
self.wr.byte_order().write_u64::<LittleEndian>(wsz).await?;
self.wr.write_all(bytes.as_slice()).await?;
Ok(())
}
}
struct EnvelopeReader {
rd: Box<dyn AsyncRead + Unpin + Send + Sync + 'static>,
rdbuf: Vec<u8>,
}
impl EnvelopeReader {
fn new<R: AsyncRead + Unpin + Send + Sync + 'static>(rdr: R) -> Self {
EnvelopeReader {
rd: Box::new(rdr),
rdbuf: Vec::new(),
}
}
#[instrument(skip(self))]
async fn recv<OneWay: Msg, Request: Msg, Response: Msg>(
&mut self,
) -> Result<Envelope<OneWay, Request, Response>, Error> {
use byteorder_async::{LittleEndian, ReaderToByteOrder};
let rsz: u64 = self.rd.byte_order().read_u64::<LittleEndian>().await?;
self.rdbuf.resize(rsz as usize, 0);
self.rd.read_exact(self.rdbuf.as_mut_slice()).await?;
Ok(postcard::from_bytes(self.rdbuf.as_slice())?)
}
}
struct Reception<OneWay: Msg, Request: Msg, Response: Msg> {
next_request: u64,
requests: HashMap<u64, Sender<Response>>,
enqueue: UnboundedSender<Envelope<OneWay, Request, Response>>,
}
pub struct Queue<OneWay: Msg, Request: Msg, Response: Msg> {
reception: Arc<Mutex<Reception<OneWay, Request, Response>>>,
}
impl<OneWay: Msg, Request: Msg, Response: Msg> Clone for Queue<OneWay, Request, Response> {
fn clone(&self) -> Self {
Self {
reception: self.reception.clone(),
}
}
}
impl<OneWay: Msg, Request: Msg, Response: Msg> Queue<OneWay, Request, Response> {
fn new(reception: Reception<OneWay, Request, Response>) -> Self {
Self {
reception: Arc::new(Mutex::new(reception)),
}
}
pub fn enqueue_oneway(
&self,
oneway: OneWay,
) -> impl Future<Output = Result<(), Error>> + 'static {
let reception = self.reception.clone();
let span = tracing::trace_span!("enqueue_oneway");
(async move {
let env = Envelope::<OneWay, Request, Response>::OneWay(oneway);
let guard = reception.lock().await;
guard.enqueue.unbounded_send(env).map_err(|_| Error::Queue)
})
.instrument(span)
}
pub fn enqueue_request(
&self,
req: Request,
) -> impl Future<Output = Result<Response, Error>> + 'static {
let reception = self.reception.clone();
let span = trace_span!("enqueue_request");
(async move {
let (send_err, recv) = {
let mut guard = reception.lock().await;
let curr = guard.next_request;
let env = Envelope::<OneWay, Request, Response>::Request(curr, req);
let send_err = guard.enqueue.unbounded_send(env);
let (send, recv) = channel();
if send_err.is_ok() {
tracing::trace!(?curr, "enqueued envelope for request");
guard.next_request += 1;
guard.requests.insert(curr, send);
}
(send_err, recv)
};
if send_err.is_ok() {
Ok(recv.await?)
} else {
Err(Error::Queue)
}
})
.instrument(span)
}
}
type PendingWrite = Pin<Box<dyn FusedFuture<Output = Result<(), Error>> + Send + Sync + 'static>>;
type PendingRead<OneWay, Request, Response> = Pin<
Box<
dyn FusedFuture<Output = Result<Envelope<OneWay, Request, Response>, Error>>
+ Send
+ Sync
+ 'static,
>,
>;
pub struct Connection<OneWay: Msg, Request: Msg, Response: Msg> {
reader: Arc<Mutex<EnvelopeReader>>,
writer: Arc<Mutex<EnvelopeWriter>>,
reads_in_progress: FuturesUnordered<PendingRead<OneWay, Request, Response>>,
writes_in_progress: FuturesUnordered<PendingWrite>,
pub queue: Queue<OneWay, Request, Response>,
dequeue: UnboundedReceiver<Envelope<OneWay, Request, Response>>,
responses: FuturesUnordered<BoxFuture<'static, (u64, Response)>>,
envelope_count: usize,
}
#[serde(bound = "")]
#[derive(Debug, Serialize, Deserialize)]
enum Envelope<OneWay: Msg, Request: Msg, Response: Msg> {
OneWay(OneWay),
Request(u64, Request),
Response(u64, Response),
}
impl<OneWay: Msg, Request: Msg, Response: Msg> Connection<OneWay, Request, Response> {
pub fn new_split<R, W>(rdr: R, wtr: W) -> Self
where
R: AsyncRead + Unpin + Send + Sync + 'static,
W: AsyncWrite + Unpin + Send + Sync + 'static,
{
let reader = Arc::new(Mutex::new(EnvelopeReader::new(rdr)));
let writer = Arc::new(Mutex::new(EnvelopeWriter::new(wtr)));
let next_request = 0;
let requests = HashMap::new();
let responses = FuturesUnordered::new();
let (enqueue, dequeue) = unbounded();
let queue = Queue::new(Reception {
next_request,
requests,
enqueue,
});
let reads_in_progress = FuturesUnordered::new();
let writes_in_progress = FuturesUnordered::new();
Connection {
reader,
writer,
queue,
reads_in_progress,
writes_in_progress,
responses,
dequeue,
envelope_count: 0,
}
}
pub fn new<RW: AsyncReadWrite>(rw: RW) -> Self {
let (rdr, wtr) = rw.split();
Self::new_split(rdr, wtr)
}
pub fn enqueue_oneway(
&self,
oneway: OneWay,
) -> impl Future<Output = Result<(), Error>> + 'static {
self.queue.enqueue_oneway(oneway)
}
pub fn enqueue_request(
&self,
req: Request,
) -> impl Future<Output = Result<Response, Error>> + 'static {
self.queue.enqueue_request(req)
}
fn issue_read(&mut self) {
let rdr = self.reader.clone();
let fut = Box::pin(
async move { rdr.lock().await.recv::<OneWay, Request, Response>().await }.fuse(),
);
self.reads_in_progress.push(fut);
}
fn issue_write(&mut self, env: Envelope<OneWay, Request, Response>) {
let wtr = self.writer.clone();
let fut = Box::pin(async move { wtr.lock().await.send(env).await }.fuse());
self.writes_in_progress.push(fut);
}
#[instrument(skip(self, srv_req, srv_ow))]
pub async fn advance<ServeRequest, FutureResponse, ServeOneWay>(
&mut self,
srv_req: ServeRequest,
srv_ow: ServeOneWay,
) -> Result<(), Error>
where
ServeRequest: FnOnce(Request) -> FutureResponse,
FutureResponse: Future<Output = Response> + Send + 'static,
ServeOneWay: FnOnce(OneWay) -> (),
{
if self.reads_in_progress.len() == 0 {
self.issue_read();
}
select! {
next_written = self.writes_in_progress.next() => match next_written {
None => (Ok(())),
Some(res) => res
},
next_enqueued = self.dequeue.next() => match next_enqueued {
None => Ok(()),
Some(env) => {
trace!("dequeued envelope, sending");
self.issue_write(env);
Ok(())
}
},
next_response = self.responses.next() => {
match next_response {
None => Ok(()),
Some((n, response)) => {
trace!(n, "finished serving request, enqueueing response");
let env = Envelope::Response(n, response);
let guard = self.queue.reception.lock().await;
guard.enqueue.unbounded_send(env).map_err(|_| Error::Queue)
}
}
},
next_read = self.reads_in_progress.next() => match next_read {
None => Ok(()),
Some(read_result) => {
self.envelope_count += 1;
self.issue_read();
let env = read_result?;
match env {
Envelope::OneWay(ow) => {
trace!("received one-way, calling service function");
let span = trace_span!("oneway", e=self.envelope_count);
Ok(span.in_scope(|| srv_ow(ow)))
},
Envelope::Request(n, req) => {
trace!(n, "received request, calling service function");
let span = trace_span!("req", e=self.envelope_count);
let res_fut = srv_req(req);
let boxed : BoxFuture<'static,_> = Box::pin(res_fut.instrument(span).map(move |r| (n, r)));
Ok(self.responses.push(boxed))
},
Envelope::Response(n, res) => {
trace!(n, "received response, fulfilling future");
let mut guard = self.queue.reception.lock().await;
match guard.requests.remove(&n.clone()) {
None => Err(Error::UnknownResponse(n)),
Some(send) => {
match send.send(res) {
Ok(_) => Ok(()),
Err(_) => Err(Error::ResponseChannelDropped(n))
}
}
}
}
}
}
}
}
}
}