1mod formats;
17mod util;
18
19use formats::*;
20use futures::channel::{mpsc, oneshot};
21use futures::stream;
22use futures::stream::BoxStream;
23use futures::{future, Future, Sink};
24use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
25use serde::de::DeserializeOwned;
26use serde::Serialize;
27use serde_json::value::{RawValue, Value};
28use std::collections::{BTreeMap, HashMap};
29use std::panic::AssertUnwindSafe;
30use std::sync::Arc;
31use util::UtilStreamExt;
32use warp::filters::ws::{Message, WebSocket};
33
34const WS_SEND_BUFFER_SIZE: usize = 1024;
35const REQUEST_GC_THRESHOLD: usize = 64;
36const INTER_STREAM_FAIRNESS: u64 = 64;
37
38pub trait Service {
39 type Req: DeserializeOwned;
40 type Resp: Serialize + 'static;
41 type Error: Serialize + 'static;
42 type Ctx: Clone;
43
44 fn serve(
45 &self,
46 ctx: Self::Ctx,
47 req: Self::Req,
48 ) -> BoxStream<'static, Result<Self::Resp, Self::Error>>;
49
50 fn boxed(self) -> BoxedService<Self::Ctx>
51 where
52 Self: Send + Sized + Sync + 'static,
53 {
54 Box::new(self)
55 }
56}
57
58pub trait WebsocketService<Ctx: Clone> {
59 fn serve_ws(
60 &self,
61 ctx: Ctx,
62 raw_req: Value,
63 service_id: &str,
64 ) -> BoxStream<'static, Result<Box<RawValue>, ErrorKind>>;
65}
66
67impl<Req, Resp, Ctx, S> WebsocketService<Ctx> for S
68where
69 S: Service<Req = Req, Resp = Resp, Ctx = Ctx>,
70 Req: DeserializeOwned,
71 Resp: Serialize + 'static,
72 Ctx: Clone,
73{
74 fn serve_ws(
75 &self,
76 ctx: Ctx,
77 raw_req: Value,
78 service_id: &str,
79 ) -> BoxStream<'static, Result<Box<RawValue>, ErrorKind>> {
80 tracing::trace!(
81 "Serving raw request for service {}: {:?}",
82 service_id,
83 raw_req
84 );
85 match serde_json::from_value(raw_req) {
86 Ok(req) => self
87 .serve(ctx, req)
88 .map(|resp_result| {
89 resp_result
90 .map(|resp| {
91 serde_json::value::to_raw_value(&resp)
92 .expect("Could not serialize service response")
93 })
94 .map_err(|err| ErrorKind::ServiceError {
95 value: serde_json::to_value(&err)
96 .expect("Could not serialize service error response"),
97 })
98 })
99 .boxed(),
100 Err(cause) => {
101 let message = format!("{}", cause);
102 tracing::warn!(
103 "Error deserializing request for service {}: {}",
104 service_id,
105 message
106 );
107 stream::once(future::err(ErrorKind::BadRequest { message })).boxed()
108 }
109 }
110 }
111}
112
113pub type BoxedService<Ctx> = Box<dyn WebsocketService<Ctx> + Send + Sync>;
114
115pub async fn serve<Ctx: Clone + Send + 'static>(
116 ws: warp::ws::Ws,
117 services: Arc<BTreeMap<&'static str, BoxedService<Ctx>>>,
118 ctx: Ctx,
119) -> Result<impl warp::Reply, warp::Rejection> {
120 Ok(ws
122 .max_frame_size(64 << 20)
123 .max_message_size(128 << 20)
125 .on_upgrade(move |socket| client_connected(socket, ctx, services).map(|_| ())))
126 }
128
129#[allow(clippy::cognitive_complexity)]
130fn client_connected<Ctx: Clone + Send + 'static>(
131 ws: WebSocket,
132 ctx: Ctx,
133 services: Arc<BTreeMap<&'static str, BoxedService<Ctx>>>,
134) -> impl Future<Output = Result<(), ()>> {
135 let (ws_out, ws_in) = ws.split();
136
137 let (mut mux_in, mux_out) = mpsc::channel::<Result<Message, warp::Error>>(WS_SEND_BUFFER_SIZE);
139
140 let mut active_responses: HashMap<ReqId, oneshot::Sender<()>> = HashMap::new();
145
146 tokio::spawn(mux_out.fuse().forward(ws_out).map(|_| ()));
148
149 ws_in
150 .try_for_each(move |raw_msg| {
151 if active_responses.len() > REQUEST_GC_THRESHOLD {
152 active_responses.retain(|_, canceled| !canceled.is_canceled());
153 }
154
155 if let Ok(text_msg) = raw_msg.to_str() {
157 match serde_json::from_str::<Incoming>(text_msg) {
158 Ok(req_env) => match req_env {
159 Incoming::Request(body) => {
160 if let Some(srv) = services.get(body.service_id) {
162 let (snd_cancel, rcv_cancel) = oneshot::channel();
164
165 if let Some(previous) =
166 active_responses.insert(body.request_id, snd_cancel)
167 {
168 cancel_response_stream(previous);
169 };
170
171 tokio::spawn(serve_request(
172 rcv_cancel,
173 srv,
174 ctx.clone(),
175 body.service_id,
176 body.request_id,
177 body.payload,
178 mux_in.clone(),
179 ));
180 } else {
181 tokio::spawn(serve_error(
182 body.request_id,
183 ErrorKind::UnknownEndpoint {
184 endpoint: body.service_id.to_string(),
185 valid_endpoints: services
186 .keys()
187 .map(|e| e.to_string())
188 .collect::<Vec<String>>(),
189 },
190 mux_in.clone(),
191 ));
192 tracing::warn!(
193 "Client tried to access unknown service: {}",
194 body.service_id
195 );
196 }
197 }
198 Incoming::Cancel { request_id } => {
199 if let Some(snd_cancel) = active_responses.remove(&request_id) {
200 cancel_response_stream(snd_cancel);
201 }
202 }
203 },
204 Err(cause) => {
205 tracing::warn!(
206 "Could not deserialize client request {}: {}",
207 text_msg,
208 cause
209 );
210 cancel_response_streams_close_channel(&mut active_responses, &mut mux_in);
211 }
212 }
213 } else if raw_msg.is_ping() {
214 } else if raw_msg.is_close() {
216 tracing::debug!("Closing websocket connection (client disconnected)");
217 cancel_response_streams_close_channel(&mut active_responses, &mut mux_in);
218 } else {
219 tracing::warn!("Expected TEXT Websocket message but got binary");
220 cancel_response_streams_close_channel(&mut active_responses, &mut mux_in);
221 };
222 future::ok(())
223 })
224 .map_err(|err| {
225 tracing::info!("Websocket closed with error {}", err);
226 })
227}
228
229#[allow(clippy::cognitive_complexity)]
231fn cancel_response_stream(snd_cancel: oneshot::Sender<()>) {
232 if snd_cancel.is_canceled() {
233 tracing::trace!("Not trying to cancel response stream whose cancel rcv has already dropped")
234 } else {
235 match snd_cancel.send(()) {
238 Ok(_) => tracing::debug!("Merged Cancel signal into ongoing response stream"),
239 Err(_) => tracing::debug!("Response stream we are trying to stop has already stopped"),
240 }
241 }
242}
243
244fn cancel_response_streams_close_channel(
245 active_responses: &mut HashMap<ReqId, oneshot::Sender<()>>,
246 mux_in: &mut mpsc::Sender<Result<Message, warp::Error>>,
247) {
248 for (_, snd_cancel) in active_responses.drain() {
249 cancel_response_stream(snd_cancel);
250 }
251 mux_in.close_channel();
252}
253
254fn serve_request_stream<Ctx: Clone>(
255 srv: &BoxedService<Ctx>,
256 ctx: Ctx,
257 service_id: &str,
258 req_id: ReqId,
259 payload: Value,
260) -> impl Stream<Item = Result<Message, warp::Error>> {
261 let resp_stream = srv
262 .serve_ws(ctx, payload, service_id)
263 .take_until_condition(|resp| future::ready(resp.is_err()))
264 .ready_chunks(128)
265 .flat_map(move |payload_results| {
266 let mut err = None;
267 let mut payload = Vec::with_capacity(payload_results.len());
268 for payload_result in payload_results {
269 match payload_result {
270 Ok(value) => payload.push(value),
271 Err(kind) => err = Some(kind), }
273 }
274 let mut res = Vec::with_capacity(1);
275 if !payload.is_empty() {
276 res.push(Outgoing::Next {
277 request_id: req_id,
278 payload,
279 });
280 }
281 if let Some(kind) = err {
282 res.push(Outgoing::Error {
283 request_id: req_id,
284 kind,
285 });
286 }
287 stream::iter(res)
288 });
289
290 AssertUnwindSafe(resp_stream)
291 .catch_unwind()
292 .map(move |msg_result| match msg_result {
293 Ok(msg) => msg,
294 Err(_) => Outgoing::Error {
295 request_id: req_id,
296 kind: ErrorKind::InternalError,
297 },
298 })
299 .chain(stream::once(future::ready(Outgoing::Complete {
300 request_id: req_id,
301 })))
302 .map(|env| Ok(Message::text(serde_json::to_string(&env).unwrap())))
303}
304
305fn serve_request<T: std::fmt::Debug, Ctx: Clone>(
306 canceled: oneshot::Receiver<()>,
307 srv: &BoxedService<Ctx>,
308 ctx: Ctx,
309 service_id: &str,
310 req_id: ReqId,
311 payload: Value,
312 output: impl Sink<Result<Message, warp::Error>, Error = T>,
313) -> impl Future<Output = ()> {
314 let response_stream = serve_request_stream(srv, ctx, service_id, req_id, payload)
315 .take_until_signaled(canceled)
316 .map(|item| {
317 Ok(item)
320 });
321
322 let service_id = service_id.to_owned();
323 response_stream
324 .yield_after(INTER_STREAM_FAIRNESS)
325 .forward(output)
326 .map(move |result| {
327 if let Err(cause) = result {
328 tracing::warn!(%service_id, "Multiplexing error {:?}", cause);
329 };
330 })
331}
332
333fn serve_error<S>(req_id: ReqId, error_kind: ErrorKind, output: S) -> impl Future<Output = ()>
334where
335 S: Sink<Result<Message, warp::Error>>,
336 S::Error: std::fmt::Debug,
337{
338 let msg = Outgoing::Error {
339 request_id: req_id,
340 kind: error_kind,
341 };
342
343 let raw_msg = Message::text(serde_json::to_string_pretty(&msg).unwrap());
344
345 stream::once(future::ok(Ok(raw_msg)))
346 .forward(output)
347 .map(|result| {
348 if let Err(err) = result {
349 tracing::warn!("Could not send Error message: {:?}", err);
350 };
351 })
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use crate::Service;
358 use futures::stream;
359 use futures::stream::BoxStream;
360 use futures::stream::StreamExt;
361 use futures::task::Poll;
362 use serde::{Deserialize, Serialize};
363 use std::net::SocketAddr;
364 use std::thread::JoinHandle;
365 use warp::Filter;
366 use websocket::{ClientBuilder, OwnedMessage};
367
368 #[derive(Serialize, Deserialize)]
369 enum Request {
370 Count(u64), Size(String), Ctx, Fail(String), Panic, }
376
377 #[derive(Serialize, Deserialize)]
378 struct BadRequest {
379 bad_field: String,
380 }
381
382 #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
383 struct Response(u64);
384
385 struct TestService();
386
387 impl TestService {
388 fn new() -> TestService {
389 TestService()
390 }
391 }
392
393 impl Service for TestService {
394 type Req = Request;
395 type Resp = Response;
396 type Error = String;
397 type Ctx = u64;
398
399 fn serve(&self, ctx: u64, req: Request) -> BoxStream<'static, Result<Response, String>> {
400 match req {
401 Request::Count(cnt) => {
402 let mut ctr = 0;
403 stream::poll_fn(move |_| {
404 let output = ctr;
405 ctr += 1;
406 if ctr <= cnt {
407 Poll::Ready(Some(Ok(Response(output))))
408 } else {
409 Poll::Ready(None)
410 }
411 })
412 .boxed()
413 }
414 Request::Size(data) => {
415 stream::once(future::ok(Response(data.len() as u64))).boxed()
416 }
417 Request::Ctx => stream::once(future::ok(Response(ctx))).boxed(),
418 Request::Fail(reason) => stream::once(future::err(reason)).boxed(),
419 Request::Panic => stream::poll_fn(|_| panic!("Test panic")).boxed(),
420 }
421 }
422 }
423
424 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
427 #[serde(tag = "type")]
428 #[serde(rename_all = "camelCase")]
429 pub enum OutgoingAst {
430 #[serde(rename_all = "camelCase")]
431 Next {
432 request_id: ReqId,
433 payload: Vec<Value>,
434 },
435 #[serde(rename_all = "camelCase")]
436 Complete { request_id: ReqId },
437 #[serde(rename_all = "camelCase")]
438 Error { request_id: ReqId, kind: ErrorKind },
439 }
440
441 impl OutgoingAst {
442 pub fn request_id(&self) -> ReqId {
443 match self {
444 OutgoingAst::Next { request_id, .. } => *request_id,
445 OutgoingAst::Complete { request_id, .. } => *request_id,
446 OutgoingAst::Error { request_id, .. } => *request_id,
447 }
448 }
449 }
450
451 fn test_client<Req: Serialize, Resp: DeserializeOwned>(
452 addr: SocketAddr,
453 endpoint: &str,
454 id: u64,
455 req: Req,
456 ) -> (Vec<Resp>, OutgoingAst) {
457 let addr = format!("ws://{}/test_ws", addr);
458 let client = ClientBuilder::new(&*addr)
459 .expect("Could not setup client")
460 .connect_insecure()
461 .expect("Could not connect to test server");
462
463 let (mut receiver, mut sender) = client.split().unwrap();
464
465 let payload = serde_json::to_value(req).expect("Could not serialize request");
466 let req_env = Incoming::Request(RequestBody {
467 service_id: endpoint,
468 request_id: ReqId(id),
469 payload,
470 });
471 let req_env_json =
472 serde_json::to_string(&req_env).expect("Could not serialize request envelope");
473
474 sender
475 .send_message(&OwnedMessage::Text(req_env_json))
476 .expect("Could not send request");
477
478 let mut completion: Option<OutgoingAst> = None;
479
480 let msgs = receiver
481 .incoming_messages()
482 .filter_map(move |msg| {
483 let msg_ok = msg.expect("Expected message but got websocket error");
484 if let OwnedMessage::Text(raw_resp) = msg_ok {
485 let resp_env: OutgoingAst = serde_json::from_str(&*raw_resp)
486 .expect("Could not deserialize response envelope");
487 if resp_env.request_id().0 == id {
488 Some(resp_env)
489 } else {
490 None
491 }
492 } else {
493 None
494 }
495 })
496 .take_while(|env| {
497 if let OutgoingAst::Next { .. } = env {
498 true
499 } else {
500 completion = Some(env.clone());
501 false
502 }
503 })
504 .flat_map(|env| {
505 if let OutgoingAst::Next { payload, .. } = env {
506 payload
507 .into_iter()
508 .map(|p| {
509 serde_json::from_value::<Resp>(p)
510 .expect("Could not deserialize response")
511 })
512 .collect()
513 } else {
514 vec![]
515 }
516 })
517 .collect();
518 (msgs, completion.expect("Expected a completion message"))
519 }
520
521 async fn start_test_service() -> SocketAddr {
522 let services = Arc::new(maplit::btreemap! {"test" => TestService::new().boxed()});
523 let ws = warp::path("test_ws")
524 .and(warp::ws())
525 .and(warp::any().map(move || services.clone()))
526 .and(warp::any().map(move || 23))
527 .and_then(super::serve);
528 let (addr, task) = warp::serve(ws).bind_ephemeral(([127, 0, 0, 1], 0));
529 tokio::spawn(task);
530 addr
531 }
532
533 #[tokio::test(flavor = "multi_thread")]
534 async fn properly_serve_single_request() {
535 let addr = start_test_service().await;
536
537 assert_eq!(
538 test_client::<Request, Response>(addr, "test", 0, Request::Count(5)).0,
539 vec![
540 Response(0),
541 Response(1),
542 Response(2),
543 Response(3),
544 Response(4)
545 ]
546 );
547 }
548
549 #[tokio::test(flavor = "multi_thread")]
550 async fn properly_serve_single_request_ctx() {
551 let addr = start_test_service().await;
552
553 assert_eq!(
554 test_client::<Request, Response>(addr, "test", 0, Request::Ctx).0,
555 vec![Response(23)]
556 );
557 }
558
559 #[tokio::test(flavor = "multi_thread")]
560 async fn properly_serve_large_request() {
561 let addr = start_test_service().await;
562 let len = 20_000_000;
563 let data: String = std::iter::repeat('x').take(len).collect::<String>();
564
565 assert_eq!(
566 test_client::<Request, Response>(addr, "test", 0, Request::Size(data)).0,
567 vec![Response(len as u64)]
568 );
569 }
570
571 #[tokio::test(flavor = "multi_thread")]
572 async fn multiplex_multiple_queries() {
573 let addr = start_test_service().await;
574
575 let client_cnt = 50;
576 let request_cnt = 100;
577 let start_barrier = Arc::new(std::sync::Barrier::new(client_cnt));
578
579 let join_handles: Vec<JoinHandle<Vec<Response>>> = (0..client_cnt)
580 .map(|i| {
581 let b = start_barrier.clone();
582 std::thread::spawn(move || {
583 b.wait();
584 test_client::<Request, Response>(
585 addr,
586 "test",
587 i as u64,
588 Request::Count(request_cnt),
589 )
590 .0
591 })
592 })
593 .collect();
594 let expected: Vec<Response> = (0..request_cnt).map(|i| Response(i as u64)).collect();
595
596 for handle in join_handles {
597 assert_eq!(handle.join().unwrap(), expected)
598 }
599 }
600
601 #[tokio::test(flavor = "multi_thread")]
602 async fn report_wrong_endpoint() {
603 let addr = start_test_service().await;
604
605 let (msgs, completion) =
606 test_client::<Request, Response>(addr, "no_such_service", 49, Request::Count(5));
607
608 assert_eq!(msgs, vec![]);
609
610 assert_eq!(
611 completion,
612 OutgoingAst::Error {
613 request_id: ReqId(49),
614 kind: ErrorKind::UnknownEndpoint {
615 endpoint: "no_such_service".to_string(),
616 valid_endpoints: vec!["test".to_string()],
617 }
618 }
619 );
620 }
621
622 #[tokio::test(flavor = "multi_thread")]
623 async fn report_badly_formatted_request() {
624 let addr = start_test_service().await;
625
626 let (msgs, completion) = test_client::<BadRequest, Response>(
627 addr,
628 "test",
629 49,
630 BadRequest {
631 bad_field: "xzy".to_string(),
632 },
633 );
634
635 assert_eq!(msgs, vec![]);
636
637 if let OutgoingAst::Error {
638 request_id: ReqId(49),
639 kind: ErrorKind::BadRequest { message },
640 } = completion
641 {
642 assert!(message.starts_with("unknown variant"));
643 } else {
644 panic!();
645 }
646 }
647
648 #[tokio::test(flavor = "multi_thread")]
649 async fn report_service_error() {
650 let addr = start_test_service().await;
651
652 let (msgs, completion) = test_client::<Request, Response>(
653 addr,
654 "test",
655 49,
656 Request::Fail("Test reason".to_string()),
657 );
658
659 assert_eq!(msgs, vec![]);
660
661 assert_eq!(
662 completion,
663 OutgoingAst::Error {
664 request_id: ReqId(49),
665 kind: ErrorKind::ServiceError {
666 value: Value::String("Test reason".to_string())
667 },
668 }
669 );
670 }
671
672 #[tokio::test(flavor = "multi_thread")]
673 async fn report_service_panic() {
674 let addr = start_test_service().await;
675
676 let (msgs, completion) = test_client::<Request, Response>(addr, "test", 49, Request::Panic);
677
678 assert_eq!(msgs, vec![]);
679
680 assert_eq!(
681 completion,
682 OutgoingAst::Error {
683 request_id: ReqId(49),
684 kind: ErrorKind::InternalError,
685 }
686 );
687 }
688
689 }