async_graphql/http/
websocket.rs

1//! WebSocket transport for subscription
2
3use std::{
4    collections::HashMap,
5    future::Future,
6    pin::Pin,
7    sync::Arc,
8    task::{Context, Poll},
9    time::{Duration, Instant},
10};
11
12use futures_util::{
13    FutureExt, StreamExt,
14    future::{BoxFuture, Ready},
15    stream::Stream,
16};
17use pin_project_lite::pin_project;
18use serde::{Deserialize, Serialize};
19
20use crate::{Data, Error, Executor, Request, Response, Result, util::Delay};
21
22/// All known protocols based on WebSocket.
23pub const ALL_WEBSOCKET_PROTOCOLS: [&str; 2] = ["graphql-transport-ws", "graphql-ws"];
24
25/// An enum representing the various forms of a WebSocket message.
26#[derive(Clone, Debug, PartialEq, Eq)]
27pub enum WsMessage {
28    /// A text WebSocket message
29    Text(String),
30
31    /// A close message with the close frame.
32    Close(u16, String),
33}
34
35impl WsMessage {
36    /// Returns the contained [WsMessage::Text] value, consuming the `self`
37    /// value.
38    ///
39    /// Because this function may panic, its use is generally discouraged.
40    ///
41    /// # Panics
42    ///
43    /// Panics if the self value not equals [WsMessage::Text].
44    pub fn unwrap_text(self) -> String {
45        match self {
46            Self::Text(text) => text,
47            Self::Close(_, _) => panic!("Not a text message"),
48        }
49    }
50
51    /// Returns the contained [WsMessage::Close] value, consuming the `self`
52    /// value.
53    ///
54    /// Because this function may panic, its use is generally discouraged.
55    ///
56    /// # Panics
57    ///
58    /// Panics if the self value not equals [WsMessage::Close].
59    pub fn unwrap_close(self) -> (u16, String) {
60        match self {
61            Self::Close(code, msg) => (code, msg),
62            Self::Text(_) => panic!("Not a close message"),
63        }
64    }
65}
66
67struct Timer {
68    interval: Duration,
69    delay: Delay,
70}
71
72impl Timer {
73    #[inline]
74    fn new(interval: Duration) -> Self {
75        Self {
76            interval,
77            delay: Delay::new(interval),
78        }
79    }
80
81    #[inline]
82    fn reset(&mut self) {
83        self.delay.reset(self.interval);
84    }
85}
86
87impl Stream for Timer {
88    type Item = ();
89
90    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
91        let this = &mut *self;
92        match this.delay.poll_unpin(cx) {
93            Poll::Ready(_) => {
94                this.delay.reset(this.interval);
95                Poll::Ready(Some(()))
96            }
97            Poll::Pending => Poll::Pending,
98        }
99    }
100}
101
102pin_project! {
103    /// A GraphQL connection over websocket.
104    ///
105    /// # References
106    ///
107    /// - [subscriptions-transport-ws](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md)
108    /// - [graphql-ws](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md)
109    pub struct WebSocket<S, E, OnInit, OnPing> {
110        on_connection_init: Option<OnInit>,
111        on_ping: OnPing,
112        init_fut: Option<BoxFuture<'static, Result<Data>>>,
113        ping_fut: Option<BoxFuture<'static, Result<Option<serde_json::Value>>>>,
114        connection_data: Option<Data>,
115        data: Option<Arc<Data>>,
116        executor: E,
117        streams: HashMap<String, Pin<Box<dyn Stream<Item = Response> + Send>>>,
118        #[pin]
119        stream: S,
120        protocol: Protocols,
121        last_msg_at: Instant,
122        keepalive_timer: Option<Timer>,
123        close: bool,
124    }
125}
126
127type MessageMapStream<S> =
128    futures_util::stream::Map<S, fn(<S as Stream>::Item) -> serde_json::Result<ClientMessage>>;
129
130/// Default connection initializer type.
131pub type DefaultOnConnInitType = fn(serde_json::Value) -> Ready<Result<Data>>;
132
133/// Default ping handler type.
134pub type DefaultOnPingType =
135    fn(Option<&Data>, Option<serde_json::Value>) -> Ready<Result<Option<serde_json::Value>>>;
136
137/// Default connection initializer function.
138pub fn default_on_connection_init(_: serde_json::Value) -> Ready<Result<Data>> {
139    futures_util::future::ready(Ok(Data::default()))
140}
141
142/// Default ping handler function.
143pub fn default_on_ping(
144    _: Option<&Data>,
145    _: Option<serde_json::Value>,
146) -> Ready<Result<Option<serde_json::Value>>> {
147    futures_util::future::ready(Ok(None))
148}
149
150impl<S, E> WebSocket<S, E, DefaultOnConnInitType, DefaultOnPingType>
151where
152    E: Executor,
153    S: Stream<Item = serde_json::Result<ClientMessage>>,
154{
155    /// Create a new websocket from [`ClientMessage`] stream.
156    pub fn from_message_stream(executor: E, stream: S, protocol: Protocols) -> Self {
157        WebSocket {
158            on_connection_init: Some(default_on_connection_init),
159            on_ping: default_on_ping,
160            init_fut: None,
161            ping_fut: None,
162            connection_data: None,
163            data: None,
164            executor,
165            streams: HashMap::new(),
166            stream,
167            protocol,
168            last_msg_at: Instant::now(),
169            keepalive_timer: None,
170            close: false,
171        }
172    }
173}
174
175impl<S, E> WebSocket<MessageMapStream<S>, E, DefaultOnConnInitType, DefaultOnPingType>
176where
177    E: Executor,
178    S: Stream,
179    S::Item: AsRef<[u8]>,
180{
181    /// Create a new websocket from bytes stream.
182    pub fn new(executor: E, stream: S, protocol: Protocols) -> Self {
183        let stream = stream
184            .map(ClientMessage::from_bytes as fn(S::Item) -> serde_json::Result<ClientMessage>);
185        WebSocket::from_message_stream(executor, stream, protocol)
186    }
187}
188
189impl<S, E, OnInit, OnPing> WebSocket<S, E, OnInit, OnPing>
190where
191    E: Executor,
192    S: Stream<Item = serde_json::Result<ClientMessage>>,
193{
194    /// Specify a connection data.
195    ///
196    /// This data usually comes from HTTP requests.
197    /// When the `GQL_CONNECTION_INIT` message is received, this data will be
198    /// merged with the data returned by the closure specified by
199    /// `with_initializer` into the final subscription context data.
200    #[must_use]
201    pub fn connection_data(mut self, data: Data) -> Self {
202        self.connection_data = Some(data);
203        self
204    }
205
206    /// Specify a connection initialize callback function.
207    ///
208    /// This function if present, will be called with the data sent by the
209    /// client in the [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init).
210    /// From that point on the returned data will be accessible to all requests.
211    #[must_use]
212    pub fn on_connection_init<F, R>(self, callback: F) -> WebSocket<S, E, F, OnPing>
213    where
214        F: FnOnce(serde_json::Value) -> R + Send + 'static,
215        R: Future<Output = Result<Data>> + Send + 'static,
216    {
217        WebSocket {
218            on_connection_init: Some(callback),
219            on_ping: self.on_ping,
220            init_fut: self.init_fut,
221            ping_fut: self.ping_fut,
222            connection_data: self.connection_data,
223            data: self.data,
224            executor: self.executor,
225            streams: self.streams,
226            stream: self.stream,
227            protocol: self.protocol,
228            last_msg_at: self.last_msg_at,
229            keepalive_timer: self.keepalive_timer,
230            close: self.close,
231        }
232    }
233
234    /// Specify a ping callback function.
235    ///
236    /// This function if present, will be called with the data sent by the
237    /// client in the [`Ping` message](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#ping).
238    ///
239    /// The function should return the data to be sent in the [`Pong` message](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#pong).
240    ///
241    /// NOTE: Only used for the `graphql-ws` protocol.
242    #[must_use]
243    pub fn on_ping<F, R>(self, callback: F) -> WebSocket<S, E, OnInit, F>
244    where
245        F: FnOnce(Option<&Data>, Option<serde_json::Value>) -> R + Send + Clone + 'static,
246        R: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
247    {
248        WebSocket {
249            on_connection_init: self.on_connection_init,
250            on_ping: callback,
251            init_fut: self.init_fut,
252            ping_fut: self.ping_fut,
253            connection_data: self.connection_data,
254            data: self.data,
255            executor: self.executor,
256            streams: self.streams,
257            stream: self.stream,
258            protocol: self.protocol,
259            last_msg_at: self.last_msg_at,
260            keepalive_timer: self.keepalive_timer,
261            close: self.close,
262        }
263    }
264
265    /// Sets a timeout for receiving an acknowledgement of the keep-alive ping.
266    ///
267    /// If the ping is not acknowledged within the timeout, the connection will
268    /// be closed.
269    ///
270    /// NOTE: Only used for the `graphql-ws` protocol.
271    #[must_use]
272    pub fn keepalive_timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
273        Self {
274            keepalive_timer: timeout.into().map(Timer::new),
275            ..self
276        }
277    }
278}
279
280impl<S, E, OnInit, InitFut, OnPing, PingFut> Stream for WebSocket<S, E, OnInit, OnPing>
281where
282    E: Executor,
283    S: Stream<Item = serde_json::Result<ClientMessage>>,
284    OnInit: FnOnce(serde_json::Value) -> InitFut + Send + 'static,
285    InitFut: Future<Output = Result<Data>> + Send + 'static,
286    OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> PingFut + Clone + Send + 'static,
287    PingFut: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
288{
289    type Item = WsMessage;
290
291    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
292        let mut this = self.project();
293
294        if *this.close {
295            return Poll::Ready(None);
296        }
297
298        if let Some(keepalive_timer) = this.keepalive_timer
299            && let Poll::Ready(Some(())) = keepalive_timer.poll_next_unpin(cx)
300        {
301            return match this.protocol {
302                Protocols::SubscriptionsTransportWS => {
303                    *this.close = true;
304                    Poll::Ready(Some(WsMessage::Text(
305                        serde_json::to_string(&ServerMessage::ConnectionError {
306                            payload: Error::new("timeout"),
307                        })
308                        .unwrap(),
309                    )))
310                }
311                Protocols::GraphQLWS => {
312                    *this.close = true;
313                    Poll::Ready(Some(WsMessage::Close(3008, "timeout".to_string())))
314                }
315            };
316        }
317
318        if this.init_fut.is_none() && this.ping_fut.is_none() {
319            while let Poll::Ready(message) = Pin::new(&mut this.stream).poll_next(cx) {
320                let message = match message {
321                    Some(message) => message,
322                    None => return Poll::Ready(None),
323                };
324
325                let message: ClientMessage = match message {
326                    Ok(message) => message,
327                    Err(err) => {
328                        *this.close = true;
329                        return Poll::Ready(Some(WsMessage::Close(1002, err.to_string())));
330                    }
331                };
332
333                *this.last_msg_at = Instant::now();
334                if let Some(keepalive_timer) = this.keepalive_timer {
335                    keepalive_timer.reset();
336                }
337
338                match message {
339                    ClientMessage::ConnectionInit { payload } => {
340                        if let Some(on_connection_init) = this.on_connection_init.take() {
341                            *this.init_fut = Some(Box::pin(async move {
342                                on_connection_init(payload.unwrap_or_default()).await
343                            }));
344                            break;
345                        } else {
346                            *this.close = true;
347                            match this.protocol {
348                                Protocols::SubscriptionsTransportWS => {
349                                    return Poll::Ready(Some(WsMessage::Text(
350                                        serde_json::to_string(&ServerMessage::ConnectionError {
351                                            payload: Error::new(
352                                                "Too many initialisation requests.",
353                                            ),
354                                        })
355                                        .unwrap(),
356                                    )));
357                                }
358                                Protocols::GraphQLWS => {
359                                    return Poll::Ready(Some(WsMessage::Close(
360                                        4429,
361                                        "Too many initialisation requests.".to_string(),
362                                    )));
363                                }
364                            }
365                        }
366                    }
367                    ClientMessage::Start {
368                        id,
369                        payload: request,
370                    } => {
371                        if let Some(data) = this.data.clone() {
372                            this.streams.insert(
373                                id,
374                                Box::pin(this.executor.execute_stream(request, Some(data))),
375                            );
376                        } else {
377                            *this.close = true;
378                            return Poll::Ready(Some(WsMessage::Close(
379                                1011,
380                                "The handshake is not completed.".to_string(),
381                            )));
382                        }
383                    }
384                    ClientMessage::Stop { id } => {
385                        if this.streams.remove(&id).is_some() {
386                            return Poll::Ready(Some(WsMessage::Text(
387                                serde_json::to_string(&ServerMessage::Complete { id: &id })
388                                    .unwrap(),
389                            )));
390                        }
391                    }
392                    // Note: in the revised `graphql-ws` spec, there is no equivalent to the
393                    // `CONNECTION_TERMINATE` `client -> server` message; rather, disconnection is
394                    // handled by disconnecting the websocket
395                    ClientMessage::ConnectionTerminate => {
396                        *this.close = true;
397                        return Poll::Ready(None);
398                    }
399                    // Pong must be sent in response from the receiving party as soon as possible.
400                    ClientMessage::Ping { payload } => {
401                        let on_ping = this.on_ping.clone();
402                        let data = this.data.clone();
403                        *this.ping_fut =
404                            Some(Box::pin(
405                                async move { on_ping(data.as_deref(), payload).await },
406                            ));
407                        break;
408                    }
409                    ClientMessage::Pong { .. } => {
410                        // Do nothing...
411                    }
412                }
413            }
414        }
415
416        if let Some(init_fut) = this.init_fut {
417            return init_fut.poll_unpin(cx).map(|res| {
418                *this.init_fut = None;
419                match res {
420                    Ok(data) => {
421                        let mut ctx_data = this.connection_data.take().unwrap_or_default();
422                        ctx_data.merge(data);
423                        *this.data = Some(Arc::new(ctx_data));
424                        Some(WsMessage::Text(
425                            serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
426                        ))
427                    }
428                    Err(err) => {
429                        *this.close = true;
430                        match this.protocol {
431                            Protocols::SubscriptionsTransportWS => Some(WsMessage::Text(
432                                serde_json::to_string(&ServerMessage::ConnectionError {
433                                    payload: Error::new(err.message),
434                                })
435                                .unwrap(),
436                            )),
437                            Protocols::GraphQLWS => Some(WsMessage::Close(1002, err.message)),
438                        }
439                    }
440                }
441            });
442        }
443
444        if let Some(ping_fut) = this.ping_fut {
445            return ping_fut.poll_unpin(cx).map(|res| {
446                *this.ping_fut = None;
447                match res {
448                    Ok(payload) => Some(WsMessage::Text(
449                        serde_json::to_string(&ServerMessage::Pong { payload }).unwrap(),
450                    )),
451                    Err(err) => {
452                        *this.close = true;
453                        match this.protocol {
454                            Protocols::SubscriptionsTransportWS => Some(WsMessage::Text(
455                                serde_json::to_string(&ServerMessage::ConnectionError {
456                                    payload: Error::new(err.message),
457                                })
458                                .unwrap(),
459                            )),
460                            Protocols::GraphQLWS => Some(WsMessage::Close(1002, err.message)),
461                        }
462                    }
463                }
464            });
465        }
466
467        for (id, stream) in &mut *this.streams {
468            match Pin::new(stream).poll_next(cx) {
469                Poll::Ready(Some(payload)) => {
470                    return Poll::Ready(Some(WsMessage::Text(
471                        serde_json::to_string(&this.protocol.next_message(id, payload)).unwrap(),
472                    )));
473                }
474                Poll::Ready(None) => {
475                    let id = id.clone();
476                    this.streams.remove(&id);
477                    return Poll::Ready(Some(WsMessage::Text(
478                        serde_json::to_string(&ServerMessage::Complete { id: &id }).unwrap(),
479                    )));
480                }
481                Poll::Pending => {}
482            }
483        }
484
485        Poll::Pending
486    }
487}
488
489/// Specification of which GraphQL Over WebSockets protocol is being utilized
490#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
491pub enum Protocols {
492    /// [subscriptions-transport-ws protocol](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md).
493    SubscriptionsTransportWS,
494    /// [graphql-ws protocol](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md).
495    GraphQLWS,
496}
497
498impl Protocols {
499    /// Returns the `Sec-WebSocket-Protocol` header value for the protocol
500    pub fn sec_websocket_protocol(&self) -> &'static str {
501        match self {
502            Protocols::SubscriptionsTransportWS => "graphql-ws",
503            Protocols::GraphQLWS => "graphql-transport-ws",
504        }
505    }
506
507    #[inline]
508    fn next_message<'s>(&self, id: &'s str, payload: Response) -> ServerMessage<'s> {
509        match self {
510            Protocols::SubscriptionsTransportWS => ServerMessage::Data { id, payload },
511            Protocols::GraphQLWS => ServerMessage::Next { id, payload },
512        }
513    }
514}
515
516impl std::str::FromStr for Protocols {
517    type Err = Error;
518
519    fn from_str(protocol: &str) -> Result<Self, Self::Err> {
520        if protocol.eq_ignore_ascii_case("graphql-ws") {
521            Ok(Protocols::SubscriptionsTransportWS)
522        } else if protocol.eq_ignore_ascii_case("graphql-transport-ws") {
523            Ok(Protocols::GraphQLWS)
524        } else {
525            Err(Error::new(format!(
526                "Unsupported Sec-WebSocket-Protocol: {}",
527                protocol
528            )))
529        }
530    }
531}
532
533/// A websocket message received from the client
534#[derive(Deserialize)]
535#[serde(tag = "type", rename_all = "snake_case")]
536#[allow(clippy::large_enum_variant)] // Request is at fault
537pub enum ClientMessage {
538    /// A new connection
539    ConnectionInit {
540        /// Optional init payload from the client
541        payload: Option<serde_json::Value>,
542    },
543    /// The start of a Websocket subscription
544    #[serde(alias = "subscribe")]
545    Start {
546        /// Message ID
547        id: String,
548        /// The GraphQL Request - this can be modified by protocol implementors
549        /// to add files uploads.
550        payload: Request,
551    },
552    /// The end of a Websocket subscription
553    #[serde(alias = "complete")]
554    Stop {
555        /// Message ID
556        id: String,
557    },
558    /// Connection terminated by the client
559    ConnectionTerminate,
560    /// Useful for detecting failed connections, displaying latency metrics or
561    /// other types of network probing.
562    ///
563    /// Reference: <https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#ping>
564    Ping {
565        /// Additional details about the ping.
566        payload: Option<serde_json::Value>,
567    },
568    /// The response to the Ping message.
569    ///
570    /// Reference: <https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#pong>
571    Pong {
572        /// Additional details about the pong.
573        payload: Option<serde_json::Value>,
574    },
575}
576
577impl ClientMessage {
578    /// Creates a ClientMessage from an array of bytes
579    pub fn from_bytes<T>(message: T) -> serde_json::Result<Self>
580    where
581        T: AsRef<[u8]>,
582    {
583        serde_json::from_slice(message.as_ref())
584    }
585}
586
587#[derive(Serialize)]
588#[serde(tag = "type", rename_all = "snake_case")]
589enum ServerMessage<'a> {
590    ConnectionError {
591        payload: Error,
592    },
593    ConnectionAck,
594    /// subscriptions-transport-ws protocol next payload
595    Data {
596        id: &'a str,
597        payload: Response,
598    },
599    /// graphql-ws protocol next payload
600    Next {
601        id: &'a str,
602        payload: Response,
603    },
604    // Not used by this library, as it's not necessary to send
605    // Error {
606    //     id: &'a str,
607    //     payload: serde_json::Value,
608    // },
609    Complete {
610        id: &'a str,
611    },
612    /// The response to the Ping message.
613    ///
614    /// https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#pong
615    Pong {
616        #[serde(skip_serializing_if = "Option::is_none")]
617        payload: Option<serde_json::Value>,
618    },
619    // Not used by this library
620    // #[serde(rename = "ka")]
621    // KeepAlive
622}