1use std::pin::Pin;
4use std::task::Poll;
5use std::time::Duration;
6
7use futures::Future;
8use futures::Sink;
9use futures::SinkExt;
10use futures::Stream;
11use futures::StreamExt;
12use futures::future;
13use futures::stream::SplitStream;
14use http::HeaderValue;
15use pin_project_lite::pin_project;
16use schemars::JsonSchema;
17use serde::Deserialize;
18use serde::Serialize;
19use serde_json_bytes::Value;
20use tokio::io::AsyncRead;
21use tokio::io::AsyncWrite;
22use tokio_stream::wrappers::IntervalStream;
23use tokio_tungstenite::WebSocketStream;
24use tokio_tungstenite::tungstenite::Message;
25use tokio_tungstenite::tungstenite::protocol::CloseFrame;
26use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
27
28use crate::graphql;
29
30const CONNECTION_ACK_TIMEOUT: Duration = Duration::from_secs(5);
31
32const GRAPHQL_WS_SUBPROTOCOL: &str = "graphql-transport-ws";
35const SUBSCRIPTIONS_TRANSPORT_WS_SUBPROTOCOL: &str = "graphql-ws";
38
39#[derive(Debug, Default, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, JsonSchema, Copy)]
40#[serde(rename_all = "snake_case")]
41pub(crate) enum WebSocketProtocol {
42 #[default]
46 GraphqlWs,
47 #[serde(rename = "graphql_transport_ws")]
48 SubscriptionsTransportWs,
53}
54
55impl From<WebSocketProtocol> for HeaderValue {
56 fn from(value: WebSocketProtocol) -> Self {
57 match value {
58 WebSocketProtocol::GraphqlWs => HeaderValue::from_static(GRAPHQL_WS_SUBPROTOCOL),
59 WebSocketProtocol::SubscriptionsTransportWs => {
60 HeaderValue::from_static(SUBSCRIPTIONS_TRANSPORT_WS_SUBPROTOCOL)
61 }
62 }
63 }
64}
65
66impl WebSocketProtocol {
67 fn subscribe(&self, id: String, payload: graphql::Request) -> ClientMessage {
69 match self {
70 WebSocketProtocol::GraphqlWs => ClientMessage::Subscribe { id, payload },
71 WebSocketProtocol::SubscriptionsTransportWs => ClientMessage::OldStart { id, payload },
72 }
73 }
74
75 fn complete(&self, id: String) -> ClientMessage {
77 match self {
78 WebSocketProtocol::GraphqlWs => ClientMessage::Complete { id },
79 WebSocketProtocol::SubscriptionsTransportWs => ClientMessage::OldStop { id },
80 }
81 }
82}
83
84#[derive(Deserialize, Serialize, Debug)]
89#[serde(tag = "type", rename_all = "snake_case")]
90pub(crate) enum ClientMessage {
91 ConnectionInit {
93 payload: Option<serde_json_bytes::Value>,
95 },
96 Subscribe {
98 id: String,
100 payload: graphql::Request,
103 },
104 #[serde(rename = "start")]
106 OldStart {
107 id: String,
109 payload: graphql::Request,
112 },
113 Complete {
115 id: String,
117 },
118 #[serde(rename = "stop")]
120 OldStop {
121 id: String,
123 },
124 #[serde(rename = "connection_terminate")]
126 OldConnectionTerminate,
127 CloseWebsocket,
129 Ping {
134 #[serde(skip_serializing_if = "Option::is_none")]
136 payload: Option<serde_json_bytes::Value>,
137 },
138 Pong {
142 #[serde(skip_serializing_if = "Option::is_none")]
144 payload: Option<serde_json_bytes::Value>,
145 },
146}
147
148#[derive(Deserialize, Serialize, Debug)]
150#[serde(tag = "type", rename_all = "snake_case")]
151pub(crate) enum ServerMessage {
152 ConnectionAck,
153 #[serde(alias = "data")]
156 Next {
157 id: String,
158 payload: graphql::Response,
159 },
160 #[serde(alias = "connection_error")]
161 Error {
162 id: Option<String>,
163 payload: ServerError,
164 },
165 Complete {
166 id: String,
167 },
168 #[serde(alias = "ka")]
169 KeepAlive,
170 Pong {
174 payload: Option<serde_json::Value>,
175 },
176 Ping {
177 payload: Option<serde_json::Value>,
178 },
179}
180
181#[derive(Deserialize, Serialize, Debug, Clone)]
182#[serde(untagged)]
183pub(crate) enum ServerError {
184 Error(graphql::Error),
185 Errors(Vec<graphql::Error>),
186}
187
188impl From<ServerError> for Vec<graphql::Error> {
189 fn from(value: ServerError) -> Self {
190 match value {
191 ServerError::Error(e) => vec![e],
192 ServerError::Errors(e) => e,
193 }
194 }
195}
196
197impl ServerMessage {
198 fn into_graphql_response(self) -> (Option<graphql::Response>, bool) {
199 match self {
200 ServerMessage::Next { id: _, mut payload } => {
201 payload.subscribed = Some(true);
202 (Some(payload), false)
203 }
204 ServerMessage::Error { id: _, payload } => (
205 Some(
206 graphql::Response::builder()
207 .errors(payload.into())
208 .subscribed(false)
209 .build(),
210 ),
211 true,
212 ),
213 ServerMessage::Complete { .. } => (None, true),
214 ServerMessage::ConnectionAck | ServerMessage::Pong { .. } => (None, false),
215 ServerMessage::Ping { .. } => (None, false),
216 ServerMessage::KeepAlive => (None, false),
217 }
218 }
219
220 fn id(&self) -> Option<String> {
221 match self {
222 ServerMessage::ConnectionAck
223 | ServerMessage::KeepAlive
224 | ServerMessage::Ping { .. }
225 | ServerMessage::Pong { .. } => None,
226 ServerMessage::Next { id, .. } | ServerMessage::Complete { id } => Some(id.to_string()),
227 ServerMessage::Error { id, .. } => id.clone(),
228 }
229 }
230}
231
232pub(crate) struct GraphqlWebSocket<S> {
233 stream: S,
234 id: String,
235 protocol: WebSocketProtocol,
236}
237
238impl<S> GraphqlWebSocket<S>
239where
240 S: Stream<Item = serde_json::Result<ServerMessage>>
241 + Sink<ClientMessage>
242 + std::marker::Unpin
243 + std::marker::Send
244 + 'static,
245{
246 pub(crate) async fn new(
247 mut stream: S,
248 id: String,
249 protocol: WebSocketProtocol,
250 connection_params: Option<Value>,
251 ) -> Result<Self, graphql::Error> {
252 let connection_init_msg = match connection_params {
253 Some(connection_params) => ClientMessage::ConnectionInit {
254 payload: Some(serde_json_bytes::json!({
255 "connectionParams": connection_params
256 })),
257 },
258 None => ClientMessage::ConnectionInit { payload: None },
259 };
260 stream.send(connection_init_msg).await.map_err(|_err| {
261 graphql::Error::builder()
262 .message("cannot send connection init through websocket connection")
263 .extension_code("WEBSOCKET_INIT_ERROR")
264 .build()
265 })?;
266
267 let first_non_ping_payload = async {
268 loop {
269 match stream.next().await {
270 Some(Ok(ServerMessage::Ping { .. })) => {
271 let _ = stream.flush().await;
278 }
279 other => {
280 return other;
281 }
282 }
283 }
284 };
285
286 let resp = tokio::time::timeout(CONNECTION_ACK_TIMEOUT, first_non_ping_payload)
287 .await
288 .map_err(|_| {
289 graphql::Error::builder()
290 .message("cannot receive connection ack from websocket connection")
291 .extension_code("WEBSOCKET_ACK_ERROR_TIMEOUT")
292 .build()
293 })?;
294 if !matches!(resp, Some(Ok(ServerMessage::ConnectionAck))) {
295 return Err(graphql::Error::builder()
296 .message(format!("didn't receive the connection ack from websocket connection but instead got: {resp:?}"))
297 .extension_code("WEBSOCKET_ACK_ERROR")
298 .build());
299 }
300
301 Ok(Self {
302 stream,
303 id,
304 protocol,
305 })
306 }
307
308 pub(crate) async fn into_subscription(
309 mut self,
310 request: graphql::Request,
311 heartbeat_interval: Option<tokio::time::Duration>,
312 ) -> Result<SubscriptionStream<S>, graphql::Error> {
313 self.stream
314 .send(self.protocol.subscribe(self.id.to_string(), request))
315 .await
316 .map(|_| {
317 SubscriptionStream::new(self.stream, self.id, self.protocol, heartbeat_interval)
318 })
319 .map_err(|_err| {
320 graphql::Error::builder()
321 .message("cannot send to websocket connection")
322 .extension_code("WEBSOCKET_CONNECTION_ERROR")
323 .build()
324 })
325 }
326}
327
328#[derive(thiserror::Error, Debug)]
329pub(crate) enum Error {
330 #[error("websocket error")]
331 WebSocketError(#[from] tokio_tungstenite::tungstenite::Error),
332 #[error("deserialization/serialization error")]
333 SerdeError(#[from] serde_json::Error),
334}
335
336pub(crate) fn convert_websocket_stream<T>(
339 stream: WebSocketStream<T>,
340 id: String,
341) -> impl Stream<Item = serde_json::Result<ServerMessage>> + Sink<ClientMessage, Error = Error>
342where
343 T: AsyncRead + AsyncWrite + Unpin,
344{
345 stream
346 .with(|client_message: ClientMessage| {
348 match client_message {
349 ClientMessage::CloseWebsocket => {
350 future::ready(Ok(Message::Close(Some(CloseFrame{
351 code: CloseCode::Normal,
352 reason: Default::default(),
353 }))))
354 },
355 message => {
356 future::ready(match serde_json::to_string(&message) {
357 Ok(client_message_str) => Ok(Message::text(client_message_str)),
358 Err(err) => Err(Error::SerdeError(err)),
359 })
360 },
361 }
362 })
363 .inspect(|msg| if let Ok(Message::Text(_) | Message::Binary(_)) = msg {
364 u64_counter!(
365 "apollo.router.operations.subscriptions.events",
366 "Number of subscription events",
367 1,
368 subscriptions.mode = "passthrough"
369 );
370 })
371 .map(move |msg| match msg {
373 Ok(Message::Text(text)) => serde_json::from_str(&text),
374 Ok(Message::Binary(bin)) => serde_json::from_slice(&bin),
375 Ok(Message::Ping(payload)) => Ok(ServerMessage::Ping {
376 payload: serde_json::from_slice(&payload).ok(),
377 }),
378 Ok(Message::Pong(payload)) => Ok(ServerMessage::Pong {
379 payload: serde_json::from_slice(&payload).ok(),
380 }),
381 Ok(Message::Close(None)) => Ok(ServerMessage::Complete { id: id.to_string() }),
382 Ok(Message::Close(Some(CloseFrame{ code, reason }))) => {
383 if code == CloseCode::Normal {
384 Ok(ServerMessage::Complete { id: id.to_string() })
385 } else {
386 Ok(ServerMessage::Error {
387 id: Some(id.to_string()),
388 payload: ServerError::Error(
389 graphql::Error::builder()
390 .message(format!("websocket connection has been closed with error code '{code}' and reason '{reason}'"))
391 .extension_code("WEBSOCKET_CLOSE_ERROR")
392 .build(),
393 ),
394 })
395 }
396 }
397 Ok(Message::Frame(frame)) => serde_json::from_slice(frame.payload()),
398 Err(err) => {
399 tracing::trace!("cannot consume more message on websocket stream: {err:?}");
400
401 Ok(ServerMessage::Error {
402 id: Some(id.to_string()),
403 payload: ServerError::Error(
404 graphql::Error::builder()
405 .message("cannot read message from websocket")
406 .extension_code("WEBSOCKET_MESSAGE_ERROR")
407 .build(),
408 ),
409 })
410 }
411 })
412}
413
414pub(crate) struct SubscriptionStream<S> {
415 inner_stream: SplitStream<InnerStream<S>>,
416 close_signal: Option<tokio::sync::oneshot::Sender<()>>,
417}
418
419impl<S> SubscriptionStream<S>
420where
421 S: Stream<Item = serde_json::Result<ServerMessage>>
422 + Sink<ClientMessage>
423 + std::marker::Unpin
424 + std::marker::Send
425 + 'static,
426{
427 pub(crate) fn new(
428 stream: S,
429 id: String,
430 protocol: WebSocketProtocol,
431 heartbeat_interval: Option<tokio::time::Duration>,
432 ) -> Self {
433 let (mut sink, inner_stream) = InnerStream::new(stream, id, protocol).split();
434 let (close_signal, close_sentinel) = tokio::sync::oneshot::channel::<()>();
435
436 tokio::task::spawn(async move {
437 if let (WebSocketProtocol::GraphqlWs, Some(duration)) = (protocol, heartbeat_interval) {
438 let mut interval =
439 tokio::time::interval_at(tokio::time::Instant::now() + duration, duration);
440 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
441 let mut heartbeat_stream = IntervalStream::new(interval)
442 .map(|_| Ok(ClientMessage::Ping { payload: None }))
443 .take_until(close_sentinel);
444 if let Err(err) = sink.send_all(&mut heartbeat_stream).await {
445 tracing::trace!("cannot send heartbeat: {err:?}");
446 if let Some(close_sentinel) = heartbeat_stream.take_future()
447 && let Err(err) = close_sentinel.await
448 {
449 tracing::trace!("cannot shutdown sink: {err:?}");
450 }
451 }
452 } else if let Err(err) = close_sentinel.await {
453 tracing::trace!("cannot shutdown sink: {err:?}");
454 };
455
456 u64_counter!(
457 "apollo.router.operations.subscriptions.events",
458 "Number of subscription events",
459 1,
460 subscriptions.mode = "passthrough",
461 subscriptions.complete = true
462 );
463
464 if let Err(err) = sink.close().await {
465 tracing::trace!("cannot close the websocket stream: {err:?}");
466 }
467 });
468
469 Self {
470 inner_stream,
471 close_signal: Some(close_signal),
472 }
473 }
474}
475
476impl<S> Drop for SubscriptionStream<S> {
477 fn drop(&mut self) {
478 if let Some(close_signal) = self.close_signal.take()
479 && let Err(err) = close_signal.send(())
480 {
481 tracing::trace!("cannot close the websocket stream: {err:?}");
482 }
483 }
484}
485
486impl<S> Stream for SubscriptionStream<S>
487where
488 S: Stream<Item = serde_json::Result<ServerMessage>> + Sink<ClientMessage> + std::marker::Unpin,
489{
490 type Item = graphql::Response;
491
492 fn poll_next(
493 mut self: Pin<&mut Self>,
494 cx: &mut std::task::Context<'_>,
495 ) -> Poll<Option<Self::Item>> {
496 self.inner_stream.poll_next_unpin(cx)
497 }
498}
499
500pin_project! {
501 struct InnerStream<S> {
506 #[pin]
507 stream: S,
508 id: String,
509 protocol: WebSocketProtocol,
510 completed: bool,
512 terminated: bool,
513 closed: bool,
515 }
516}
517
518impl<S> InnerStream<S>
519where
520 S: Stream<Item = serde_json::Result<ServerMessage>> + Sink<ClientMessage> + std::marker::Unpin,
521{
522 fn new(stream: S, id: String, protocol: WebSocketProtocol) -> Self {
523 Self {
524 stream,
525 id,
526 protocol,
527 completed: false,
528 terminated: false,
529 closed: false,
530 }
531 }
532}
533
534impl<S> Stream for InnerStream<S>
535where
536 S: Stream<Item = serde_json::Result<ServerMessage>> + Sink<ClientMessage>,
537{
538 type Item = graphql::Response;
539
540 fn poll_next(
541 mut self: Pin<&mut Self>,
542 cx: &mut std::task::Context<'_>,
543 ) -> Poll<Option<Self::Item>> {
544 let mut this = self.as_mut().project();
545
546 match Pin::new(&mut this.stream).poll_next(cx) {
547 Poll::Ready(message) => match message {
548 Some(server_message) => match server_message {
549 Ok(server_message) => {
550 if let Some(id) = &server_message.id()
551 && this.id != id
552 {
553 tracing::error!(
554 "we should not receive data from other subscriptions, closing the stream"
555 );
556 return Poll::Ready(None);
557 }
558 if let ServerMessage::Ping { .. } = server_message {
559 let _ = Pin::new(
563 &mut Pin::new(&mut this.stream)
564 .send(ClientMessage::Pong { payload: None }),
565 )
566 .poll(cx);
567 }
568 match server_message.into_graphql_response() {
569 (None, true) => Poll::Ready(None),
570 (None, false) => self.poll_next(cx),
572 (Some(resp), _) => Poll::Ready(Some(resp)),
573 }
574 }
575 Err(err) => Poll::Ready(
576 graphql::Response::builder()
577 .error(
578 graphql::Error::builder()
579 .message(format!(
580 "cannot deserialize websocket server message: {err:?}"
581 ))
582 .extension_code("INVALID_WEBSOCKET_SERVER_MESSAGE_FORMAT")
583 .build(),
584 )
585 .build()
586 .into(),
587 ),
588 },
589 None => Poll::Ready(None),
590 },
591 Poll::Pending => Poll::Pending,
592 }
593 }
594}
595
596impl<S> Sink<ClientMessage> for InnerStream<S>
597where
598 S: Stream<Item = serde_json::Result<ServerMessage>> + Sink<ClientMessage>,
599{
600 type Error = graphql::Error;
601
602 fn poll_ready(
603 self: Pin<&mut Self>,
604 cx: &mut std::task::Context<'_>,
605 ) -> Poll<Result<(), Self::Error>> {
606 let mut this = self.project();
607
608 match Pin::new(&mut this.stream).poll_ready(cx) {
609 Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
610 Poll::Ready(Err(_err)) => Poll::Ready(Err("websocket connection error")),
611 Poll::Pending => Poll::Pending,
612 }
613 .map_err(|err| {
614 graphql::Error::builder()
615 .message(format!("cannot establish websocket connection: {err}"))
616 .extension_code("WEBSOCKET_CONNECTION_ERROR")
617 .build()
618 })
619 }
620
621 fn start_send(self: Pin<&mut Self>, item: ClientMessage) -> Result<(), Self::Error> {
622 let mut this = self.project();
623
624 Pin::new(&mut this.stream).start_send(item).map_err(|_err| {
625 graphql::Error::builder()
626 .message("cannot send to websocket connection")
627 .extension_code("WEBSOCKET_CONNECTION_ERROR")
628 .build()
629 })
630 }
631
632 fn poll_flush(
633 self: Pin<&mut Self>,
634 cx: &mut std::task::Context<'_>,
635 ) -> Poll<Result<(), Self::Error>> {
636 let mut this = self.project();
637 Pin::new(&mut this.stream).poll_flush(cx).map_err(|_err| {
638 graphql::Error::builder()
639 .message("cannot flush to websocket connection")
640 .extension_code("WEBSOCKET_CONNECTION_ERROR")
641 .build()
642 })
643 }
644
645 fn poll_close(
646 self: Pin<&mut Self>,
647 cx: &mut std::task::Context<'_>,
648 ) -> Poll<Result<(), Self::Error>> {
649 let mut this = self.project();
650 if !*this.completed {
651 match Pin::new(
654 &mut Pin::new(&mut this.stream).send(this.protocol.complete(this.id.to_string())),
655 )
656 .poll(cx)
657 {
658 Poll::Ready(_) => {
659 *this.completed = true;
660 }
661 Poll::Pending => {
662 return Poll::Pending;
663 }
664 }
665 }
666 if let WebSocketProtocol::SubscriptionsTransportWs = this.protocol
667 && !*this.terminated
668 {
669 match Pin::new(
672 &mut Pin::new(&mut this.stream).send(ClientMessage::OldConnectionTerminate),
673 )
674 .poll(cx)
675 {
676 Poll::Ready(_) => {
677 *this.terminated = true;
678 }
679 Poll::Pending => {
680 return Poll::Pending;
681 }
682 }
683 }
684
685 if !*this.closed {
686 match Pin::new(&mut Pin::new(&mut this.stream).send(ClientMessage::CloseWebsocket))
690 .poll(cx)
691 {
692 Poll::Ready(_) => {
693 *this.closed = true;
694 }
695 Poll::Pending => {
696 return Poll::Pending;
697 }
698 }
699 }
700
701 Pin::new(&mut this.stream).poll_close(cx).map_err(|_err| {
702 graphql::Error::builder()
703 .message("cannot close websocket connection")
704 .extension_code("WEBSOCKET_CONNECTION_ERROR")
705 .build()
706 })
707 }
708}
709
710#[cfg(test)]
711mod tests {
712 use std::convert::Infallible;
713 use std::net::SocketAddr;
714
715 use axum::Router;
716 use axum::extract::WebSocketUpgrade;
717 use axum::extract::ws::Message as AxumWsMessage;
718 use axum::routing::get;
719 use bytes::Bytes;
720 use futures::FutureExt;
721 use http::HeaderValue;
722 use tokio_tungstenite::connect_async;
723 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
724 use uuid::Uuid;
725
726 use super::*;
727 use crate::assert_response_eq_ignoring_error_id;
728 use crate::graphql::Request;
729 use crate::metrics::FutureMetricsExt;
730
731 async fn emulate_correct_websocket_server_new_protocol(
732 send_ping: bool,
733 heartbeat_interval: Option<tokio::time::Duration>,
734 port: Option<u16>,
735 ) -> SocketAddr {
736 let ws_handler = move |ws: WebSocketUpgrade| async move {
737 let res = ws.protocols([GRAPHQL_WS_SUBPROTOCOL]).on_upgrade(move |mut socket| async move {
738 let connection_init = socket.recv().await.unwrap().unwrap().into_text().unwrap();
739 let init_msg: ClientMessage = serde_json::from_str(&connection_init).unwrap();
740 if let ClientMessage::ConnectionInit { payload } = init_msg {
741 assert_eq!(payload, Some(serde_json_bytes::json!({"connectionParams": {
742 "token": "XXX"
743 }})));
744 } else {
745 panic!("it should be a connection init message");
746 }
747
748 if send_ping {
749 socket
751 .send(AxumWsMessage::Ping(Bytes::new()))
752 .await
753 .unwrap();
754
755 let pong_message = socket.recv().await.unwrap().unwrap();
756 assert_eq!(pong_message, AxumWsMessage::Pong(Bytes::new()));
757 }
758
759 socket
760 .send(AxumWsMessage::text(
761 serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
762 ))
763 .await
764 .unwrap();
765 let new_message = socket.recv().await.unwrap().unwrap().into_text().unwrap();
766 let subscribe_msg: ClientMessage = serde_json::from_str(&new_message).unwrap();
767 assert!(matches!(subscribe_msg, ClientMessage::Subscribe { .. }));
768 #[allow(unused_assignments)]
769 let mut client_id = None;
770 if let ClientMessage::Subscribe { payload, id } = subscribe_msg {
771 client_id = Some(id);
772 assert_eq!(
773 payload,
774 Request::builder()
775 .query("subscription {\n userWasCreated {\n username\n }\n}")
776 .build()
777 );
778 } else {
779 panic!("we should receive a subscribe message");
780 }
781
782 socket
783 .send(AxumWsMessage::text("coucou"))
784 .await
785 .unwrap();
786
787 if let Some(duration) = heartbeat_interval {
788 tokio::time::pause();
789 assert!(
790 socket.next().now_or_never().is_none(),
791 "It should be no pending messages"
792 );
793
794 tokio::time::sleep(duration).await;
795 let ping_message = socket.next().await.unwrap().unwrap();
796 assert_eq!(ping_message, AxumWsMessage::text(
797 serde_json::to_string(&ClientMessage::Ping { payload: None }).unwrap(),
798 ));
799
800 assert!(
801 socket.next().now_or_never().is_none(),
802 "It should be no pending messages"
803 );
804 tokio::time::resume();
805 }
806
807 socket
808 .send(AxumWsMessage::text(
809 serde_json::to_string(&ServerMessage::Next { id: client_id.clone().unwrap(), payload: graphql::Response::builder().data(serde_json_bytes::json!({"userWasCreated": {"username": "ada_lovelace"}})).build() }).unwrap(),
810 ))
811 .await
812 .unwrap();
813
814 socket
815 .send(AxumWsMessage::Ping(Bytes::new()))
816 .await
817 .unwrap();
818
819 let pong_message = socket.next().await.unwrap().unwrap();
820 assert_eq!(pong_message, AxumWsMessage::Pong(Bytes::new()));
821
822 socket
823 .send(AxumWsMessage::Ping(Bytes::new()))
824 .await
825 .unwrap();
826
827 let pong_message = socket.next().await.unwrap().unwrap();
828 assert_eq!(pong_message, AxumWsMessage::Pong(Bytes::new()));
829
830 socket
831 .send(AxumWsMessage::text(
832 serde_json::to_string(&ServerMessage::Complete { id: client_id.unwrap() }).unwrap(),
833 ))
834 .await
835 .unwrap();
836
837 let terminate_sub = socket.recv().await.unwrap().unwrap().into_text().unwrap();
838 let terminate_msg: ClientMessage = serde_json::from_str(&terminate_sub).unwrap();
839 assert!(matches!(terminate_msg, ClientMessage::OldConnectionTerminate));
840 socket.close().await.unwrap();
841 });
842
843 Ok::<_, Infallible>(res)
844 };
845
846 let app = Router::new().route("/ws", get(ws_handler));
847 let listener =
848 tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port.unwrap_or_default()))
849 .await
850 .unwrap();
851 let server = axum::serve(listener, app);
852 let local_addr = server.local_addr().unwrap();
853 tokio::spawn(async { server.await.unwrap() });
854 local_addr
855 }
856
857 async fn emulate_correct_websocket_server_old_protocol(
858 send_ping: bool,
859 port: Option<u16>,
860 ) -> SocketAddr {
861 let ws_handler = move |ws: WebSocketUpgrade| async move {
862 let res = ws.protocols([SUBSCRIPTIONS_TRANSPORT_WS_SUBPROTOCOL]).on_upgrade(move |mut socket| async move {
863 let init_connection = socket.recv().await.unwrap().unwrap().into_text().unwrap();
864 let init_msg: ClientMessage = serde_json::from_str(&init_connection).unwrap();
865 assert!(matches!(init_msg, ClientMessage::ConnectionInit { .. }));
866
867 if send_ping {
868 socket
870 .send(AxumWsMessage::Ping(Bytes::new()))
871 .await
872 .unwrap();
873 let pong_message = socket.recv().await.unwrap().unwrap();
874 assert_eq!(pong_message, AxumWsMessage::Pong(Bytes::new()));
875 }
876 socket
877 .send(AxumWsMessage::text(
878 serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
879 ))
880 .await
881 .unwrap();
882 socket
883 .send(AxumWsMessage::text(
884 serde_json::to_string(&ServerMessage::KeepAlive).unwrap(),
885 ))
886 .await
887 .unwrap();
888 let new_message = socket.recv().await.unwrap().unwrap().into_text().unwrap();
889 let subscribe_msg: ClientMessage = serde_json::from_str(&new_message).unwrap();
890 assert!(matches!(subscribe_msg, ClientMessage::OldStart { .. }));
891 #[allow(unused_assignments)]
892 let mut client_id = None;
893 if let ClientMessage::OldStart { payload, id } = subscribe_msg {
894 client_id = Some(id);
895 assert_eq!(
896 payload,
897 Request::builder()
898 .query("subscription {\n userWasCreated {\n username\n }\n}")
899 .build()
900 );
901 } else {
902 panic!("we should receive a subscribe message");
903 }
904
905 socket
906 .send(AxumWsMessage::text("coucou"))
907 .await
908 .unwrap();
909
910 socket
911 .send(AxumWsMessage::text(
912 serde_json::to_string(&ServerMessage::Next { id: client_id.clone().unwrap(), payload: graphql::Response::builder().data(serde_json_bytes::json!({"userWasCreated": {"username": "ada_lovelace"}})).build() }).unwrap(),
913 ))
914 .await
915 .unwrap();
916 socket
917 .send(AxumWsMessage::text(
918 serde_json::to_string(&ServerMessage::KeepAlive).unwrap(),
919 ))
920 .await
921 .unwrap();
922
923 let stop_sub = socket.recv().await.unwrap().unwrap().into_text().unwrap();
924 let stop_msg: ClientMessage = serde_json::from_str(&stop_sub).unwrap();
925 assert!(matches!(stop_msg, ClientMessage::OldStop { .. }));
926
927 let terminate_sub = socket.recv().await.unwrap().unwrap().into_text().unwrap();
928 let terminate_msg: ClientMessage = serde_json::from_str(&terminate_sub).unwrap();
929 assert!(matches!(terminate_msg, ClientMessage::OldConnectionTerminate));
930
931 socket.close().await.unwrap();
932 });
933
934 Ok::<_, Infallible>(res)
935 };
936
937 let app = Router::new().route("/ws", get(ws_handler));
938 let listener =
939 tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port.unwrap_or_default()))
940 .await
941 .unwrap();
942 let server = axum::serve(listener, app);
943 let local_addr = server.local_addr().unwrap();
944 tokio::spawn(async { server.await.unwrap() });
945 local_addr
946 }
947
948 #[tokio::test]
949 async fn test_ws_connection_new_proto_with_ping() {
950 test_ws_connection_new_proto(true, None, None).await
951 }
952
953 #[tokio::test]
954 async fn test_ws_connection_new_proto_without_ping() {
955 test_ws_connection_new_proto(false, None, None).await
956 }
957
958 #[tokio::test]
959 async fn test_ws_connection_new_proto_with_heartbeat() {
960 test_ws_connection_new_proto(false, Some(tokio::time::Duration::from_secs(60)), None).await
961 }
962
963 async fn test_ws_connection_new_proto(
964 send_ping: bool,
965 heartbeat_interval: Option<tokio::time::Duration>,
966 port: Option<u16>,
967 ) {
968 let socket_addr =
969 emulate_correct_websocket_server_new_protocol(send_ping, heartbeat_interval, port)
970 .await;
971 let url = format!("ws://{socket_addr}/ws");
972 let mut request = url.into_client_request().unwrap();
973 request.headers_mut().insert(
974 http::header::SEC_WEBSOCKET_PROTOCOL,
975 HeaderValue::from_static(GRAPHQL_WS_SUBPROTOCOL),
976 );
977 let (ws_stream, _resp) = connect_async(request).await.unwrap();
978
979 async move {
980 let sub_uuid = Uuid::new_v4();
981 let gql_socket = GraphqlWebSocket::new(
982 convert_websocket_stream(ws_stream, sub_uuid.to_string()),
983 sub_uuid.to_string(),
984 WebSocketProtocol::GraphqlWs,
985 Some(serde_json_bytes::json!({
986 "token": "XXX"
987 })),
988 )
989 .await
990 .unwrap();
991
992 let sub = "subscription {\n userWasCreated {\n username\n }\n}";
993 let mut gql_read_stream = gql_socket
994 .into_subscription(
995 graphql::Request::builder().query(sub).build(),
996 heartbeat_interval,
997 )
998 .await
999 .unwrap();
1000
1001 assert_counter!(
1003 "apollo.router.operations.subscriptions.events",
1004 1,
1005 subscriptions.mode = "passthrough"
1006 );
1007
1008 let next_payload = gql_read_stream.next().await.unwrap();
1009 assert_response_eq_ignoring_error_id!(next_payload, graphql::Response::builder()
1010 .error(
1011 graphql::Error::builder()
1012 .message(
1013 "cannot deserialize websocket server message: Error(\"expected value\", line: 1, column: 1)".to_string())
1014 .extension_code("INVALID_WEBSOCKET_SERVER_MESSAGE_FORMAT")
1015 .build(),
1016 )
1017 .build()
1018 );
1019 assert_counter!(
1021 "apollo.router.operations.subscriptions.events",
1022 2,
1023 subscriptions.mode = "passthrough"
1024 );
1025
1026 let next_payload = gql_read_stream.next().await.unwrap();
1027 assert_eq!(
1028 next_payload,
1029 graphql::Response::builder()
1030 .subscribed(true)
1031 .data(serde_json_bytes::json!({"userWasCreated": {"username": "ada_lovelace"}}))
1032 .build()
1033 );
1034 assert_counter!(
1036 "apollo.router.operations.subscriptions.events",
1037 3,
1038 subscriptions.mode = "passthrough"
1039 );
1040
1041 assert!(
1042 gql_read_stream.next().now_or_never().is_none(),
1043 "It should be completed"
1044 );
1045 }
1046 .with_metrics()
1047 .await;
1048 }
1049
1050 #[tokio::test]
1051 async fn test_ws_connection_new_proto_error_on_init() {
1052 let ws_handler = move |ws: WebSocketUpgrade| async move {
1053 let res =
1054 ws.protocols(["graphql-transport-ws"])
1055 .on_upgrade(move |mut socket| async move {
1056 let connection_ack =
1057 socket.recv().await.unwrap().unwrap().into_text().unwrap();
1058 let ack_msg: ClientMessage = serde_json::from_str(&connection_ack).unwrap();
1059 if let ClientMessage::ConnectionInit { payload } = ack_msg {
1060 assert_eq!(
1061 payload,
1062 Some(serde_json_bytes::json!({"connectionParams": {
1063 "token": "XXX"
1064 }}))
1065 );
1066 } else {
1067 panic!("it should be a connection init message");
1068 }
1069
1070 socket
1071 .send(AxumWsMessage::text(
1072 r#"{"type": "connection_error", "payload": {"message": "PAYLOAD_MESSAGE_ERROR"}}"#,
1073 ))
1074 .await
1075 .unwrap();
1076
1077 socket.close().await.unwrap();
1078 });
1079
1080 Ok::<_, Infallible>(res)
1081 };
1082
1083 let app = Router::new().route("/ws", get(ws_handler));
1084 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1085 let server = axum::serve(listener, app);
1086 let socket_addr = server.local_addr().unwrap();
1087 tokio::spawn(async { server.await.unwrap() });
1088
1089 let url = format!("ws://{socket_addr}/ws");
1090 let mut request = url.into_client_request().unwrap();
1091 request.headers_mut().insert(
1092 http::header::SEC_WEBSOCKET_PROTOCOL,
1093 HeaderValue::from_static("graphql-transport-ws"),
1094 );
1095 let (ws_stream, _resp) = connect_async(request).await.unwrap();
1096
1097 let sub_uuid = Uuid::new_v4();
1098 let res = GraphqlWebSocket::new(
1099 convert_websocket_stream(ws_stream, sub_uuid.to_string()),
1100 sub_uuid.to_string(),
1101 WebSocketProtocol::GraphqlWs,
1102 Some(serde_json_bytes::json!({
1103 "token": "XXX"
1104 })),
1105 )
1106 .await;
1107
1108 assert!(res.is_err());
1109 let err = res.err().unwrap();
1110 println!("err: {err:?}");
1111 assert!(
1112 err.message
1113 .as_str()
1114 .starts_with("didn't receive the connection ack from websocket connection")
1115 );
1116 assert!(
1117 err.message
1118 .as_str()
1119 .contains(r#"Error(Error { message: "PAYLOAD_MESSAGE_ERROR"#)
1120 );
1121 assert_eq!(err.extensions.get("code").unwrap(), "WEBSOCKET_ACK_ERROR");
1122 }
1123
1124 #[tokio::test]
1125 async fn test_ws_connection_old_proto_with_ping() {
1126 test_ws_connection_old_proto(true, None).await
1127 }
1128
1129 #[tokio::test]
1130 async fn test_ws_connection_old_proto_without_ping() {
1131 test_ws_connection_old_proto(false, None).await
1132 }
1133
1134 async fn test_ws_connection_old_proto(send_ping: bool, port: Option<u16>) {
1135 let socket_addr = emulate_correct_websocket_server_old_protocol(send_ping, port).await;
1136 let url = format!("ws://{socket_addr}/ws");
1137 let mut request = url.into_client_request().unwrap();
1138 request.headers_mut().insert(
1139 http::header::SEC_WEBSOCKET_PROTOCOL,
1140 HeaderValue::from_static(SUBSCRIPTIONS_TRANSPORT_WS_SUBPROTOCOL),
1141 );
1142 let (ws_stream, _resp) = connect_async(request).await.unwrap();
1143
1144 async move {
1145 let sub_uuid = Uuid::new_v4();
1146 let gql_socket = GraphqlWebSocket::new(
1147 convert_websocket_stream(ws_stream, sub_uuid.to_string()),
1148 sub_uuid.to_string(),
1149 WebSocketProtocol::SubscriptionsTransportWs,
1150 None,
1151 )
1152 .await
1153 .unwrap();
1154
1155 let sub = "subscription {\n userWasCreated {\n username\n }\n}";
1156 let mut gql_read_stream = gql_socket
1157 .into_subscription(graphql::Request::builder().query(sub).build(), None)
1158 .await
1159 .unwrap();
1160
1161 assert_counter!(
1163 "apollo.router.operations.subscriptions.events",
1164 1,
1165 subscriptions.mode = "passthrough"
1166 );
1167
1168 let next_payload = gql_read_stream.next().await.unwrap();
1169 assert_response_eq_ignoring_error_id!(next_payload, graphql::Response::builder()
1170 .error(
1171 graphql::Error::builder()
1172 .message(
1173 "cannot deserialize websocket server message: Error(\"expected value\", line: 1, column: 1)".to_string())
1174 .extension_code("INVALID_WEBSOCKET_SERVER_MESSAGE_FORMAT")
1175 .build(),
1176 )
1177 .build()
1178 );
1179 assert_counter!(
1181 "apollo.router.operations.subscriptions.events",
1182 3,
1183 subscriptions.mode = "passthrough"
1184 );
1185
1186 let next_payload = gql_read_stream.next().await.unwrap();
1187 assert_eq!(
1188 next_payload,
1189 graphql::Response::builder()
1190 .subscribed(true)
1191 .data(serde_json_bytes::json!({"userWasCreated": {"username": "ada_lovelace"}}))
1192 .build()
1193 );
1194 assert_counter!(
1196 "apollo.router.operations.subscriptions.events",
1197 4,
1198 subscriptions.mode = "passthrough"
1199 );
1200
1201 assert!(
1202 gql_read_stream.next().now_or_never().is_none(),
1203 "It should be completed"
1204 );
1205 }
1206 .with_metrics()
1207 .await;
1208 }
1209}