wsrpc/
lib.rs

1/*
2 * Copyright 2021 Actyx AG
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16mod formats;
17mod util;
18
19use formats::*;
20use futures::channel::{mpsc, oneshot};
21use futures::stream;
22use futures::stream::BoxStream;
23use futures::{future, Future, Sink};
24use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
25use serde::de::DeserializeOwned;
26use serde::Serialize;
27use serde_json::value::{RawValue, Value};
28use std::collections::{BTreeMap, HashMap};
29use std::panic::AssertUnwindSafe;
30use std::sync::Arc;
31use util::UtilStreamExt;
32use warp::filters::ws::{Message, WebSocket};
33
34const WS_SEND_BUFFER_SIZE: usize = 1024;
35const REQUEST_GC_THRESHOLD: usize = 64;
36const INTER_STREAM_FAIRNESS: u64 = 64;
37
38pub trait Service {
39    type Req: DeserializeOwned;
40    type Resp: Serialize + 'static;
41    type Error: Serialize + 'static;
42    type Ctx: Clone;
43
44    fn serve(
45        &self,
46        ctx: Self::Ctx,
47        req: Self::Req,
48    ) -> BoxStream<'static, Result<Self::Resp, Self::Error>>;
49
50    fn boxed(self) -> BoxedService<Self::Ctx>
51    where
52        Self: Send + Sized + Sync + 'static,
53    {
54        Box::new(self)
55    }
56}
57
58pub trait WebsocketService<Ctx: Clone> {
59    fn serve_ws(
60        &self,
61        ctx: Ctx,
62        raw_req: Value,
63        service_id: &str,
64    ) -> BoxStream<'static, Result<Box<RawValue>, ErrorKind>>;
65}
66
67impl<Req, Resp, Ctx, S> WebsocketService<Ctx> for S
68where
69    S: Service<Req = Req, Resp = Resp, Ctx = Ctx>,
70    Req: DeserializeOwned,
71    Resp: Serialize + 'static,
72    Ctx: Clone,
73{
74    fn serve_ws(
75        &self,
76        ctx: Ctx,
77        raw_req: Value,
78        service_id: &str,
79    ) -> BoxStream<'static, Result<Box<RawValue>, ErrorKind>> {
80        tracing::trace!(
81            "Serving raw request for service {}: {:?}",
82            service_id,
83            raw_req
84        );
85        match serde_json::from_value(raw_req) {
86            Ok(req) => self
87                .serve(ctx, req)
88                .map(|resp_result| {
89                    resp_result
90                        .map(|resp| {
91                            serde_json::value::to_raw_value(&resp)
92                                .expect("Could not serialize service response")
93                        })
94                        .map_err(|err| ErrorKind::ServiceError {
95                            value: serde_json::to_value(&err)
96                                .expect("Could not serialize service error response"),
97                        })
98                })
99                .boxed(),
100            Err(cause) => {
101                let message = format!("{}", cause);
102                tracing::warn!(
103                    "Error deserializing request for service {}: {}",
104                    service_id,
105                    message
106                );
107                stream::once(future::err(ErrorKind::BadRequest { message })).boxed()
108            }
109        }
110    }
111}
112
113pub type BoxedService<Ctx> = Box<dyn WebsocketService<Ctx> + Send + Sync>;
114
115pub async fn serve<Ctx: Clone + Send + 'static>(
116    ws: warp::ws::Ws,
117    services: Arc<BTreeMap<&'static str, BoxedService<Ctx>>>,
118    ctx: Ctx,
119) -> Result<impl warp::Reply, warp::Rejection> {
120    // Set the max frame size to 64 MB (defaults to 16 MB which we have hit at CTA)
121    Ok(ws
122        .max_frame_size(64 << 20)
123        // Set the max message size to 128 MB (defaults to 64 MB which we have hit for an humongous snapshot)
124        .max_message_size(128 << 20)
125        .on_upgrade(move |socket| client_connected(socket, ctx, services).map(|_| ())))
126    // on_upgrade does not take in errors any longer
127}
128
129#[allow(clippy::cognitive_complexity)]
130fn client_connected<Ctx: Clone + Send + 'static>(
131    ws: WebSocket,
132    ctx: Ctx,
133    services: Arc<BTreeMap<&'static str, BoxedService<Ctx>>>,
134) -> impl Future<Output = Result<(), ()>> {
135    let (ws_out, ws_in) = ws.split();
136
137    // Create an MPSC channel to merge outbound WS messages
138    let (mut mux_in, mux_out) = mpsc::channel::<Result<Message, warp::Error>>(WS_SEND_BUFFER_SIZE);
139
140    // Map of request IDs to the reference counted boolean that will terminate the response
141    // stream upon cancellation. There is no need for a concurrent map because we simply share
142    // the entries with the running streams. This also means that the running response stream
143    // does not need to actually look up the entry every time.
144    let mut active_responses: HashMap<ReqId, oneshot::Sender<()>> = HashMap::new();
145
146    // Pipe the merged stream into the websocket output;
147    tokio::spawn(mux_out.fuse().forward(ws_out).map(|_| ()));
148
149    ws_in
150        .try_for_each(move |raw_msg| {
151            if active_responses.len() > REQUEST_GC_THRESHOLD {
152                active_responses.retain(|_, canceled| !canceled.is_canceled());
153            }
154
155            // Do some parsing first...
156            if let Ok(text_msg) = raw_msg.to_str() {
157                match serde_json::from_str::<Incoming>(text_msg) {
158                    Ok(req_env) => match req_env {
159                        Incoming::Request(body) => {
160                            // Locate the service matching the request
161                            if let Some(srv) = services.get(body.service_id) {
162                                // Set up cancellation signal
163                                let (snd_cancel, rcv_cancel) = oneshot::channel();
164
165                                if let Some(previous) =
166                                    active_responses.insert(body.request_id, snd_cancel)
167                                {
168                                    cancel_response_stream(previous);
169                                };
170
171                                tokio::spawn(serve_request(
172                                    rcv_cancel,
173                                    srv,
174                                    ctx.clone(),
175                                    body.service_id,
176                                    body.request_id,
177                                    body.payload,
178                                    mux_in.clone(),
179                                ));
180                            } else {
181                                tokio::spawn(serve_error(
182                                    body.request_id,
183                                    ErrorKind::UnknownEndpoint {
184                                        endpoint: body.service_id.to_string(),
185                                        valid_endpoints: services
186                                            .keys()
187                                            .map(|e| e.to_string())
188                                            .collect::<Vec<String>>(),
189                                    },
190                                    mux_in.clone(),
191                                ));
192                                tracing::warn!(
193                                    "Client tried to access unknown service: {}",
194                                    body.service_id
195                                );
196                            }
197                        }
198                        Incoming::Cancel { request_id } => {
199                            if let Some(snd_cancel) = active_responses.remove(&request_id) {
200                                cancel_response_stream(snd_cancel);
201                            }
202                        }
203                    },
204                    Err(cause) => {
205                        tracing::warn!(
206                            "Could not deserialize client request {}: {}",
207                            text_msg,
208                            cause
209                        );
210                        cancel_response_streams_close_channel(&mut active_responses, &mut mux_in);
211                    }
212                }
213            } else if raw_msg.is_ping() {
214                // No way to send pong??
215            } else if raw_msg.is_close() {
216                tracing::debug!("Closing websocket connection (client disconnected)");
217                cancel_response_streams_close_channel(&mut active_responses, &mut mux_in);
218            } else {
219                tracing::warn!("Expected TEXT Websocket message but got binary");
220                cancel_response_streams_close_channel(&mut active_responses, &mut mux_in);
221            };
222            future::ok(())
223        })
224        .map_err(|err| {
225            tracing::info!("Websocket closed with error {}", err);
226        })
227}
228
229// Wtf, clippy?
230#[allow(clippy::cognitive_complexity)]
231fn cancel_response_stream(snd_cancel: oneshot::Sender<()>) {
232    if snd_cancel.is_canceled() {
233        tracing::trace!("Not trying to cancel response stream whose cancel rcv has already dropped")
234    } else {
235        // Let it be said that we could just as well just drop the Sender here,
236        // which would also signal the Receiver (with a 'Cancel' error).
237        match snd_cancel.send(()) {
238            Ok(_) => tracing::debug!("Merged Cancel signal into ongoing response stream"),
239            Err(_) => tracing::debug!("Response stream we are trying to stop has already stopped"),
240        }
241    }
242}
243
244fn cancel_response_streams_close_channel(
245    active_responses: &mut HashMap<ReqId, oneshot::Sender<()>>,
246    mux_in: &mut mpsc::Sender<Result<Message, warp::Error>>,
247) {
248    for (_, snd_cancel) in active_responses.drain() {
249        cancel_response_stream(snd_cancel);
250    }
251    mux_in.close_channel();
252}
253
254fn serve_request_stream<Ctx: Clone>(
255    srv: &BoxedService<Ctx>,
256    ctx: Ctx,
257    service_id: &str,
258    req_id: ReqId,
259    payload: Value,
260) -> impl Stream<Item = Result<Message, warp::Error>> {
261    let resp_stream = srv
262        .serve_ws(ctx, payload, service_id)
263        .take_until_condition(|resp| future::ready(resp.is_err()))
264        .ready_chunks(128)
265        .flat_map(move |payload_results| {
266            let mut err = None;
267            let mut payload = Vec::with_capacity(payload_results.len());
268            for payload_result in payload_results {
269                match payload_result {
270                    Ok(value) => payload.push(value),
271                    Err(kind) => err = Some(kind), // always comes last
272                }
273            }
274            let mut res = Vec::with_capacity(1);
275            if !payload.is_empty() {
276                res.push(Outgoing::Next {
277                    request_id: req_id,
278                    payload,
279                });
280            }
281            if let Some(kind) = err {
282                res.push(Outgoing::Error {
283                    request_id: req_id,
284                    kind,
285                });
286            }
287            stream::iter(res)
288        });
289
290    AssertUnwindSafe(resp_stream)
291        .catch_unwind()
292        .map(move |msg_result| match msg_result {
293            Ok(msg) => msg,
294            Err(_) => Outgoing::Error {
295                request_id: req_id,
296                kind: ErrorKind::InternalError,
297            },
298        })
299        .chain(stream::once(future::ready(Outgoing::Complete {
300            request_id: req_id,
301        })))
302        .map(|env| Ok(Message::text(serde_json::to_string(&env).unwrap())))
303}
304
305fn serve_request<T: std::fmt::Debug, Ctx: Clone>(
306    canceled: oneshot::Receiver<()>,
307    srv: &BoxedService<Ctx>,
308    ctx: Ctx,
309    service_id: &str,
310    req_id: ReqId,
311    payload: Value,
312    output: impl Sink<Result<Message, warp::Error>, Error = T>,
313) -> impl Future<Output = ()> {
314    let response_stream = serve_request_stream(srv, ctx, service_id, req_id, payload)
315        .take_until_signaled(canceled)
316        .map(|item| {
317            // We need to re-wrap in an outer result because Sink requires SinkError as the error type
318            // but it will pass our inner error unmodified
319            Ok(item)
320        });
321
322    let service_id = service_id.to_owned();
323    response_stream
324        .yield_after(INTER_STREAM_FAIRNESS)
325        .forward(output)
326        .map(move |result| {
327            if let Err(cause) = result {
328                tracing::warn!(%service_id, "Multiplexing error {:?}", cause);
329            };
330        })
331}
332
333fn serve_error<S>(req_id: ReqId, error_kind: ErrorKind, output: S) -> impl Future<Output = ()>
334where
335    S: Sink<Result<Message, warp::Error>>,
336    S::Error: std::fmt::Debug,
337{
338    let msg = Outgoing::Error {
339        request_id: req_id,
340        kind: error_kind,
341    };
342
343    let raw_msg = Message::text(serde_json::to_string_pretty(&msg).unwrap());
344
345    stream::once(future::ok(Ok(raw_msg)))
346        .forward(output)
347        .map(|result| {
348            if let Err(err) = result {
349                tracing::warn!("Could not send Error message: {:?}", err);
350            };
351        })
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use crate::Service;
358    use futures::stream;
359    use futures::stream::BoxStream;
360    use futures::stream::StreamExt;
361    use futures::task::Poll;
362    use serde::{Deserialize, Serialize};
363    use std::net::SocketAddr;
364    use std::thread::JoinHandle;
365    use warp::Filter;
366    use websocket::{ClientBuilder, OwnedMessage};
367
368    #[derive(Serialize, Deserialize)]
369    enum Request {
370        Count(u64),   // Returns numbers 0..N
371        Size(String), // returns data size
372        Ctx,          // returns the provided context
373        Fail(String), // Fails the service normally with given reason
374        Panic,        // Panics the service
375    }
376
377    #[derive(Serialize, Deserialize)]
378    struct BadRequest {
379        bad_field: String,
380    }
381
382    #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
383    struct Response(u64);
384
385    struct TestService();
386
387    impl TestService {
388        fn new() -> TestService {
389            TestService()
390        }
391    }
392
393    impl Service for TestService {
394        type Req = Request;
395        type Resp = Response;
396        type Error = String;
397        type Ctx = u64;
398
399        fn serve(&self, ctx: u64, req: Request) -> BoxStream<'static, Result<Response, String>> {
400            match req {
401                Request::Count(cnt) => {
402                    let mut ctr = 0;
403                    stream::poll_fn(move |_| {
404                        let output = ctr;
405                        ctr += 1;
406                        if ctr <= cnt {
407                            Poll::Ready(Some(Ok(Response(output))))
408                        } else {
409                            Poll::Ready(None)
410                        }
411                    })
412                    .boxed()
413                }
414                Request::Size(data) => {
415                    stream::once(future::ok(Response(data.len() as u64))).boxed()
416                }
417                Request::Ctx => stream::once(future::ok(Response(ctx))).boxed(),
418                Request::Fail(reason) => stream::once(future::err(reason)).boxed(),
419                Request::Panic => stream::poll_fn(|_| panic!("Test panic")).boxed(),
420            }
421        }
422    }
423
424    // Copy of Outgoing that uses Value over Box<RawValue>
425    // Needed due to https://github.com/serde-rs/json/issues/779
426    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
427    #[serde(tag = "type")]
428    #[serde(rename_all = "camelCase")]
429    pub enum OutgoingAst {
430        #[serde(rename_all = "camelCase")]
431        Next {
432            request_id: ReqId,
433            payload: Vec<Value>,
434        },
435        #[serde(rename_all = "camelCase")]
436        Complete { request_id: ReqId },
437        #[serde(rename_all = "camelCase")]
438        Error { request_id: ReqId, kind: ErrorKind },
439    }
440
441    impl OutgoingAst {
442        pub fn request_id(&self) -> ReqId {
443            match self {
444                OutgoingAst::Next { request_id, .. } => *request_id,
445                OutgoingAst::Complete { request_id, .. } => *request_id,
446                OutgoingAst::Error { request_id, .. } => *request_id,
447            }
448        }
449    }
450
451    fn test_client<Req: Serialize, Resp: DeserializeOwned>(
452        addr: SocketAddr,
453        endpoint: &str,
454        id: u64,
455        req: Req,
456    ) -> (Vec<Resp>, OutgoingAst) {
457        let addr = format!("ws://{}/test_ws", addr);
458        let client = ClientBuilder::new(&*addr)
459            .expect("Could not setup client")
460            .connect_insecure()
461            .expect("Could not connect to test server");
462
463        let (mut receiver, mut sender) = client.split().unwrap();
464
465        let payload = serde_json::to_value(req).expect("Could not serialize request");
466        let req_env = Incoming::Request(RequestBody {
467            service_id: endpoint,
468            request_id: ReqId(id),
469            payload,
470        });
471        let req_env_json =
472            serde_json::to_string(&req_env).expect("Could not serialize request envelope");
473
474        sender
475            .send_message(&OwnedMessage::Text(req_env_json))
476            .expect("Could not send request");
477
478        let mut completion: Option<OutgoingAst> = None;
479
480        let msgs = receiver
481            .incoming_messages()
482            .filter_map(move |msg| {
483                let msg_ok = msg.expect("Expected message but got websocket error");
484                if let OwnedMessage::Text(raw_resp) = msg_ok {
485                    let resp_env: OutgoingAst = serde_json::from_str(&*raw_resp)
486                        .expect("Could not deserialize response envelope");
487                    if resp_env.request_id().0 == id {
488                        Some(resp_env)
489                    } else {
490                        None
491                    }
492                } else {
493                    None
494                }
495            })
496            .take_while(|env| {
497                if let OutgoingAst::Next { .. } = env {
498                    true
499                } else {
500                    completion = Some(env.clone());
501                    false
502                }
503            })
504            .flat_map(|env| {
505                if let OutgoingAst::Next { payload, .. } = env {
506                    payload
507                        .into_iter()
508                        .map(|p| {
509                            serde_json::from_value::<Resp>(p)
510                                .expect("Could not deserialize response")
511                        })
512                        .collect()
513                } else {
514                    vec![]
515                }
516            })
517            .collect();
518        (msgs, completion.expect("Expected a completion message"))
519    }
520
521    async fn start_test_service() -> SocketAddr {
522        let services = Arc::new(maplit::btreemap! {"test" => TestService::new().boxed()});
523        let ws = warp::path("test_ws")
524            .and(warp::ws())
525            .and(warp::any().map(move || services.clone()))
526            .and(warp::any().map(move || 23))
527            .and_then(super::serve);
528        let (addr, task) = warp::serve(ws).bind_ephemeral(([127, 0, 0, 1], 0));
529        tokio::spawn(task);
530        addr
531    }
532
533    #[tokio::test(flavor = "multi_thread")]
534    async fn properly_serve_single_request() {
535        let addr = start_test_service().await;
536
537        assert_eq!(
538            test_client::<Request, Response>(addr, "test", 0, Request::Count(5)).0,
539            vec![
540                Response(0),
541                Response(1),
542                Response(2),
543                Response(3),
544                Response(4)
545            ]
546        );
547    }
548
549    #[tokio::test(flavor = "multi_thread")]
550    async fn properly_serve_single_request_ctx() {
551        let addr = start_test_service().await;
552
553        assert_eq!(
554            test_client::<Request, Response>(addr, "test", 0, Request::Ctx).0,
555            vec![Response(23)]
556        );
557    }
558
559    #[tokio::test(flavor = "multi_thread")]
560    async fn properly_serve_large_request() {
561        let addr = start_test_service().await;
562        let len = 20_000_000;
563        let data: String = std::iter::repeat('x').take(len).collect::<String>();
564
565        assert_eq!(
566            test_client::<Request, Response>(addr, "test", 0, Request::Size(data)).0,
567            vec![Response(len as u64)]
568        );
569    }
570
571    #[tokio::test(flavor = "multi_thread")]
572    async fn multiplex_multiple_queries() {
573        let addr = start_test_service().await;
574
575        let client_cnt = 50;
576        let request_cnt = 100;
577        let start_barrier = Arc::new(std::sync::Barrier::new(client_cnt));
578
579        let join_handles: Vec<JoinHandle<Vec<Response>>> = (0..client_cnt)
580            .map(|i| {
581                let b = start_barrier.clone();
582                std::thread::spawn(move || {
583                    b.wait();
584                    test_client::<Request, Response>(
585                        addr,
586                        "test",
587                        i as u64,
588                        Request::Count(request_cnt),
589                    )
590                    .0
591                })
592            })
593            .collect();
594        let expected: Vec<Response> = (0..request_cnt).map(|i| Response(i as u64)).collect();
595
596        for handle in join_handles {
597            assert_eq!(handle.join().unwrap(), expected)
598        }
599    }
600
601    #[tokio::test(flavor = "multi_thread")]
602    async fn report_wrong_endpoint() {
603        let addr = start_test_service().await;
604
605        let (msgs, completion) =
606            test_client::<Request, Response>(addr, "no_such_service", 49, Request::Count(5));
607
608        assert_eq!(msgs, vec![]);
609
610        assert_eq!(
611            completion,
612            OutgoingAst::Error {
613                request_id: ReqId(49),
614                kind: ErrorKind::UnknownEndpoint {
615                    endpoint: "no_such_service".to_string(),
616                    valid_endpoints: vec!["test".to_string()],
617                }
618            }
619        );
620    }
621
622    #[tokio::test(flavor = "multi_thread")]
623    async fn report_badly_formatted_request() {
624        let addr = start_test_service().await;
625
626        let (msgs, completion) = test_client::<BadRequest, Response>(
627            addr,
628            "test",
629            49,
630            BadRequest {
631                bad_field: "xzy".to_string(),
632            },
633        );
634
635        assert_eq!(msgs, vec![]);
636
637        if let OutgoingAst::Error {
638            request_id: ReqId(49),
639            kind: ErrorKind::BadRequest { message },
640        } = completion
641        {
642            assert!(message.starts_with("unknown variant"));
643        } else {
644            panic!();
645        }
646    }
647
648    #[tokio::test(flavor = "multi_thread")]
649    async fn report_service_error() {
650        let addr = start_test_service().await;
651
652        let (msgs, completion) = test_client::<Request, Response>(
653            addr,
654            "test",
655            49,
656            Request::Fail("Test reason".to_string()),
657        );
658
659        assert_eq!(msgs, vec![]);
660
661        assert_eq!(
662            completion,
663            OutgoingAst::Error {
664                request_id: ReqId(49),
665                kind: ErrorKind::ServiceError {
666                    value: Value::String("Test reason".to_string())
667                },
668            }
669        );
670    }
671
672    #[tokio::test(flavor = "multi_thread")]
673    async fn report_service_panic() {
674        let addr = start_test_service().await;
675
676        let (msgs, completion) = test_client::<Request, Response>(addr, "test", 49, Request::Panic);
677
678        assert_eq!(msgs, vec![]);
679
680        assert_eq!(
681            completion,
682            OutgoingAst::Error {
683                request_id: ReqId(49),
684                kind: ErrorKind::InternalError,
685            }
686        );
687    }
688
689    // Handle service panic
690}