mod formats;
mod util;
use formats::*;
use futures::channel::{mpsc, oneshot};
use futures::stream;
use futures::stream::BoxStream;
use futures::{future, Future, Sink};
use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
use serde::de::DeserializeOwned;
use serde::Serialize;
use serde_json::value::{RawValue, Value};
use std::collections::{BTreeMap, HashMap};
use std::panic::AssertUnwindSafe;
use std::sync::Arc;
use util::UtilStreamExt;
use warp::filters::ws::{Message, WebSocket};
const WS_SEND_BUFFER_SIZE: usize = 1024;
const REQUEST_GC_THRESHOLD: usize = 64;
const INTER_STREAM_FAIRNESS: u64 = 64;
pub trait Service {
type Req: DeserializeOwned;
type Resp: Serialize + 'static;
type Error: Serialize + 'static;
type Ctx: Clone;
fn serve(
&self,
ctx: Self::Ctx,
req: Self::Req,
) -> BoxStream<'static, Result<Self::Resp, Self::Error>>;
fn boxed(self) -> BoxedService<Self::Ctx>
where
Self: Send + Sized + Sync + 'static,
{
Box::new(self)
}
}
pub trait WebsocketService<Ctx: Clone> {
fn serve_ws(
&self,
ctx: Ctx,
raw_req: Value,
service_id: &str,
) -> BoxStream<'static, Result<Box<RawValue>, ErrorKind>>;
}
impl<Req, Resp, Ctx, S> WebsocketService<Ctx> for S
where
S: Service<Req = Req, Resp = Resp, Ctx = Ctx>,
Req: DeserializeOwned,
Resp: Serialize + 'static,
Ctx: Clone,
{
fn serve_ws(
&self,
ctx: Ctx,
raw_req: Value,
service_id: &str,
) -> BoxStream<'static, Result<Box<RawValue>, ErrorKind>> {
tracing::trace!(
"Serving raw request for service {}: {:?}",
service_id,
raw_req
);
match serde_json::from_value(raw_req) {
Ok(req) => self
.serve(ctx, req)
.map(|resp_result| {
resp_result
.map(|resp| {
serde_json::value::to_raw_value(&resp)
.expect("Could not serialize service response")
})
.map_err(|err| ErrorKind::ServiceError {
value: serde_json::to_value(&err)
.expect("Could not serialize service error response"),
})
})
.boxed(),
Err(cause) => {
let message = format!("{}", cause);
tracing::warn!(
"Error deserializing request for service {}: {}",
service_id,
message
);
stream::once(future::err(ErrorKind::BadRequest { message })).boxed()
}
}
}
}
pub type BoxedService<Ctx> = Box<dyn WebsocketService<Ctx> + Send + Sync>;
pub async fn serve<Ctx: Clone + Send + 'static>(
ws: warp::ws::Ws,
services: Arc<BTreeMap<&'static str, BoxedService<Ctx>>>,
ctx: Ctx,
) -> Result<impl warp::Reply, warp::Rejection> {
Ok(ws
.max_frame_size(64 << 20)
.max_message_size(128 << 20)
.on_upgrade(move |socket| client_connected(socket, ctx, services).map(|_| ())))
}
#[allow(clippy::cognitive_complexity)]
fn client_connected<Ctx: Clone + Send + 'static>(
ws: WebSocket,
ctx: Ctx,
services: Arc<BTreeMap<&'static str, BoxedService<Ctx>>>,
) -> impl Future<Output = Result<(), ()>> {
let (ws_out, ws_in) = ws.split();
let (mut mux_in, mux_out) = mpsc::channel::<Result<Message, warp::Error>>(WS_SEND_BUFFER_SIZE);
let mut active_responses: HashMap<ReqId, oneshot::Sender<()>> = HashMap::new();
tokio::spawn(mux_out.fuse().forward(ws_out).map(|_| ()));
ws_in
.try_for_each(move |raw_msg| {
if active_responses.len() > REQUEST_GC_THRESHOLD {
active_responses.retain(|_, canceled| !canceled.is_canceled());
}
if let Ok(text_msg) = raw_msg.to_str() {
match serde_json::from_str::<Incoming>(text_msg) {
Ok(req_env) => match req_env {
Incoming::Request(body) => {
if let Some(srv) = services.get(body.service_id) {
let (snd_cancel, rcv_cancel) = oneshot::channel();
if let Some(previous) =
active_responses.insert(body.request_id, snd_cancel)
{
cancel_response_stream(previous);
};
tokio::spawn(serve_request(
rcv_cancel,
srv,
ctx.clone(),
body.service_id,
body.request_id,
body.payload,
mux_in.clone(),
));
} else {
tokio::spawn(serve_error(
body.request_id,
ErrorKind::UnknownEndpoint {
endpoint: body.service_id.to_string(),
valid_endpoints: services
.keys()
.map(|e| e.to_string())
.collect::<Vec<String>>(),
},
mux_in.clone(),
));
tracing::warn!(
"Client tried to access unknown service: {}",
body.service_id
);
}
}
Incoming::Cancel { request_id } => {
if let Some(snd_cancel) = active_responses.remove(&request_id) {
cancel_response_stream(snd_cancel);
}
}
},
Err(cause) => {
tracing::warn!(
"Could not deserialize client request {}: {}",
text_msg,
cause
);
cancel_response_streams_close_channel(&mut active_responses, &mut mux_in);
}
}
} else if raw_msg.is_ping() {
} else if raw_msg.is_close() {
tracing::debug!("Closing websocket connection (client disconnected)");
cancel_response_streams_close_channel(&mut active_responses, &mut mux_in);
} else {
tracing::warn!("Expected TEXT Websocket message but got binary");
cancel_response_streams_close_channel(&mut active_responses, &mut mux_in);
};
future::ok(())
})
.map_err(|err| {
tracing::info!("Websocket closed with error {}", err);
})
}
#[allow(clippy::cognitive_complexity)]
fn cancel_response_stream(snd_cancel: oneshot::Sender<()>) {
if snd_cancel.is_canceled() {
tracing::trace!("Not trying to cancel response stream whose cancel rcv has already dropped")
} else {
match snd_cancel.send(()) {
Ok(_) => tracing::debug!("Merged Cancel signal into ongoing response stream"),
Err(_) => tracing::debug!("Response stream we are trying to stop has already stopped"),
}
}
}
fn cancel_response_streams_close_channel(
active_responses: &mut HashMap<ReqId, oneshot::Sender<()>>,
mux_in: &mut mpsc::Sender<Result<Message, warp::Error>>,
) {
for (_, snd_cancel) in active_responses.drain() {
cancel_response_stream(snd_cancel);
}
mux_in.close_channel();
}
fn serve_request_stream<Ctx: Clone>(
srv: &BoxedService<Ctx>,
ctx: Ctx,
service_id: &str,
req_id: ReqId,
payload: Value,
) -> impl Stream<Item = Result<Message, warp::Error>> {
let resp_stream = srv
.serve_ws(ctx, payload, service_id)
.take_until_condition(|resp| future::ready(resp.is_err()))
.ready_chunks(128)
.flat_map(move |payload_results| {
let mut err = None;
let mut payload = Vec::with_capacity(payload_results.len());
for payload_result in payload_results {
match payload_result {
Ok(value) => payload.push(value),
Err(kind) => err = Some(kind), }
}
let mut res = Vec::with_capacity(1);
if !payload.is_empty() {
res.push(Outgoing::Next {
request_id: req_id,
payload,
});
}
if let Some(kind) = err {
res.push(Outgoing::Error {
request_id: req_id,
kind,
});
}
stream::iter(res)
});
AssertUnwindSafe(resp_stream)
.catch_unwind()
.map(move |msg_result| match msg_result {
Ok(msg) => msg,
Err(_) => Outgoing::Error {
request_id: req_id,
kind: ErrorKind::InternalError,
},
})
.chain(stream::once(future::ready(Outgoing::Complete {
request_id: req_id,
})))
.map(|env| Ok(Message::text(serde_json::to_string(&env).unwrap())))
}
fn serve_request<T: std::fmt::Debug, Ctx: Clone>(
canceled: oneshot::Receiver<()>,
srv: &BoxedService<Ctx>,
ctx: Ctx,
service_id: &str,
req_id: ReqId,
payload: Value,
output: impl Sink<Result<Message, warp::Error>, Error = T>,
) -> impl Future<Output = ()> {
let response_stream = serve_request_stream(srv, ctx, service_id, req_id, payload)
.take_until_signaled(canceled)
.map(|item| {
Ok(item)
});
let service_id = service_id.to_owned();
response_stream
.yield_after(INTER_STREAM_FAIRNESS)
.forward(output)
.map(move |result| {
if let Err(cause) = result {
tracing::warn!(%service_id, "Multiplexing error {:?}", cause);
};
})
}
fn serve_error<S>(req_id: ReqId, error_kind: ErrorKind, output: S) -> impl Future<Output = ()>
where
S: Sink<Result<Message, warp::Error>>,
S::Error: std::fmt::Debug,
{
let msg = Outgoing::Error {
request_id: req_id,
kind: error_kind,
};
let raw_msg = Message::text(serde_json::to_string_pretty(&msg).unwrap());
stream::once(future::ok(Ok(raw_msg)))
.forward(output)
.map(|result| {
if let Err(err) = result {
tracing::warn!("Could not send Error message: {:?}", err);
};
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Service;
use futures::stream;
use futures::stream::BoxStream;
use futures::stream::StreamExt;
use futures::task::Poll;
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use std::thread::JoinHandle;
use warp::Filter;
use websocket::{ClientBuilder, OwnedMessage};
#[derive(Serialize, Deserialize)]
enum Request {
Count(u64), Size(String), Ctx, Fail(String), Panic, }
#[derive(Serialize, Deserialize)]
struct BadRequest {
bad_field: String,
}
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
struct Response(u64);
struct TestService();
impl TestService {
fn new() -> TestService {
TestService()
}
}
impl Service for TestService {
type Req = Request;
type Resp = Response;
type Error = String;
type Ctx = u64;
fn serve(&self, ctx: u64, req: Request) -> BoxStream<'static, Result<Response, String>> {
match req {
Request::Count(cnt) => {
let mut ctr = 0;
stream::poll_fn(move |_| {
let output = ctr;
ctr += 1;
if ctr <= cnt {
Poll::Ready(Some(Ok(Response(output))))
} else {
Poll::Ready(None)
}
})
.boxed()
}
Request::Size(data) => {
stream::once(future::ok(Response(data.len() as u64))).boxed()
}
Request::Ctx => stream::once(future::ok(Response(ctx))).boxed(),
Request::Fail(reason) => stream::once(future::err(reason)).boxed(),
Request::Panic => stream::poll_fn(|_| panic!("Test panic")).boxed(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type")]
#[serde(rename_all = "camelCase")]
pub enum OutgoingAst {
#[serde(rename_all = "camelCase")]
Next {
request_id: ReqId,
payload: Vec<Value>,
},
#[serde(rename_all = "camelCase")]
Complete { request_id: ReqId },
#[serde(rename_all = "camelCase")]
Error { request_id: ReqId, kind: ErrorKind },
}
impl OutgoingAst {
pub fn request_id(&self) -> ReqId {
match self {
OutgoingAst::Next { request_id, .. } => *request_id,
OutgoingAst::Complete { request_id, .. } => *request_id,
OutgoingAst::Error { request_id, .. } => *request_id,
}
}
}
fn test_client<Req: Serialize, Resp: DeserializeOwned>(
addr: SocketAddr,
endpoint: &str,
id: u64,
req: Req,
) -> (Vec<Resp>, OutgoingAst) {
let addr = format!("ws://{}/test_ws", addr);
let client = ClientBuilder::new(&*addr)
.expect("Could not setup client")
.connect_insecure()
.expect("Could not connect to test server");
let (mut receiver, mut sender) = client.split().unwrap();
let payload = serde_json::to_value(req).expect("Could not serialize request");
let req_env = Incoming::Request(RequestBody {
service_id: endpoint,
request_id: ReqId(id),
payload,
});
let req_env_json =
serde_json::to_string(&req_env).expect("Could not serialize request envelope");
sender
.send_message(&OwnedMessage::Text(req_env_json))
.expect("Could not send request");
let mut completion: Option<OutgoingAst> = None;
let msgs = receiver
.incoming_messages()
.filter_map(move |msg| {
let msg_ok = msg.expect("Expected message but got websocket error");
if let OwnedMessage::Text(raw_resp) = msg_ok {
let resp_env: OutgoingAst = serde_json::from_str(&*raw_resp)
.expect("Could not deserialize response envelope");
if resp_env.request_id().0 == id {
Some(resp_env)
} else {
None
}
} else {
None
}
})
.take_while(|env| {
if let OutgoingAst::Next { .. } = env {
true
} else {
completion = Some(env.clone());
false
}
})
.flat_map(|env| {
if let OutgoingAst::Next { payload, .. } = env {
payload
.into_iter()
.map(|p| {
serde_json::from_value::<Resp>(p)
.expect("Could not deserialize response")
})
.collect()
} else {
vec![]
}
})
.collect();
(msgs, completion.expect("Expected a completion message"))
}
async fn start_test_service() -> SocketAddr {
let services = Arc::new(maplit::btreemap! {"test" => TestService::new().boxed()});
let ws = warp::path("test_ws")
.and(warp::ws())
.and(warp::any().map(move || services.clone()))
.and(warp::any().map(move || 23))
.and_then(super::serve);
let (addr, task) = warp::serve(ws).bind_ephemeral(([127, 0, 0, 1], 0));
tokio::spawn(task);
addr
}
#[tokio::test(flavor = "multi_thread")]
async fn properly_serve_single_request() {
let addr = start_test_service().await;
assert_eq!(
test_client::<Request, Response>(addr, "test", 0, Request::Count(5)).0,
vec![
Response(0),
Response(1),
Response(2),
Response(3),
Response(4)
]
);
}
#[tokio::test(flavor = "multi_thread")]
async fn properly_serve_single_request_ctx() {
let addr = start_test_service().await;
assert_eq!(
test_client::<Request, Response>(addr, "test", 0, Request::Ctx).0,
vec![Response(23)]
);
}
#[tokio::test(flavor = "multi_thread")]
async fn properly_serve_large_request() {
let addr = start_test_service().await;
let len = 20_000_000;
let data: String = std::iter::repeat('x').take(len).collect::<String>();
assert_eq!(
test_client::<Request, Response>(addr, "test", 0, Request::Size(data)).0,
vec![Response(len as u64)]
);
}
#[tokio::test(flavor = "multi_thread")]
async fn multiplex_multiple_queries() {
let addr = start_test_service().await;
let client_cnt = 50;
let request_cnt = 100;
let start_barrier = Arc::new(std::sync::Barrier::new(client_cnt));
let join_handles: Vec<JoinHandle<Vec<Response>>> = (0..client_cnt)
.map(|i| {
let b = start_barrier.clone();
std::thread::spawn(move || {
b.wait();
test_client::<Request, Response>(
addr,
"test",
i as u64,
Request::Count(request_cnt),
)
.0
})
})
.collect();
let expected: Vec<Response> = (0..request_cnt).map(|i| Response(i as u64)).collect();
for handle in join_handles {
assert_eq!(handle.join().unwrap(), expected)
}
}
#[tokio::test(flavor = "multi_thread")]
async fn report_wrong_endpoint() {
let addr = start_test_service().await;
let (msgs, completion) =
test_client::<Request, Response>(addr, "no_such_service", 49, Request::Count(5));
assert_eq!(msgs, vec![]);
assert_eq!(
completion,
OutgoingAst::Error {
request_id: ReqId(49),
kind: ErrorKind::UnknownEndpoint {
endpoint: "no_such_service".to_string(),
valid_endpoints: vec!["test".to_string()],
}
}
);
}
#[tokio::test(flavor = "multi_thread")]
async fn report_badly_formatted_request() {
let addr = start_test_service().await;
let (msgs, completion) = test_client::<BadRequest, Response>(
addr,
"test",
49,
BadRequest {
bad_field: "xzy".to_string(),
},
);
assert_eq!(msgs, vec![]);
if let OutgoingAst::Error {
request_id: ReqId(49),
kind: ErrorKind::BadRequest { message },
} = completion
{
assert!(message.starts_with("unknown variant"));
} else {
panic!();
}
}
#[tokio::test(flavor = "multi_thread")]
async fn report_service_error() {
let addr = start_test_service().await;
let (msgs, completion) = test_client::<Request, Response>(
addr,
"test",
49,
Request::Fail("Test reason".to_string()),
);
assert_eq!(msgs, vec![]);
assert_eq!(
completion,
OutgoingAst::Error {
request_id: ReqId(49),
kind: ErrorKind::ServiceError {
value: Value::String("Test reason".to_string())
},
}
);
}
#[tokio::test(flavor = "multi_thread")]
async fn report_service_panic() {
let addr = start_test_service().await;
let (msgs, completion) = test_client::<Request, Response>(addr, "test", 49, Request::Panic);
assert_eq!(msgs, vec![]);
assert_eq!(
completion,
OutgoingAst::Error {
request_id: ReqId(49),
kind: ErrorKind::InternalError,
}
);
}
}