async_graphql/http/
websocket.rs

1//! WebSocket transport for subscription
2
3use std::{
4    collections::HashMap,
5    future::Future,
6    pin::Pin,
7    sync::Arc,
8    task::{Context, Poll},
9    time::{Duration, Instant},
10};
11
12use futures_util::{
13    FutureExt, StreamExt,
14    future::{BoxFuture, Ready},
15    stream::Stream,
16};
17use pin_project_lite::pin_project;
18use serde::{Deserialize, Serialize};
19
20use crate::{Data, Error, Executor, Request, Response, Result, runtime::Timer as RtTimer};
21
22/// All known protocols based on WebSocket.
23pub const ALL_WEBSOCKET_PROTOCOLS: [&str; 2] = ["graphql-transport-ws", "graphql-ws"];
24
25/// An enum representing the various forms of a WebSocket message.
26#[derive(Clone, Debug, PartialEq, Eq)]
27pub enum WsMessage {
28    /// A text WebSocket message
29    Text(String),
30
31    /// A close message with the close frame.
32    Close(u16, String),
33}
34
35impl WsMessage {
36    /// Returns the contained [WsMessage::Text] value, consuming the `self`
37    /// value.
38    ///
39    /// Because this function may panic, its use is generally discouraged.
40    ///
41    /// # Panics
42    ///
43    /// Panics if the self value not equals [WsMessage::Text].
44    pub fn unwrap_text(self) -> String {
45        match self {
46            Self::Text(text) => text,
47            Self::Close(_, _) => panic!("Not a text message"),
48        }
49    }
50
51    /// Returns the contained [WsMessage::Close] value, consuming the `self`
52    /// value.
53    ///
54    /// Because this function may panic, its use is generally discouraged.
55    ///
56    /// # Panics
57    ///
58    /// Panics if the self value not equals [WsMessage::Close].
59    pub fn unwrap_close(self) -> (u16, String) {
60        match self {
61            Self::Close(code, msg) => (code, msg),
62            Self::Text(_) => panic!("Not a close message"),
63        }
64    }
65}
66
67struct Timer {
68    interval: Duration,
69    rt_timer: Box<dyn RtTimer>,
70    future: BoxFuture<'static, ()>,
71}
72
73impl Timer {
74    #[inline]
75    fn new<T>(rt_timer: T, interval: Duration) -> Self
76    where
77        T: RtTimer,
78    {
79        Self {
80            interval,
81            future: rt_timer.delay(interval),
82            rt_timer: Box::new(rt_timer),
83        }
84    }
85
86    #[inline]
87    fn reset(&mut self) {
88        self.future = self.rt_timer.delay(self.interval);
89    }
90}
91
92impl Stream for Timer {
93    type Item = ();
94
95    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
96        let this = &mut *self;
97        match this.future.poll_unpin(cx) {
98            Poll::Ready(_) => {
99                this.reset();
100                Poll::Ready(Some(()))
101            }
102            Poll::Pending => Poll::Pending,
103        }
104    }
105}
106
107pin_project! {
108    /// A GraphQL connection over websocket.
109    ///
110    /// # References
111    ///
112    /// - [subscriptions-transport-ws](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md)
113    /// - [graphql-ws](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md)
114    pub struct WebSocket<S, E, OnInit, OnPing> {
115        on_connection_init: Option<OnInit>,
116        on_ping: OnPing,
117        init_fut: Option<BoxFuture<'static, Result<Data>>>,
118        ping_fut: Option<BoxFuture<'static, Result<Option<serde_json::Value>>>>,
119        connection_data: Option<Data>,
120        data: Option<Arc<Data>>,
121        executor: E,
122        streams: HashMap<String, Pin<Box<dyn Stream<Item = Response> + Send>>>,
123        #[pin]
124        stream: S,
125        protocol: Protocols,
126        last_msg_at: Instant,
127        keepalive_timer: Option<Timer>,
128        close: bool,
129    }
130}
131
132type MessageMapStream<S> =
133    futures_util::stream::Map<S, fn(<S as Stream>::Item) -> serde_json::Result<ClientMessage>>;
134
135/// Default connection initializer type.
136pub type DefaultOnConnInitType = fn(serde_json::Value) -> Ready<Result<Data>>;
137
138/// Default ping handler type.
139pub type DefaultOnPingType =
140    fn(Option<&Data>, Option<serde_json::Value>) -> Ready<Result<Option<serde_json::Value>>>;
141
142/// Default connection initializer function.
143pub fn default_on_connection_init(_: serde_json::Value) -> Ready<Result<Data>> {
144    futures_util::future::ready(Ok(Data::default()))
145}
146
147/// Default ping handler function.
148pub fn default_on_ping(
149    _: Option<&Data>,
150    _: Option<serde_json::Value>,
151) -> Ready<Result<Option<serde_json::Value>>> {
152    futures_util::future::ready(Ok(None))
153}
154
155impl<S, E> WebSocket<S, E, DefaultOnConnInitType, DefaultOnPingType>
156where
157    E: Executor,
158    S: Stream<Item = serde_json::Result<ClientMessage>>,
159{
160    /// Create a new websocket from [`ClientMessage`] stream.
161    pub fn from_message_stream(executor: E, stream: S, protocol: Protocols) -> Self {
162        WebSocket {
163            on_connection_init: Some(default_on_connection_init),
164            on_ping: default_on_ping,
165            init_fut: None,
166            ping_fut: None,
167            connection_data: None,
168            data: None,
169            executor,
170            streams: HashMap::new(),
171            stream,
172            protocol,
173            last_msg_at: Instant::now(),
174            keepalive_timer: None,
175            close: false,
176        }
177    }
178}
179
180impl<S, E> WebSocket<MessageMapStream<S>, E, DefaultOnConnInitType, DefaultOnPingType>
181where
182    E: Executor,
183    S: Stream,
184    S::Item: AsRef<[u8]>,
185{
186    /// Create a new websocket from bytes stream.
187    pub fn new(executor: E, stream: S, protocol: Protocols) -> Self {
188        let stream = stream
189            .map(ClientMessage::from_bytes as fn(S::Item) -> serde_json::Result<ClientMessage>);
190        WebSocket::from_message_stream(executor, stream, protocol)
191    }
192}
193
194impl<S, E, OnInit, OnPing> WebSocket<S, E, OnInit, OnPing>
195where
196    E: Executor,
197    S: Stream<Item = serde_json::Result<ClientMessage>>,
198{
199    /// Specify a connection data.
200    ///
201    /// This data usually comes from HTTP requests.
202    /// When the `GQL_CONNECTION_INIT` message is received, this data will be
203    /// merged with the data returned by the closure specified by
204    /// `with_initializer` into the final subscription context data.
205    #[must_use]
206    pub fn connection_data(mut self, data: Data) -> Self {
207        self.connection_data = Some(data);
208        self
209    }
210
211    /// Specify a connection initialize callback function.
212    ///
213    /// This function if present, will be called with the data sent by the
214    /// client in the [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init).
215    /// From that point on the returned data will be accessible to all requests.
216    #[must_use]
217    pub fn on_connection_init<F, R>(self, callback: F) -> WebSocket<S, E, F, OnPing>
218    where
219        F: FnOnce(serde_json::Value) -> R + Send + 'static,
220        R: Future<Output = Result<Data>> + Send + 'static,
221    {
222        WebSocket {
223            on_connection_init: Some(callback),
224            on_ping: self.on_ping,
225            init_fut: self.init_fut,
226            ping_fut: self.ping_fut,
227            connection_data: self.connection_data,
228            data: self.data,
229            executor: self.executor,
230            streams: self.streams,
231            stream: self.stream,
232            protocol: self.protocol,
233            last_msg_at: self.last_msg_at,
234            keepalive_timer: self.keepalive_timer,
235            close: self.close,
236        }
237    }
238
239    /// Specify a ping callback function.
240    ///
241    /// This function if present, will be called with the data sent by the
242    /// client in the [`Ping` message](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#ping).
243    ///
244    /// The function should return the data to be sent in the [`Pong` message](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#pong).
245    ///
246    /// NOTE: Only used for the `graphql-ws` protocol.
247    #[must_use]
248    pub fn on_ping<F, R>(self, callback: F) -> WebSocket<S, E, OnInit, F>
249    where
250        F: FnOnce(Option<&Data>, Option<serde_json::Value>) -> R + Send + Clone + 'static,
251        R: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
252    {
253        WebSocket {
254            on_connection_init: self.on_connection_init,
255            on_ping: callback,
256            init_fut: self.init_fut,
257            ping_fut: self.ping_fut,
258            connection_data: self.connection_data,
259            data: self.data,
260            executor: self.executor,
261            streams: self.streams,
262            stream: self.stream,
263            protocol: self.protocol,
264            last_msg_at: self.last_msg_at,
265            keepalive_timer: self.keepalive_timer,
266            close: self.close,
267        }
268    }
269
270    /// Sets a timeout for receiving an acknowledgement of the keep-alive ping.
271    ///
272    /// If the ping is not acknowledged within the timeout, the connection will
273    /// be closed.
274    ///
275    /// NOTE: Only used for the `graphql-ws` protocol.
276    #[must_use]
277    pub fn keepalive_timeout<T>(self, timer: T, timeout: impl Into<Option<Duration>>) -> Self
278    where
279        T: RtTimer,
280    {
281        Self {
282            keepalive_timer: timeout.into().map(|timeout| Timer::new(timer, timeout)),
283            ..self
284        }
285    }
286}
287
288impl<S, E, OnInit, InitFut, OnPing, PingFut> Stream for WebSocket<S, E, OnInit, OnPing>
289where
290    E: Executor,
291    S: Stream<Item = serde_json::Result<ClientMessage>>,
292    OnInit: FnOnce(serde_json::Value) -> InitFut + Send + 'static,
293    InitFut: Future<Output = Result<Data>> + Send + 'static,
294    OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> PingFut + Clone + Send + 'static,
295    PingFut: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
296{
297    type Item = WsMessage;
298
299    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
300        let mut this = self.project();
301
302        if *this.close {
303            return Poll::Ready(None);
304        }
305
306        if let Some(keepalive_timer) = this.keepalive_timer
307            && let Poll::Ready(Some(())) = keepalive_timer.poll_next_unpin(cx)
308        {
309            return match this.protocol {
310                Protocols::SubscriptionsTransportWS => {
311                    *this.close = true;
312                    Poll::Ready(Some(WsMessage::Text(
313                        serde_json::to_string(&ServerMessage::ConnectionError {
314                            payload: Error::new("timeout"),
315                        })
316                        .unwrap(),
317                    )))
318                }
319                Protocols::GraphQLWS => {
320                    *this.close = true;
321                    Poll::Ready(Some(WsMessage::Close(3008, "timeout".to_string())))
322                }
323            };
324        }
325
326        if this.init_fut.is_none() && this.ping_fut.is_none() {
327            while let Poll::Ready(message) = Pin::new(&mut this.stream).poll_next(cx) {
328                let message = match message {
329                    Some(message) => message,
330                    None => return Poll::Ready(None),
331                };
332
333                let message: ClientMessage = match message {
334                    Ok(message) => message,
335                    Err(err) => {
336                        *this.close = true;
337                        return Poll::Ready(Some(WsMessage::Close(1002, err.to_string())));
338                    }
339                };
340
341                *this.last_msg_at = Instant::now();
342                if let Some(keepalive_timer) = this.keepalive_timer {
343                    keepalive_timer.reset();
344                }
345
346                match message {
347                    ClientMessage::ConnectionInit { payload } => {
348                        if let Some(on_connection_init) = this.on_connection_init.take() {
349                            *this.init_fut = Some(Box::pin(async move {
350                                on_connection_init(payload.unwrap_or_default()).await
351                            }));
352                            break;
353                        } else {
354                            *this.close = true;
355                            match this.protocol {
356                                Protocols::SubscriptionsTransportWS => {
357                                    return Poll::Ready(Some(WsMessage::Text(
358                                        serde_json::to_string(&ServerMessage::ConnectionError {
359                                            payload: Error::new(
360                                                "Too many initialisation requests.",
361                                            ),
362                                        })
363                                        .unwrap(),
364                                    )));
365                                }
366                                Protocols::GraphQLWS => {
367                                    return Poll::Ready(Some(WsMessage::Close(
368                                        4429,
369                                        "Too many initialisation requests.".to_string(),
370                                    )));
371                                }
372                            }
373                        }
374                    }
375                    ClientMessage::Start {
376                        id,
377                        payload: request,
378                    } => {
379                        if let Some(data) = this.data.clone() {
380                            this.streams.insert(
381                                id,
382                                Box::pin(this.executor.execute_stream(request, Some(data))),
383                            );
384                        } else {
385                            *this.close = true;
386                            return Poll::Ready(Some(WsMessage::Close(
387                                1011,
388                                "The handshake is not completed.".to_string(),
389                            )));
390                        }
391                    }
392                    ClientMessage::Stop { id } => {
393                        if this.streams.remove(&id).is_some() {
394                            return Poll::Ready(Some(WsMessage::Text(
395                                serde_json::to_string(&ServerMessage::Complete { id: &id })
396                                    .unwrap(),
397                            )));
398                        }
399                    }
400                    // Note: in the revised `graphql-ws` spec, there is no equivalent to the
401                    // `CONNECTION_TERMINATE` `client -> server` message; rather, disconnection is
402                    // handled by disconnecting the websocket
403                    ClientMessage::ConnectionTerminate => {
404                        *this.close = true;
405                        return Poll::Ready(None);
406                    }
407                    // Pong must be sent in response from the receiving party as soon as possible.
408                    ClientMessage::Ping { payload } => {
409                        let on_ping = this.on_ping.clone();
410                        let data = this.data.clone();
411                        *this.ping_fut =
412                            Some(Box::pin(
413                                async move { on_ping(data.as_deref(), payload).await },
414                            ));
415                        break;
416                    }
417                    ClientMessage::Pong { .. } => {
418                        // Do nothing...
419                    }
420                }
421            }
422        }
423
424        if let Some(init_fut) = this.init_fut {
425            return init_fut.poll_unpin(cx).map(|res| {
426                *this.init_fut = None;
427                match res {
428                    Ok(data) => {
429                        let mut ctx_data = this.connection_data.take().unwrap_or_default();
430                        ctx_data.merge(data);
431                        *this.data = Some(Arc::new(ctx_data));
432                        Some(WsMessage::Text(
433                            serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
434                        ))
435                    }
436                    Err(err) => {
437                        *this.close = true;
438                        match this.protocol {
439                            Protocols::SubscriptionsTransportWS => Some(WsMessage::Text(
440                                serde_json::to_string(&ServerMessage::ConnectionError {
441                                    payload: Error::new(err.message),
442                                })
443                                .unwrap(),
444                            )),
445                            Protocols::GraphQLWS => Some(WsMessage::Close(1002, err.message)),
446                        }
447                    }
448                }
449            });
450        }
451
452        if let Some(ping_fut) = this.ping_fut {
453            return ping_fut.poll_unpin(cx).map(|res| {
454                *this.ping_fut = None;
455                match res {
456                    Ok(payload) => Some(WsMessage::Text(
457                        serde_json::to_string(&ServerMessage::Pong { payload }).unwrap(),
458                    )),
459                    Err(err) => {
460                        *this.close = true;
461                        match this.protocol {
462                            Protocols::SubscriptionsTransportWS => Some(WsMessage::Text(
463                                serde_json::to_string(&ServerMessage::ConnectionError {
464                                    payload: Error::new(err.message),
465                                })
466                                .unwrap(),
467                            )),
468                            Protocols::GraphQLWS => Some(WsMessage::Close(1002, err.message)),
469                        }
470                    }
471                }
472            });
473        }
474
475        for (id, stream) in &mut *this.streams {
476            match Pin::new(stream).poll_next(cx) {
477                Poll::Ready(Some(payload)) => {
478                    return Poll::Ready(Some(WsMessage::Text(
479                        serde_json::to_string(&this.protocol.next_message(id, payload)).unwrap(),
480                    )));
481                }
482                Poll::Ready(None) => {
483                    let id = id.clone();
484                    this.streams.remove(&id);
485                    return Poll::Ready(Some(WsMessage::Text(
486                        serde_json::to_string(&ServerMessage::Complete { id: &id }).unwrap(),
487                    )));
488                }
489                Poll::Pending => {}
490            }
491        }
492
493        Poll::Pending
494    }
495}
496
497/// Specification of which GraphQL Over WebSockets protocol is being utilized
498#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
499pub enum Protocols {
500    /// [subscriptions-transport-ws protocol](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md).
501    SubscriptionsTransportWS,
502    /// [graphql-ws protocol](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md).
503    GraphQLWS,
504}
505
506impl Protocols {
507    /// Returns the `Sec-WebSocket-Protocol` header value for the protocol
508    pub fn sec_websocket_protocol(&self) -> &'static str {
509        match self {
510            Protocols::SubscriptionsTransportWS => "graphql-ws",
511            Protocols::GraphQLWS => "graphql-transport-ws",
512        }
513    }
514
515    #[inline]
516    fn next_message<'s>(&self, id: &'s str, payload: Response) -> ServerMessage<'s> {
517        match self {
518            Protocols::SubscriptionsTransportWS => ServerMessage::Data { id, payload },
519            Protocols::GraphQLWS => ServerMessage::Next { id, payload },
520        }
521    }
522}
523
524impl std::str::FromStr for Protocols {
525    type Err = Error;
526
527    fn from_str(protocol: &str) -> Result<Self, Self::Err> {
528        if protocol.eq_ignore_ascii_case("graphql-ws") {
529            Ok(Protocols::SubscriptionsTransportWS)
530        } else if protocol.eq_ignore_ascii_case("graphql-transport-ws") {
531            Ok(Protocols::GraphQLWS)
532        } else {
533            Err(Error::new(format!(
534                "Unsupported Sec-WebSocket-Protocol: {}",
535                protocol
536            )))
537        }
538    }
539}
540
541/// A websocket message received from the client
542#[derive(Deserialize)]
543#[serde(tag = "type", rename_all = "snake_case")]
544#[allow(clippy::large_enum_variant)] // Request is at fault
545pub enum ClientMessage {
546    /// A new connection
547    ConnectionInit {
548        /// Optional init payload from the client
549        payload: Option<serde_json::Value>,
550    },
551    /// The start of a Websocket subscription
552    #[serde(alias = "subscribe")]
553    Start {
554        /// Message ID
555        id: String,
556        /// The GraphQL Request - this can be modified by protocol implementors
557        /// to add files uploads.
558        payload: Request,
559    },
560    /// The end of a Websocket subscription
561    #[serde(alias = "complete")]
562    Stop {
563        /// Message ID
564        id: String,
565    },
566    /// Connection terminated by the client
567    ConnectionTerminate,
568    /// Useful for detecting failed connections, displaying latency metrics or
569    /// other types of network probing.
570    ///
571    /// Reference: <https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#ping>
572    Ping {
573        /// Additional details about the ping.
574        payload: Option<serde_json::Value>,
575    },
576    /// The response to the Ping message.
577    ///
578    /// Reference: <https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#pong>
579    Pong {
580        /// Additional details about the pong.
581        payload: Option<serde_json::Value>,
582    },
583}
584
585impl ClientMessage {
586    /// Creates a ClientMessage from an array of bytes
587    pub fn from_bytes<T>(message: T) -> serde_json::Result<Self>
588    where
589        T: AsRef<[u8]>,
590    {
591        serde_json::from_slice(message.as_ref())
592    }
593}
594
595#[derive(Serialize)]
596#[serde(tag = "type", rename_all = "snake_case")]
597enum ServerMessage<'a> {
598    ConnectionError {
599        payload: Error,
600    },
601    ConnectionAck,
602    /// subscriptions-transport-ws protocol next payload
603    Data {
604        id: &'a str,
605        payload: Response,
606    },
607    /// graphql-ws protocol next payload
608    Next {
609        id: &'a str,
610        payload: Response,
611    },
612    // Not used by this library, as it's not necessary to send
613    // Error {
614    //     id: &'a str,
615    //     payload: serde_json::Value,
616    // },
617    Complete {
618        id: &'a str,
619    },
620    /// The response to the Ping message.
621    ///
622    /// https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#pong
623    Pong {
624        #[serde(skip_serializing_if = "Option::is_none")]
625        payload: Option<serde_json::Value>,
626    },
627    // Not used by this library
628    // #[serde(rename = "ka")]
629    // KeepAlive
630}