Skip to main content

apollo_router/protocols/
websocket.rs

1//! Implements WebSocket _client_ protocols for GraphQL subscriptions.
2
3use std::pin::Pin;
4use std::task::Poll;
5use std::time::Duration;
6
7use futures::Future;
8use futures::Sink;
9use futures::SinkExt;
10use futures::Stream;
11use futures::StreamExt;
12use futures::future;
13use futures::stream::SplitStream;
14use http::HeaderValue;
15use pin_project_lite::pin_project;
16use schemars::JsonSchema;
17use serde::Deserialize;
18use serde::Serialize;
19use serde_json_bytes::Value;
20use tokio::io::AsyncRead;
21use tokio::io::AsyncWrite;
22use tokio_stream::wrappers::IntervalStream;
23use tokio_tungstenite::WebSocketStream;
24use tokio_tungstenite::tungstenite::Message;
25use tokio_tungstenite::tungstenite::protocol::CloseFrame;
26use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
27
28use crate::graphql;
29
30const CONNECTION_ACK_TIMEOUT: Duration = Duration::from_secs(5);
31
32/// The WebSocket subprotocol name for the modern graphql-ws protocol.
33/// See [`WebSocketProtocol::GraphqlWs`].
34const GRAPHQL_WS_SUBPROTOCOL: &str = "graphql-transport-ws";
35/// The WebSocket subprotocol name for the legacy subscriptions-transport-ws protocol.
36/// See [`WebSocketProtocol::SubscriptionsTransportWs`].
37const SUBSCRIPTIONS_TRANSPORT_WS_SUBPROTOCOL: &str = "graphql-ws";
38
39#[derive(Debug, Default, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, JsonSchema, Copy)]
40#[serde(rename_all = "snake_case")]
41pub(crate) enum WebSocketProtocol {
42    /// The modern graphql-ws protocol. The subprotocol name is "graphql-transport-ws".
43    ///
44    /// Spec URL: https://github.com/enisdenjo/graphql-ws/blob/0c0eb499c3a0278c6d9cc799064f22c5d24d2f60/PROTOCOL.md
45    #[default]
46    GraphqlWs,
47    #[serde(rename = "graphql_transport_ws")]
48    /// The legacy subscriptions-transport-ws protocol. Confusingly, the subprotocol name is
49    /// "graphql-ws".
50    ///
51    /// https://github.com/apollographql/subscriptions-transport-ws/blob/36f3f6f780acc1a458b768db13fd39c65e5e6518/PROTOCOL.md
52    SubscriptionsTransportWs,
53}
54
55impl From<WebSocketProtocol> for HeaderValue {
56    fn from(value: WebSocketProtocol) -> Self {
57        match value {
58            WebSocketProtocol::GraphqlWs => HeaderValue::from_static(GRAPHQL_WS_SUBPROTOCOL),
59            WebSocketProtocol::SubscriptionsTransportWs => {
60                HeaderValue::from_static(SUBSCRIPTIONS_TRANSPORT_WS_SUBPROTOCOL)
61            }
62        }
63    }
64}
65
66impl WebSocketProtocol {
67    /// Returns a subscription start message appropriate for the active protocol.
68    fn subscribe(&self, id: String, payload: graphql::Request) -> ClientMessage {
69        match self {
70            WebSocketProtocol::GraphqlWs => ClientMessage::Subscribe { id, payload },
71            WebSocketProtocol::SubscriptionsTransportWs => ClientMessage::OldStart { id, payload },
72        }
73    }
74
75    /// Returns a subscription completion message appropriate for the active protocol.
76    fn complete(&self, id: String) -> ClientMessage {
77        match self {
78            WebSocketProtocol::GraphqlWs => ClientMessage::Complete { id },
79            WebSocketProtocol::SubscriptionsTransportWs => ClientMessage::OldStop { id },
80        }
81    }
82}
83
84/// WebSocket messages sent from the client.
85///
86/// Branches prefixed with "Old" are specific to the subscriptions-transport-ws protocol, other
87/// branches are either part of the graphql-ws protocol or shared by both protocols.
88#[derive(Deserialize, Serialize, Debug)]
89#[serde(tag = "type", rename_all = "snake_case")]
90pub(crate) enum ClientMessage {
91    /// A new connection
92    ConnectionInit {
93        /// Optional init payload from the client
94        payload: Option<serde_json_bytes::Value>,
95    },
96    /// The start of a Websocket subscription in the graphql-ws protocol
97    Subscribe {
98        /// Message ID
99        id: String,
100        /// The GraphQL Request - this can be modified by protocol implementors
101        /// to add files uploads.
102        payload: graphql::Request,
103    },
104    /// The start of a Websocket subscription in the subscriptions-transport-ws protocol
105    #[serde(rename = "start")]
106    OldStart {
107        /// Message ID
108        id: String,
109        /// The GraphQL Request - this can be modified by protocol implementors
110        /// to add files uploads.
111        payload: graphql::Request,
112    },
113    /// The end of a Websocket subscription in the graphql-ws protocol
114    Complete {
115        /// Message ID
116        id: String,
117    },
118    /// The end of a Websocket subscription in the subscriptions-transport-ws protocol
119    #[serde(rename = "stop")]
120    OldStop {
121        /// Message ID
122        id: String,
123    },
124    /// Connection terminated by the client, only used in the subscriptions-transport-ws protocol.
125    #[serde(rename = "connection_terminate")]
126    OldConnectionTerminate,
127    /// Close the websocket connection. This is a router-internal message, not part of the protocol
128    CloseWebsocket,
129    /// Useful for detecting failed connections, displaying latency metrics or
130    /// other types of network probing.
131    ///
132    /// Reference: <https://github.com/enisdenjo/graphql-ws/blob/0c0eb499c3a0278c6d9cc799064f22c5d24d2f60/PROTOCOL.md#ping>
133    Ping {
134        /// Additional details about the ping.
135        #[serde(skip_serializing_if = "Option::is_none")]
136        payload: Option<serde_json_bytes::Value>,
137    },
138    /// The response to the Ping message.
139    ///
140    /// Reference: <https://github.com/enisdenjo/graphql-ws/blob/0c0eb499c3a0278c6d9cc799064f22c5d24d2f60/PROTOCOL.md#pong>
141    Pong {
142        /// Additional details about the pong.
143        #[serde(skip_serializing_if = "Option::is_none")]
144        payload: Option<serde_json_bytes::Value>,
145    },
146}
147
148/// WebSocket messages received from the server.
149#[derive(Deserialize, Serialize, Debug)]
150#[serde(tag = "type", rename_all = "snake_case")]
151pub(crate) enum ServerMessage {
152    ConnectionAck,
153    /// The payload message has type "next" in the graphql-ws protocol, and type "data" in the
154    /// subscriptions-transport-ws protocol.
155    #[serde(alias = "data")]
156    Next {
157        id: String,
158        payload: graphql::Response,
159    },
160    #[serde(alias = "connection_error")]
161    Error {
162        id: Option<String>,
163        payload: ServerError,
164    },
165    Complete {
166        id: String,
167    },
168    #[serde(alias = "ka")]
169    KeepAlive,
170    /// The response to the Ping message.
171    ///
172    /// Reference: <https://github.com/enisdenjo/graphql-ws/blob/0c0eb499c3a0278c6d9cc799064f22c5d24d2f60/PROTOCOL.md#pong>
173    Pong {
174        payload: Option<serde_json::Value>,
175    },
176    Ping {
177        payload: Option<serde_json::Value>,
178    },
179}
180
181#[derive(Deserialize, Serialize, Debug, Clone)]
182#[serde(untagged)]
183pub(crate) enum ServerError {
184    Error(graphql::Error),
185    Errors(Vec<graphql::Error>),
186}
187
188impl From<ServerError> for Vec<graphql::Error> {
189    fn from(value: ServerError) -> Self {
190        match value {
191            ServerError::Error(e) => vec![e],
192            ServerError::Errors(e) => e,
193        }
194    }
195}
196
197impl ServerMessage {
198    fn into_graphql_response(self) -> (Option<graphql::Response>, bool) {
199        match self {
200            ServerMessage::Next { id: _, mut payload } => {
201                payload.subscribed = Some(true);
202                (Some(payload), false)
203            }
204            ServerMessage::Error { id: _, payload } => (
205                Some(
206                    graphql::Response::builder()
207                        .errors(payload.into())
208                        .subscribed(false)
209                        .build(),
210                ),
211                true,
212            ),
213            ServerMessage::Complete { .. } => (None, true),
214            ServerMessage::ConnectionAck | ServerMessage::Pong { .. } => (None, false),
215            ServerMessage::Ping { .. } => (None, false),
216            ServerMessage::KeepAlive => (None, false),
217        }
218    }
219
220    fn id(&self) -> Option<String> {
221        match self {
222            ServerMessage::ConnectionAck
223            | ServerMessage::KeepAlive
224            | ServerMessage::Ping { .. }
225            | ServerMessage::Pong { .. } => None,
226            ServerMessage::Next { id, .. } | ServerMessage::Complete { id } => Some(id.to_string()),
227            ServerMessage::Error { id, .. } => id.clone(),
228        }
229    }
230}
231
232pub(crate) struct GraphqlWebSocket<S> {
233    stream: S,
234    id: String,
235    protocol: WebSocketProtocol,
236}
237
238impl<S> GraphqlWebSocket<S>
239where
240    S: Stream<Item = serde_json::Result<ServerMessage>>
241        + Sink<ClientMessage>
242        + std::marker::Unpin
243        + std::marker::Send
244        + 'static,
245{
246    pub(crate) async fn new(
247        mut stream: S,
248        id: String,
249        protocol: WebSocketProtocol,
250        connection_params: Option<Value>,
251    ) -> Result<Self, graphql::Error> {
252        let connection_init_msg = match connection_params {
253            Some(connection_params) => ClientMessage::ConnectionInit {
254                payload: Some(serde_json_bytes::json!({
255                    "connectionParams": connection_params
256                })),
257            },
258            None => ClientMessage::ConnectionInit { payload: None },
259        };
260        stream.send(connection_init_msg).await.map_err(|_err| {
261            graphql::Error::builder()
262                .message("cannot send connection init through websocket connection")
263                .extension_code("WEBSOCKET_INIT_ERROR")
264                .build()
265        })?;
266
267        let first_non_ping_payload = async {
268            loop {
269                match stream.next().await {
270                    Some(Ok(ServerMessage::Ping { .. })) => {
271                        // tungstenite will send a pong automatically when it receives a ping,
272                        // we just need to call flush - see:
273                        // https://docs.rs/tungstenite/latest/tungstenite/protocol/struct.WebSocket.html#method.flush
274                        // we don't mind an error here
275                        // because it will fall through the error below
276                        // if we haven't been able to properly get a ConnectionAck within the `CONNECTION_ACK_TIMEOUT`
277                        let _ = stream.flush().await;
278                    }
279                    other => {
280                        return other;
281                    }
282                }
283            }
284        };
285
286        let resp = tokio::time::timeout(CONNECTION_ACK_TIMEOUT, first_non_ping_payload)
287            .await
288            .map_err(|_| {
289                graphql::Error::builder()
290                    .message("cannot receive connection ack from websocket connection")
291                    .extension_code("WEBSOCKET_ACK_ERROR_TIMEOUT")
292                    .build()
293            })?;
294        if !matches!(resp, Some(Ok(ServerMessage::ConnectionAck))) {
295            return Err(graphql::Error::builder()
296                .message(format!("didn't receive the connection ack from websocket connection but instead got: {resp:?}"))
297                .extension_code("WEBSOCKET_ACK_ERROR")
298                .build());
299        }
300
301        Ok(Self {
302            stream,
303            id,
304            protocol,
305        })
306    }
307
308    pub(crate) async fn into_subscription(
309        mut self,
310        request: graphql::Request,
311        heartbeat_interval: Option<tokio::time::Duration>,
312    ) -> Result<SubscriptionStream<S>, graphql::Error> {
313        self.stream
314            .send(self.protocol.subscribe(self.id.to_string(), request))
315            .await
316            .map(|_| {
317                SubscriptionStream::new(self.stream, self.id, self.protocol, heartbeat_interval)
318            })
319            .map_err(|_err| {
320                graphql::Error::builder()
321                    .message("cannot send to websocket connection")
322                    .extension_code("WEBSOCKET_CONNECTION_ERROR")
323                    .build()
324            })
325    }
326}
327
328#[derive(thiserror::Error, Debug)]
329pub(crate) enum Error {
330    #[error("websocket error")]
331    WebSocketError(#[from] tokio_tungstenite::tungstenite::Error),
332    #[error("deserialization/serialization error")]
333    SerdeError(#[from] serde_json::Error),
334}
335
336/// Convert a bidirectional stream of untyped websocket packets to a [Stream] + [Sink] that speaks the
337/// GraphQL WebSocket protocol ([`ServerMessage`] and [`ClientMessage`]).
338pub(crate) fn convert_websocket_stream<T>(
339    stream: WebSocketStream<T>,
340    id: String,
341) -> impl Stream<Item = serde_json::Result<ServerMessage>> + Sink<ClientMessage, Error = Error>
342where
343    T: AsyncRead + AsyncWrite + Unpin,
344{
345    stream
346        // Serialize messages being written into the `Sink`
347        .with(|client_message: ClientMessage| {
348            match client_message {
349                ClientMessage::CloseWebsocket => {
350                    future::ready(Ok(Message::Close(Some(CloseFrame{
351                        code: CloseCode::Normal,
352                        reason: Default::default(),
353                    }))))
354                },
355                message => {
356                    future::ready(match serde_json::to_string(&message) {
357                        Ok(client_message_str) => Ok(Message::text(client_message_str)),
358                        Err(err) => Err(Error::SerdeError(err)),
359                    })
360                },
361            }
362        })
363        .inspect(|msg| if let Ok(Message::Text(_) | Message::Binary(_)) = msg {
364            u64_counter!(
365                "apollo.router.operations.subscriptions.events",
366                "Number of subscription events",
367                1,
368                subscriptions.mode = "passthrough"
369            );
370        })
371        // Parse messages received from the `Stream`
372        .map(move |msg| match msg {
373            Ok(Message::Text(text)) => serde_json::from_str(&text),
374            Ok(Message::Binary(bin)) => serde_json::from_slice(&bin),
375            Ok(Message::Ping(payload)) => Ok(ServerMessage::Ping {
376                payload: serde_json::from_slice(&payload).ok(),
377            }),
378            Ok(Message::Pong(payload)) => Ok(ServerMessage::Pong {
379                payload: serde_json::from_slice(&payload).ok(),
380            }),
381            Ok(Message::Close(None)) => Ok(ServerMessage::Complete { id: id.to_string() }),
382            Ok(Message::Close(Some(CloseFrame{ code, reason }))) => {
383                if code == CloseCode::Normal {
384                    Ok(ServerMessage::Complete { id: id.to_string() })
385                } else {
386                    Ok(ServerMessage::Error {
387                        id: Some(id.to_string()),
388                        payload: ServerError::Error(
389                            graphql::Error::builder()
390                                .message(format!("websocket connection has been closed with error code '{code}' and reason '{reason}'"))
391                                .extension_code("WEBSOCKET_CLOSE_ERROR")
392                                .build(),
393                        ),
394                    })
395                }
396            }
397            Ok(Message::Frame(frame)) => serde_json::from_slice(frame.payload()),
398            Err(err) => {
399                tracing::trace!("cannot consume more message on websocket stream: {err:?}");
400
401                Ok(ServerMessage::Error {
402                    id: Some(id.to_string()),
403                    payload: ServerError::Error(
404                        graphql::Error::builder()
405                            .message("cannot read message from websocket")
406                            .extension_code("WEBSOCKET_MESSAGE_ERROR")
407                            .build(),
408                    ),
409                })
410            }
411        })
412}
413
414pub(crate) struct SubscriptionStream<S> {
415    inner_stream: SplitStream<InnerStream<S>>,
416    close_signal: Option<tokio::sync::oneshot::Sender<()>>,
417}
418
419impl<S> SubscriptionStream<S>
420where
421    S: Stream<Item = serde_json::Result<ServerMessage>>
422        + Sink<ClientMessage>
423        + std::marker::Unpin
424        + std::marker::Send
425        + 'static,
426{
427    pub(crate) fn new(
428        stream: S,
429        id: String,
430        protocol: WebSocketProtocol,
431        heartbeat_interval: Option<tokio::time::Duration>,
432    ) -> Self {
433        let (mut sink, inner_stream) = InnerStream::new(stream, id, protocol).split();
434        let (close_signal, close_sentinel) = tokio::sync::oneshot::channel::<()>();
435
436        tokio::task::spawn(async move {
437            if let (WebSocketProtocol::GraphqlWs, Some(duration)) = (protocol, heartbeat_interval) {
438                let mut interval =
439                    tokio::time::interval_at(tokio::time::Instant::now() + duration, duration);
440                interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
441                let mut heartbeat_stream = IntervalStream::new(interval)
442                    .map(|_| Ok(ClientMessage::Ping { payload: None }))
443                    .take_until(close_sentinel);
444                if let Err(err) = sink.send_all(&mut heartbeat_stream).await {
445                    tracing::trace!("cannot send heartbeat: {err:?}");
446                    if let Some(close_sentinel) = heartbeat_stream.take_future()
447                        && let Err(err) = close_sentinel.await
448                    {
449                        tracing::trace!("cannot shutdown sink: {err:?}");
450                    }
451                }
452            } else if let Err(err) = close_sentinel.await {
453                tracing::trace!("cannot shutdown sink: {err:?}");
454            };
455
456            u64_counter!(
457                "apollo.router.operations.subscriptions.events",
458                "Number of subscription events",
459                1,
460                subscriptions.mode = "passthrough",
461                subscriptions.complete = true
462            );
463
464            if let Err(err) = sink.close().await {
465                tracing::trace!("cannot close the websocket stream: {err:?}");
466            }
467        });
468
469        Self {
470            inner_stream,
471            close_signal: Some(close_signal),
472        }
473    }
474}
475
476impl<S> Drop for SubscriptionStream<S> {
477    fn drop(&mut self) {
478        if let Some(close_signal) = self.close_signal.take()
479            && let Err(err) = close_signal.send(())
480        {
481            tracing::trace!("cannot close the websocket stream: {err:?}");
482        }
483    }
484}
485
486impl<S> Stream for SubscriptionStream<S>
487where
488    S: Stream<Item = serde_json::Result<ServerMessage>> + Sink<ClientMessage> + std::marker::Unpin,
489{
490    type Item = graphql::Response;
491
492    fn poll_next(
493        mut self: Pin<&mut Self>,
494        cx: &mut std::task::Context<'_>,
495    ) -> Poll<Option<Self::Item>> {
496        self.inner_stream.poll_next_unpin(cx)
497    }
498}
499
500pin_project! {
501    /// A wrapper over a stream + sink speaking a GraphQL websocket protocol that:
502    /// - turns internal errors into GraphQL errors
503    /// - filters out messages not related to this stream's subscription ID
504    /// - handles connection shutdown according to the GraphQL websocket protocols
505    struct InnerStream<S> {
506        #[pin]
507        stream: S,
508        id: String,
509        protocol: WebSocketProtocol,
510        // Booleans for state machine when closing the stream
511        completed: bool,
512        terminated: bool,
513        // When the websocket stream is closed (!= graphql sub protocol)
514        closed: bool,
515    }
516}
517
518impl<S> InnerStream<S>
519where
520    S: Stream<Item = serde_json::Result<ServerMessage>> + Sink<ClientMessage> + std::marker::Unpin,
521{
522    fn new(stream: S, id: String, protocol: WebSocketProtocol) -> Self {
523        Self {
524            stream,
525            id,
526            protocol,
527            completed: false,
528            terminated: false,
529            closed: false,
530        }
531    }
532}
533
534impl<S> Stream for InnerStream<S>
535where
536    S: Stream<Item = serde_json::Result<ServerMessage>> + Sink<ClientMessage>,
537{
538    type Item = graphql::Response;
539
540    fn poll_next(
541        mut self: Pin<&mut Self>,
542        cx: &mut std::task::Context<'_>,
543    ) -> Poll<Option<Self::Item>> {
544        let mut this = self.as_mut().project();
545
546        match Pin::new(&mut this.stream).poll_next(cx) {
547            Poll::Ready(message) => match message {
548                Some(server_message) => match server_message {
549                    Ok(server_message) => {
550                        if let Some(id) = &server_message.id()
551                            && this.id != id
552                        {
553                            tracing::error!(
554                                "we should not receive data from other subscriptions, closing the stream"
555                            );
556                            return Poll::Ready(None);
557                        }
558                        if let ServerMessage::Ping { .. } = server_message {
559                            // Send pong asynchronously
560                            // XXX(@goto-bus-stop): We have to pull_flush() to ensure this thing
561                            // finishes, not sure if we're doing that right now?
562                            let _ = Pin::new(
563                                &mut Pin::new(&mut this.stream)
564                                    .send(ClientMessage::Pong { payload: None }),
565                            )
566                            .poll(cx);
567                        }
568                        match server_message.into_graphql_response() {
569                            (None, true) => Poll::Ready(None),
570                            // For ignored message like ACK, Ping, Pong, etc...
571                            (None, false) => self.poll_next(cx),
572                            (Some(resp), _) => Poll::Ready(Some(resp)),
573                        }
574                    }
575                    Err(err) => Poll::Ready(
576                        graphql::Response::builder()
577                            .error(
578                                graphql::Error::builder()
579                                    .message(format!(
580                                        "cannot deserialize websocket server message: {err:?}"
581                                    ))
582                                    .extension_code("INVALID_WEBSOCKET_SERVER_MESSAGE_FORMAT")
583                                    .build(),
584                            )
585                            .build()
586                            .into(),
587                    ),
588                },
589                None => Poll::Ready(None),
590            },
591            Poll::Pending => Poll::Pending,
592        }
593    }
594}
595
596impl<S> Sink<ClientMessage> for InnerStream<S>
597where
598    S: Stream<Item = serde_json::Result<ServerMessage>> + Sink<ClientMessage>,
599{
600    type Error = graphql::Error;
601
602    fn poll_ready(
603        self: Pin<&mut Self>,
604        cx: &mut std::task::Context<'_>,
605    ) -> Poll<Result<(), Self::Error>> {
606        let mut this = self.project();
607
608        match Pin::new(&mut this.stream).poll_ready(cx) {
609            Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
610            Poll::Ready(Err(_err)) => Poll::Ready(Err("websocket connection error")),
611            Poll::Pending => Poll::Pending,
612        }
613        .map_err(|err| {
614            graphql::Error::builder()
615                .message(format!("cannot establish websocket connection: {err}"))
616                .extension_code("WEBSOCKET_CONNECTION_ERROR")
617                .build()
618        })
619    }
620
621    fn start_send(self: Pin<&mut Self>, item: ClientMessage) -> Result<(), Self::Error> {
622        let mut this = self.project();
623
624        Pin::new(&mut this.stream).start_send(item).map_err(|_err| {
625            graphql::Error::builder()
626                .message("cannot send to websocket connection")
627                .extension_code("WEBSOCKET_CONNECTION_ERROR")
628                .build()
629        })
630    }
631
632    fn poll_flush(
633        self: Pin<&mut Self>,
634        cx: &mut std::task::Context<'_>,
635    ) -> Poll<Result<(), Self::Error>> {
636        let mut this = self.project();
637        Pin::new(&mut this.stream).poll_flush(cx).map_err(|_err| {
638            graphql::Error::builder()
639                .message("cannot flush to websocket connection")
640                .extension_code("WEBSOCKET_CONNECTION_ERROR")
641                .build()
642        })
643    }
644
645    fn poll_close(
646        self: Pin<&mut Self>,
647        cx: &mut std::task::Context<'_>,
648    ) -> Poll<Result<(), Self::Error>> {
649        let mut this = self.project();
650        if !*this.completed {
651            // XXX(@goto-bus-stop): We have to pull_flush() to ensure this thing
652            // finishes, not sure if we're doing that right now?
653            match Pin::new(
654                &mut Pin::new(&mut this.stream).send(this.protocol.complete(this.id.to_string())),
655            )
656            .poll(cx)
657            {
658                Poll::Ready(_) => {
659                    *this.completed = true;
660                }
661                Poll::Pending => {
662                    return Poll::Pending;
663                }
664            }
665        }
666        if let WebSocketProtocol::SubscriptionsTransportWs = this.protocol
667            && !*this.terminated
668        {
669            // XXX(@goto-bus-stop): We have to pull_flush() to ensure this thing
670            // finishes, not sure if we're doing that right now?
671            match Pin::new(
672                &mut Pin::new(&mut this.stream).send(ClientMessage::OldConnectionTerminate),
673            )
674            .poll(cx)
675            {
676                Poll::Ready(_) => {
677                    *this.terminated = true;
678                }
679                Poll::Pending => {
680                    return Poll::Pending;
681                }
682            }
683        }
684
685        if !*this.closed {
686            // instead of just calling poll_close we also send a proper CloseWebsocket event to indicate it's a normal close, not an error
687            // XXX(@goto-bus-stop): We have to pull_flush() to ensure this thing
688            // finishes, not sure if we're doing that right now?
689            match Pin::new(&mut Pin::new(&mut this.stream).send(ClientMessage::CloseWebsocket))
690                .poll(cx)
691            {
692                Poll::Ready(_) => {
693                    *this.closed = true;
694                }
695                Poll::Pending => {
696                    return Poll::Pending;
697                }
698            }
699        }
700
701        Pin::new(&mut this.stream).poll_close(cx).map_err(|_err| {
702            graphql::Error::builder()
703                .message("cannot close websocket connection")
704                .extension_code("WEBSOCKET_CONNECTION_ERROR")
705                .build()
706        })
707    }
708}
709
710#[cfg(test)]
711mod tests {
712    use std::convert::Infallible;
713    use std::net::SocketAddr;
714
715    use axum::Router;
716    use axum::extract::WebSocketUpgrade;
717    use axum::extract::ws::Message as AxumWsMessage;
718    use axum::routing::get;
719    use bytes::Bytes;
720    use futures::FutureExt;
721    use http::HeaderValue;
722    use tokio_tungstenite::connect_async;
723    use tokio_tungstenite::tungstenite::client::IntoClientRequest;
724    use uuid::Uuid;
725
726    use super::*;
727    use crate::assert_response_eq_ignoring_error_id;
728    use crate::graphql::Request;
729    use crate::metrics::FutureMetricsExt;
730
731    async fn emulate_correct_websocket_server_new_protocol(
732        send_ping: bool,
733        heartbeat_interval: Option<tokio::time::Duration>,
734        port: Option<u16>,
735    ) -> SocketAddr {
736        let ws_handler = move |ws: WebSocketUpgrade| async move {
737            let res = ws.protocols([GRAPHQL_WS_SUBPROTOCOL]).on_upgrade(move |mut socket| async move {
738                let connection_init = socket.recv().await.unwrap().unwrap().into_text().unwrap();
739                let init_msg: ClientMessage = serde_json::from_str(&connection_init).unwrap();
740                if let ClientMessage::ConnectionInit { payload } = init_msg {
741                    assert_eq!(payload, Some(serde_json_bytes::json!({"connectionParams": {
742                        "token": "XXX"
743                    }})));
744                } else {
745                   panic!("it should be a connection init message");
746                }
747
748                if send_ping {
749                    // It turns out some servers may send Pings before they even ack the connection.
750                    socket
751                        .send(AxumWsMessage::Ping(Bytes::new()))
752                        .await
753                        .unwrap();
754
755                    let pong_message = socket.recv().await.unwrap().unwrap();
756                    assert_eq!(pong_message, AxumWsMessage::Pong(Bytes::new()));
757                }
758
759                socket
760                    .send(AxumWsMessage::text(
761                        serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
762                    ))
763                    .await
764                    .unwrap();
765                let new_message = socket.recv().await.unwrap().unwrap().into_text().unwrap();
766                let subscribe_msg: ClientMessage = serde_json::from_str(&new_message).unwrap();
767                assert!(matches!(subscribe_msg, ClientMessage::Subscribe { .. }));
768                #[allow(unused_assignments)]
769                let mut client_id = None;
770                if let ClientMessage::Subscribe { payload, id } = subscribe_msg {
771                    client_id = Some(id);
772                    assert_eq!(
773                        payload,
774                        Request::builder()
775                            .query("subscription {\n  userWasCreated {\n    username\n  }\n}")
776                            .build()
777                    );
778                } else {
779                    panic!("we should receive a subscribe message");
780                }
781
782                socket
783                    .send(AxumWsMessage::text("coucou"))
784                    .await
785                    .unwrap();
786
787                if let Some(duration) = heartbeat_interval {
788                   tokio::time::pause();
789                   assert!(
790                       socket.next().now_or_never().is_none(),
791                       "It should be no pending messages"
792                   );
793
794                   tokio::time::sleep(duration).await;
795                   let ping_message = socket.next().await.unwrap().unwrap();
796                   assert_eq!(ping_message, AxumWsMessage::text(
797                       serde_json::to_string(&ClientMessage::Ping { payload: None }).unwrap(),
798                   ));
799
800                   assert!(
801                       socket.next().now_or_never().is_none(),
802                       "It should be no pending messages"
803                   );
804                   tokio::time::resume();
805                }
806
807                socket
808                    .send(AxumWsMessage::text(
809                        serde_json::to_string(&ServerMessage::Next { id: client_id.clone().unwrap(), payload: graphql::Response::builder().data(serde_json_bytes::json!({"userWasCreated": {"username": "ada_lovelace"}})).build() }).unwrap(),
810                    ))
811                    .await
812                    .unwrap();
813
814                socket
815                    .send(AxumWsMessage::Ping(Bytes::new()))
816                    .await
817                    .unwrap();
818
819                let pong_message = socket.next().await.unwrap().unwrap();
820                assert_eq!(pong_message, AxumWsMessage::Pong(Bytes::new()));
821
822                socket
823                    .send(AxumWsMessage::Ping(Bytes::new()))
824                    .await
825                    .unwrap();
826
827                let pong_message = socket.next().await.unwrap().unwrap();
828                assert_eq!(pong_message, AxumWsMessage::Pong(Bytes::new()));
829
830                socket
831                    .send(AxumWsMessage::text(
832                        serde_json::to_string(&ServerMessage::Complete { id: client_id.unwrap() }).unwrap(),
833                    ))
834                    .await
835                    .unwrap();
836
837                let terminate_sub = socket.recv().await.unwrap().unwrap().into_text().unwrap();
838                let terminate_msg: ClientMessage = serde_json::from_str(&terminate_sub).unwrap();
839                assert!(matches!(terminate_msg, ClientMessage::OldConnectionTerminate));
840                socket.close().await.unwrap();
841            });
842
843            Ok::<_, Infallible>(res)
844        };
845
846        let app = Router::new().route("/ws", get(ws_handler));
847        let listener =
848            tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port.unwrap_or_default()))
849                .await
850                .unwrap();
851        let server = axum::serve(listener, app);
852        let local_addr = server.local_addr().unwrap();
853        tokio::spawn(async { server.await.unwrap() });
854        local_addr
855    }
856
857    async fn emulate_correct_websocket_server_old_protocol(
858        send_ping: bool,
859        port: Option<u16>,
860    ) -> SocketAddr {
861        let ws_handler = move |ws: WebSocketUpgrade| async move {
862            let res = ws.protocols([SUBSCRIPTIONS_TRANSPORT_WS_SUBPROTOCOL]).on_upgrade(move |mut socket| async move {
863                let init_connection = socket.recv().await.unwrap().unwrap().into_text().unwrap();
864                let init_msg: ClientMessage = serde_json::from_str(&init_connection).unwrap();
865                assert!(matches!(init_msg, ClientMessage::ConnectionInit { .. }));
866
867                if send_ping {
868                    // It turns out some servers may send Pings before they even ack the connection.
869                    socket
870                        .send(AxumWsMessage::Ping(Bytes::new()))
871                        .await
872                        .unwrap();
873                    let pong_message = socket.recv().await.unwrap().unwrap();
874                    assert_eq!(pong_message, AxumWsMessage::Pong(Bytes::new()));
875                }
876                socket
877                    .send(AxumWsMessage::text(
878                        serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
879                    ))
880                    .await
881                    .unwrap();
882                socket
883                    .send(AxumWsMessage::text(
884                        serde_json::to_string(&ServerMessage::KeepAlive).unwrap(),
885                    ))
886                    .await
887                    .unwrap();
888                let new_message = socket.recv().await.unwrap().unwrap().into_text().unwrap();
889                let subscribe_msg: ClientMessage = serde_json::from_str(&new_message).unwrap();
890                assert!(matches!(subscribe_msg, ClientMessage::OldStart { .. }));
891                #[allow(unused_assignments)]
892                let mut client_id = None;
893                if let ClientMessage::OldStart { payload, id } = subscribe_msg {
894                    client_id = Some(id);
895                    assert_eq!(
896                        payload,
897                        Request::builder()
898                            .query("subscription {\n  userWasCreated {\n    username\n  }\n}")
899                            .build()
900                    );
901                } else {
902                    panic!("we should receive a subscribe message");
903                }
904
905                socket
906                    .send(AxumWsMessage::text("coucou"))
907                    .await
908                    .unwrap();
909
910                socket
911                    .send(AxumWsMessage::text(
912                        serde_json::to_string(&ServerMessage::Next { id: client_id.clone().unwrap(), payload: graphql::Response::builder().data(serde_json_bytes::json!({"userWasCreated": {"username": "ada_lovelace"}})).build() }).unwrap(),
913                    ))
914                    .await
915                    .unwrap();
916                socket
917                    .send(AxumWsMessage::text(
918                        serde_json::to_string(&ServerMessage::KeepAlive).unwrap(),
919                    ))
920                    .await
921                    .unwrap();
922
923                let stop_sub = socket.recv().await.unwrap().unwrap().into_text().unwrap();
924                let stop_msg: ClientMessage = serde_json::from_str(&stop_sub).unwrap();
925                assert!(matches!(stop_msg, ClientMessage::OldStop { .. }));
926
927                let terminate_sub = socket.recv().await.unwrap().unwrap().into_text().unwrap();
928                let terminate_msg: ClientMessage = serde_json::from_str(&terminate_sub).unwrap();
929                assert!(matches!(terminate_msg, ClientMessage::OldConnectionTerminate));
930
931                socket.close().await.unwrap();
932            });
933
934            Ok::<_, Infallible>(res)
935        };
936
937        let app = Router::new().route("/ws", get(ws_handler));
938        let listener =
939            tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port.unwrap_or_default()))
940                .await
941                .unwrap();
942        let server = axum::serve(listener, app);
943        let local_addr = server.local_addr().unwrap();
944        tokio::spawn(async { server.await.unwrap() });
945        local_addr
946    }
947
948    #[tokio::test]
949    async fn test_ws_connection_new_proto_with_ping() {
950        test_ws_connection_new_proto(true, None, None).await
951    }
952
953    #[tokio::test]
954    async fn test_ws_connection_new_proto_without_ping() {
955        test_ws_connection_new_proto(false, None, None).await
956    }
957
958    #[tokio::test]
959    async fn test_ws_connection_new_proto_with_heartbeat() {
960        test_ws_connection_new_proto(false, Some(tokio::time::Duration::from_secs(60)), None).await
961    }
962
963    async fn test_ws_connection_new_proto(
964        send_ping: bool,
965        heartbeat_interval: Option<tokio::time::Duration>,
966        port: Option<u16>,
967    ) {
968        let socket_addr =
969            emulate_correct_websocket_server_new_protocol(send_ping, heartbeat_interval, port)
970                .await;
971        let url = format!("ws://{socket_addr}/ws");
972        let mut request = url.into_client_request().unwrap();
973        request.headers_mut().insert(
974            http::header::SEC_WEBSOCKET_PROTOCOL,
975            HeaderValue::from_static(GRAPHQL_WS_SUBPROTOCOL),
976        );
977        let (ws_stream, _resp) = connect_async(request).await.unwrap();
978
979        async move {
980            let sub_uuid = Uuid::new_v4();
981            let gql_socket = GraphqlWebSocket::new(
982                convert_websocket_stream(ws_stream, sub_uuid.to_string()),
983                sub_uuid.to_string(),
984                WebSocketProtocol::GraphqlWs,
985                Some(serde_json_bytes::json!({
986                    "token": "XXX"
987                })),
988            )
989            .await
990            .unwrap();
991
992            let sub = "subscription {\n  userWasCreated {\n    username\n  }\n}";
993            let mut gql_read_stream = gql_socket
994                .into_subscription(
995                    graphql::Request::builder().query(sub).build(),
996                    heartbeat_interval,
997                )
998                .await
999                .unwrap();
1000
1001            // Starts at 1 for the connection ack message
1002            assert_counter!(
1003                "apollo.router.operations.subscriptions.events",
1004                1,
1005                subscriptions.mode = "passthrough"
1006            );
1007
1008            let next_payload = gql_read_stream.next().await.unwrap();
1009            assert_response_eq_ignoring_error_id!(next_payload, graphql::Response::builder()
1010                .error(
1011                    graphql::Error::builder()
1012                        .message(
1013                            "cannot deserialize websocket server message: Error(\"expected value\", line: 1, column: 1)".to_string())
1014                        .extension_code("INVALID_WEBSOCKET_SERVER_MESSAGE_FORMAT")
1015                        .build(),
1016                )
1017                .build()
1018            );
1019            // Increments to 2 for the invalid message
1020            assert_counter!(
1021                "apollo.router.operations.subscriptions.events",
1022                2,
1023                subscriptions.mode = "passthrough"
1024            );
1025
1026            let next_payload = gql_read_stream.next().await.unwrap();
1027            assert_eq!(
1028                next_payload,
1029                graphql::Response::builder()
1030                    .subscribed(true)
1031                    .data(serde_json_bytes::json!({"userWasCreated": {"username": "ada_lovelace"}}))
1032                    .build()
1033            );
1034            // Increments to 3 for the next message
1035            assert_counter!(
1036                "apollo.router.operations.subscriptions.events",
1037                3,
1038                subscriptions.mode = "passthrough"
1039            );
1040
1041            assert!(
1042                gql_read_stream.next().now_or_never().is_none(),
1043                "It should be completed"
1044            );
1045        }
1046        .with_metrics()
1047        .await;
1048    }
1049
1050    #[tokio::test]
1051    async fn test_ws_connection_new_proto_error_on_init() {
1052        let ws_handler = move |ws: WebSocketUpgrade| async move {
1053            let res =
1054                ws.protocols(["graphql-transport-ws"])
1055                    .on_upgrade(move |mut socket| async move {
1056                        let connection_ack =
1057                            socket.recv().await.unwrap().unwrap().into_text().unwrap();
1058                        let ack_msg: ClientMessage = serde_json::from_str(&connection_ack).unwrap();
1059                        if let ClientMessage::ConnectionInit { payload } = ack_msg {
1060                            assert_eq!(
1061                                payload,
1062                                Some(serde_json_bytes::json!({"connectionParams": {
1063                                    "token": "XXX"
1064                                }}))
1065                            );
1066                        } else {
1067                            panic!("it should be a connection init message");
1068                        }
1069
1070                        socket
1071                            .send(AxumWsMessage::text(
1072                                r#"{"type": "connection_error", "payload": {"message": "PAYLOAD_MESSAGE_ERROR"}}"#,
1073                            ))
1074                            .await
1075                            .unwrap();
1076
1077                        socket.close().await.unwrap();
1078                    });
1079
1080            Ok::<_, Infallible>(res)
1081        };
1082
1083        let app = Router::new().route("/ws", get(ws_handler));
1084        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1085        let server = axum::serve(listener, app);
1086        let socket_addr = server.local_addr().unwrap();
1087        tokio::spawn(async { server.await.unwrap() });
1088
1089        let url = format!("ws://{socket_addr}/ws");
1090        let mut request = url.into_client_request().unwrap();
1091        request.headers_mut().insert(
1092            http::header::SEC_WEBSOCKET_PROTOCOL,
1093            HeaderValue::from_static("graphql-transport-ws"),
1094        );
1095        let (ws_stream, _resp) = connect_async(request).await.unwrap();
1096
1097        let sub_uuid = Uuid::new_v4();
1098        let res = GraphqlWebSocket::new(
1099            convert_websocket_stream(ws_stream, sub_uuid.to_string()),
1100            sub_uuid.to_string(),
1101            WebSocketProtocol::GraphqlWs,
1102            Some(serde_json_bytes::json!({
1103                "token": "XXX"
1104            })),
1105        )
1106        .await;
1107
1108        assert!(res.is_err());
1109        let err = res.err().unwrap();
1110        println!("err: {err:?}");
1111        assert!(
1112            err.message
1113                .as_str()
1114                .starts_with("didn't receive the connection ack from websocket connection")
1115        );
1116        assert!(
1117            err.message
1118                .as_str()
1119                .contains(r#"Error(Error { message: "PAYLOAD_MESSAGE_ERROR"#)
1120        );
1121        assert_eq!(err.extensions.get("code").unwrap(), "WEBSOCKET_ACK_ERROR");
1122    }
1123
1124    #[tokio::test]
1125    async fn test_ws_connection_old_proto_with_ping() {
1126        test_ws_connection_old_proto(true, None).await
1127    }
1128
1129    #[tokio::test]
1130    async fn test_ws_connection_old_proto_without_ping() {
1131        test_ws_connection_old_proto(false, None).await
1132    }
1133
1134    async fn test_ws_connection_old_proto(send_ping: bool, port: Option<u16>) {
1135        let socket_addr = emulate_correct_websocket_server_old_protocol(send_ping, port).await;
1136        let url = format!("ws://{socket_addr}/ws");
1137        let mut request = url.into_client_request().unwrap();
1138        request.headers_mut().insert(
1139            http::header::SEC_WEBSOCKET_PROTOCOL,
1140            HeaderValue::from_static(SUBSCRIPTIONS_TRANSPORT_WS_SUBPROTOCOL),
1141        );
1142        let (ws_stream, _resp) = connect_async(request).await.unwrap();
1143
1144        async move {
1145            let sub_uuid = Uuid::new_v4();
1146            let gql_socket = GraphqlWebSocket::new(
1147                convert_websocket_stream(ws_stream, sub_uuid.to_string()),
1148                sub_uuid.to_string(),
1149                WebSocketProtocol::SubscriptionsTransportWs,
1150                None,
1151            )
1152            .await
1153            .unwrap();
1154
1155            let sub = "subscription {\n  userWasCreated {\n    username\n  }\n}";
1156            let mut gql_read_stream = gql_socket
1157                .into_subscription(graphql::Request::builder().query(sub).build(), None)
1158                .await
1159                .unwrap();
1160
1161            // Starts at 1 for the connection ack
1162            assert_counter!(
1163                "apollo.router.operations.subscriptions.events",
1164                1,
1165                subscriptions.mode = "passthrough"
1166            );
1167
1168            let next_payload = gql_read_stream.next().await.unwrap();
1169            assert_response_eq_ignoring_error_id!(next_payload, graphql::Response::builder()
1170                .error(
1171                    graphql::Error::builder()
1172                        .message(
1173                            "cannot deserialize websocket server message: Error(\"expected value\", line: 1, column: 1)".to_string())
1174                        .extension_code("INVALID_WEBSOCKET_SERVER_MESSAGE_FORMAT")
1175                        .build(),
1176                )
1177                .build()
1178            );
1179            // Increments to 3 for the keepalive and invalid message
1180            assert_counter!(
1181                "apollo.router.operations.subscriptions.events",
1182                3,
1183                subscriptions.mode = "passthrough"
1184            );
1185
1186            let next_payload = gql_read_stream.next().await.unwrap();
1187            assert_eq!(
1188                next_payload,
1189                graphql::Response::builder()
1190                    .subscribed(true)
1191                    .data(serde_json_bytes::json!({"userWasCreated": {"username": "ada_lovelace"}}))
1192                    .build()
1193            );
1194            // Increments to 4 for the next message
1195            assert_counter!(
1196                "apollo.router.operations.subscriptions.events",
1197                4,
1198                subscriptions.mode = "passthrough"
1199            );
1200
1201            assert!(
1202                gql_read_stream.next().now_or_never().is_none(),
1203                "It should be completed"
1204            );
1205        }
1206        .with_metrics()
1207        .await;
1208    }
1209}