#![recursion_limit="512"]
use serde::{Serialize,Deserialize,de::DeserializeOwned};
use serde_cbor::{ser::to_vec_packed,de::from_slice};
use futures;
use futures::{Future,future::{BoxFuture,FutureExt}};
use futures::io::{AsyncRead,AsyncWrite,AsyncReadExt,AsyncWriteExt};
use futures::channel::oneshot::{channel,Sender};
use futures::stream::{FuturesUnordered,StreamExt};
use futures::lock::Mutex;
use futures::{select_biased};
use futures::channel::mpsc::{unbounded,UnboundedReceiver,UnboundedSender};
use thiserror::Error;
use std::{sync::Arc, collections::{HashMap}};
use log::trace;
#[cfg(test)]
mod test;
pub trait AsyncReadWrite : AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {}
impl<T> AsyncReadWrite for T where T : AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {}
pub trait Msg : Serialize + DeserializeOwned + Send + Sync + 'static {}
impl<T> Msg for T where T : Serialize + DeserializeOwned + Send + Sync + 'static {}
#[derive(Error,Debug)]
pub enum Error {
#[error("queue management error")]
Queue,
#[error("unknown response {0}")]
UnknownResponse(u64),
#[error("response channel {0} dropped")]
ResponseChannelDropped(u64),
#[error(transparent)]
Io(#[from] futures::io::Error),
#[error(transparent)]
Cbor(#[from] serde_cbor::Error),
#[error(transparent)]
Canceled(#[from] futures::channel::oneshot::Canceled),
}
struct IO {
rw: Mutex<Box<dyn AsyncReadWrite>>,
buf: Vec<u8>,
}
impl IO {
fn new<RW:AsyncReadWrite>(rw:RW) -> Self {
IO {
rw: Mutex::new(Box::new(rw)),
buf: Vec::new(),
}
}
async fn send<OneWay:Msg, Request:Msg, Response:Msg>(&mut self, e: Envelope<OneWay,Request,Response>) -> Result<(), Error> {
use byteorder_async::{WriterToByteOrder,LittleEndian};
let bytes = to_vec_packed(&e)?;
let wsz: u64 = bytes.len() as u64;
trace!("sending {}-byte envelope at IO level", wsz);
let mut guard = self.rw.lock().await;
guard.byte_order().write_u64::<LittleEndian>(wsz).await?;
guard.write_all(bytes.as_slice()).await?;
trace!("sent {}-byte envelope at IO level", wsz);
Ok(())
}
async fn recv<OneWay:Msg, Request:Msg, Response:Msg>(&mut self) -> Result<Envelope<OneWay,Request,Response>, Error> {
trace!("receiving envelope at IO level");
use byteorder_async::{ReaderToByteOrder,LittleEndian};
let mut guard = self.rw.lock().await;
let rsz: u64 = guard.byte_order().read_u64::<LittleEndian>().await?;
self.buf.resize(rsz as usize, 0);
guard.read_exact(self.buf.as_mut_slice()).await?;
trace!("received {}-byte envelope at IO level", rsz);
Ok(from_slice(self.buf.as_slice())?)
}
}
struct Reception<OneWay:Msg, Request:Msg, Response:Msg>
{
next_request: u64,
requests: HashMap<u64, Sender<Response>>,
enqueue: UnboundedSender<Envelope<OneWay,Request,Response>>,
}
pub struct Connection<OneWay:Msg, Request:Msg, Response:Msg> {
io: IO,
reception: Arc<Mutex<Reception<OneWay,Request,Response>>>,
dequeue: UnboundedReceiver<Envelope<OneWay,Request,Response>>,
responses: Mutex<FuturesUnordered<BoxFuture<'static, (u64, Response)>>>,
}
#[serde(bound = "")]
#[derive(Serialize,Deserialize)]
enum Envelope<OneWay:Msg, Request:Msg, Response:Msg> {
OneWay(OneWay),
Request(u64,Request),
Response(u64,Response)
}
impl<OneWay:Msg, Request:Msg, Response:Msg>
Connection<OneWay, Request, Response> {
pub fn new<RW:AsyncReadWrite>(rw:RW) -> Self {
let io = IO::new(rw);
let next_request = 0;
let requests = HashMap::new();
let responses = Mutex::new(FuturesUnordered::new());
let (enqueue, dequeue) = unbounded();
let reception = Arc::new(Mutex::new(Reception{next_request, requests, enqueue}));
Connection { io, reception, responses, dequeue }
}
pub fn enqueue_oneway(&self, oneway: OneWay) -> impl Future<Output=Result<(), Error>> + 'static
{
let reception = self.reception.clone();
async move {
let env = Envelope::<OneWay,Request,Response>::OneWay(oneway);
let guard = reception.lock().await;
guard.enqueue.unbounded_send(env).map_err(|_| Error::Queue)
}
}
pub fn enqueue_request(&self, req: Request) -> impl Future<Output=Result<Response, Error>> + 'static {
let reception = self.reception.clone();
async move {
let (send_err, recv) = {
let mut guard = reception.lock().await;
let curr = guard.next_request;
let env = Envelope::<OneWay,Request,Response>::Request(curr, req);
let send_err = guard.enqueue.unbounded_send(env);
let (send, recv) = channel();
if send_err.is_ok() {
trace!("enqueued envelope for request {}", curr);
guard.next_request += 1;
guard.requests.insert(curr, send);
}
(send_err, recv)
};
if send_err.is_ok() {
Ok(recv.await?)
} else {
Err(futures::future::ready(Error::Queue).await)
}
}
}
pub async fn advance<ServeRequest, FutureResponse, ServeOneWay>(&mut self, srv_req:ServeRequest, srv_ow:ServeOneWay) -> Result<(), Error>
where ServeRequest: FnOnce(Request)->FutureResponse,
FutureResponse: Future<Output=Response> + Send + 'static,
ServeOneWay: FnOnce(OneWay)->()
{
let mut resp_guard = self.responses.lock().await;
select_biased! {
next_enqueued = self.dequeue.next() => match next_enqueued {
None => Ok(()),
Some(env) => {
trace!("dequeued envelope, sending");
Ok(self.io.send(env).await?)
}
},
next_response = resp_guard.next() => match next_response {
None => Ok(()),
Some((n, response)) => {
let env = Envelope::Response(n, response);
trace!("finished serving request {}, enqueueing response", n);
self.reception.lock().await.enqueue.unbounded_send(env).map_err(|_| Error::Queue)
}
},
read_result = self.io.recv::<OneWay,Request,Response>().fuse() => {
let env = read_result?;
match env {
Envelope::OneWay(ow) => {
trace!("received one-way envelope, calling service function");
Ok(srv_ow(ow))
},
Envelope::Request(n, req) => {
trace!("received request envelope {}, calling service function", n);
let res_fut = srv_req(req);
let boxed : BoxFuture<'static,_> = Box::pin(res_fut.map(move |r| (n, r)));
Ok(resp_guard.push(boxed))
},
Envelope::Response(n, res) => {
trace!("received response envelope {}, transferring to future", n);
match self.reception.lock().await.requests.remove(&n.clone()) {
None => Err(Error::UnknownResponse(n)),
Some(send) => {
match send.send(res) {
Ok(_) => Ok(()),
Err(_) => Err(Error::ResponseChannelDropped(n))
}
}
}
}
}
}
}
}
}