#[doc(hidden)]
pub mod re_export {
pub extern crate serde;
pub extern crate tracing;
pub use std::{
boxed::Box,
concat,
pin::Pin,
stringify,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
};
pub use futures::{Future, Sink, Stream};
pub use serde_derive::{Deserialize, Serialize};
pub use tracing::Instrument;
pub use super::*;
pub use crate::error::{Error, Result};
}
use std::{collections::HashMap, fmt::Debug, marker::PhantomData, pin::Pin, sync::Arc};
use futures::{
channel::{mpsc, oneshot},
future::{select, Either},
Future, FutureExt, Sink, SinkExt, Stream, StreamExt,
};
use tracing::Instrument;
use crate::error::{Error, Result};
pub trait Rpc {
type Request;
type Response;
}
pub trait RpcServerStub<R: Rpc, I: RpcFrame<R::Request>, O: RpcFrame<R::Response>> {
fn make_response(self: Arc<Self>, req: I) -> Pin<Box<dyn Future<Output = Option<O>> + Send>>;
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub struct RequestId(pub u64);
impl std::fmt::Display for RequestId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[{:016X}]", self.0)
}
}
pub trait RpcFrame<T>: Sized + Send + 'static {
fn from_parts(id: RequestId, data: T) -> Result<Self>;
fn get_id(&self) -> RequestId;
fn get_data(self) -> Result<T>;
}
impl<T: Sized + Send + 'static> RpcFrame<T> for (RequestId, T) {
fn from_parts(id: RequestId, data: T) -> Result<Self> {
Ok((id, data))
}
fn get_id(&self) -> RequestId {
self.0
}
fn get_data(self) -> Result<T> {
Ok(self.1)
}
}
pub async fn serve<R, S, I, O, T, U>(
stub: impl Into<Arc<S>>,
mut recv: T,
mut send: U,
) -> Result<()>
where
R: Rpc,
S: RpcServerStub<R, I, O>,
I: RpcFrame<R::Request>,
O: RpcFrame<R::Response>,
T: Stream<Item = Result<I>> + Unpin,
U: Sink<O, Error = Error> + Unpin,
{
let stub: Arc<S> = stub.into();
let (tx, mut rx) = mpsc::channel::<O>(128);
let mut fut = select(recv.next(), rx.next());
loop {
match fut.await {
Either::Left((Some(req), r)) => {
let req = req?;
let id = req.get_id();
let stub = stub.clone();
let mut tx = tx.clone();
tokio::spawn(
stub.make_response(req)
.instrument(debug_span!("server", %id))
.then(move |res| async move {
if let Some(res) = res {
if let Err(e) = tx.send(res).await {
assert!(e.is_disconnected());
error!("driver closed unexpectedly");
}
}
}),
);
fut = select(recv.next(), r);
}
Either::Right((Some(rsp), r)) => {
send.send(rsp).await?;
fut = select(r, rx.next());
}
_ => {
break Ok(());
}
}
}
}
pub struct RpcClient<'a, R: Rpc, I: RpcFrame<R::Response>, O: RpcFrame<R::Request>>(
mpsc::Sender<(oneshot::Sender<Result<I>>, O)>,
PhantomData<&'a R>,
);
impl<R: Rpc, I: RpcFrame<R::Response>, O: RpcFrame<R::Request>> RpcClient<'static, R, I, O> {
pub fn new<
T: Stream<Item = Result<I>> + Unpin + Send + 'static,
U: Sink<O, Error = Error> + Unpin + Send + 'static,
>(
recv: T,
send: U,
) -> Self {
let (d, r) = Self::new_with_driver(recv, send);
tokio::spawn(d);
r
}
}
impl<'a, R: Rpc, I: RpcFrame<R::Response>, O: RpcFrame<R::Request>> RpcClient<'a, R, I, O> {
pub fn new_with_driver<T, U>(recv: T, send: U) -> (impl Future<Output = ()> + 'a, Self)
where
T: Stream<Item = Result<I>> + Unpin + 'a,
U: Sink<O, Error = Error> + Unpin + 'a,
{
async fn driver<'a, R, I, O, T, U>(
mut rx: mpsc::Receiver<(oneshot::Sender<Result<I>>, O)>,
mut recv: T,
mut send: U,
) where
R: Rpc,
I: RpcFrame<R::Response>,
O: RpcFrame<R::Request>,
T: Stream<Item = Result<I>> + Unpin + 'a,
U: Sink<O, Error = Error> + Unpin + 'a,
{
let mut fut = select(rx.next(), recv.next());
let mut req_map = HashMap::with_capacity(128);
let ret = loop {
match fut.await {
Either::Left((Some((callback, req)), r)) => {
let id = req.get_id();
if let Err(e) = send.send(req).await {
callback
.send(Err(e))
.unwrap_or_else(|_| error!("client closed unexpectedly"));
} else if req_map.insert(id, callback).is_some() {
panic!("request id is not unique")
}
fut = select(rx.next(), r);
}
Either::Right((Some(rsp), r)) => {
let rsp = match rsp {
Ok(rsp) => rsp,
Err(e) => break Err(e),
};
let id = rsp.get_id();
if let Some(callback) = req_map.remove(&id) {
callback
.send(Ok(rsp))
.unwrap_or_else(|_| error!("client closed unexpectedly"));
} else {
warn!("Server responeded for nonexist request: {}", id);
}
fut = select(r, recv.next());
}
_ => {
break Ok(());
}
}
};
if let Err(e) = ret {
let mut e = Some(e);
for (_id, r) in req_map.into_iter() {
match r.send(Err(e.take().unwrap())) {
Ok(()) => break,
Err(r) => {
e = match r {
Err(x) => Some(x),
_ => unreachable!(),
};
}
}
}
if let Some(e) = e {
error!("failed to send error in driver: {}", e);
}
}
}
let (tx, rx) = mpsc::channel::<(oneshot::Sender<Result<I>>, O)>(128);
(
driver::<'a, R, I, O, T, U>(rx, recv, send),
Self(tx, PhantomData),
)
}
pub async fn make_request(&mut self, req: O) -> Result<I> {
let (tx, rx) = oneshot::channel();
self.0.send((tx, req)).await.map_err(|e| {
assert!(e.is_disconnected());
Error::DriverStopped
})?;
rx.await.unwrap_or(Err(Error::DriverStopped))
}
}
impl<'a, R: Rpc, I: RpcFrame<R::Response>, O: RpcFrame<R::Request>> Debug
for RpcClient<'a, R, I, O>
{
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "RpcClient")
}
}
impl<'a, R: Rpc, I: RpcFrame<R::Response>, O: RpcFrame<R::Request>> Clone
for RpcClient<'a, R, I, O>
{
#[inline]
fn clone(&self) -> Self {
Self(self.0.clone(), PhantomData)
}
}