Skip to main content

axum_test/
test_web_socket.rs

1use crate::WsMessage;
2use crate::internals::ErrorMessage;
3use anyhow::Result;
4use anyhow::anyhow;
5use bytes::Bytes;
6use expect_json::expect;
7use expect_json::expect_json_eq;
8use futures_util::sink::SinkExt;
9use futures_util::stream::StreamExt;
10use hyper::upgrade::Upgraded;
11use hyper_util::rt::TokioIo;
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use serde_json::Value;
15use std::fmt::Debug;
16use std::fmt::Display;
17use tokio_tungstenite::WebSocketStream;
18use tokio_tungstenite::tungstenite::protocol::Role;
19
20#[cfg(feature = "pretty-assertions")]
21use pretty_assertions::assert_eq;
22
23#[derive(Debug)]
24pub struct TestWebSocket {
25    stream: WebSocketStream<TokioIo<Upgraded>>,
26}
27
28impl TestWebSocket {
29    pub(crate) async fn new(upgraded: Upgraded) -> Self {
30        let upgraded_io = TokioIo::new(upgraded);
31        let stream = WebSocketStream::from_raw_socket(upgraded_io, Role::Client, None).await;
32
33        Self { stream }
34    }
35
36    pub async fn close(mut self) {
37        self.stream
38            .close(None)
39            .await
40            .error_message("Failed to close WebSocket stream");
41    }
42
43    pub async fn send_text<T>(&mut self, raw_text: T) -> &mut Self
44    where
45        T: Display,
46    {
47        let text = raw_text.to_string();
48        self.send_message(WsMessage::Text(text.into())).await;
49
50        self
51    }
52
53    pub async fn send_json<J>(&mut self, body: &J) -> &mut Self
54    where
55        J: ?Sized + Serialize,
56    {
57        let raw_json = ::serde_json::to_string(body).error_message("Failed to serialize into Json");
58
59        self.send_message(WsMessage::Text(raw_json.into())).await;
60
61        self
62    }
63
64    #[cfg(feature = "yaml")]
65    pub async fn send_yaml<Y>(&mut self, body: &Y) -> &mut Self
66    where
67        Y: ?Sized + Serialize,
68    {
69        let raw_yaml = ::serde_yaml::to_string(body).error_message("Failed to serialize into Yaml");
70
71        self.send_message(WsMessage::Text(raw_yaml.into())).await;
72
73        self
74    }
75
76    #[cfg(feature = "msgpack")]
77    pub async fn send_msgpack<M>(&mut self, body: &M) -> &mut Self
78    where
79        M: ?Sized + Serialize,
80    {
81        let body_bytes =
82            ::rmp_serde::to_vec(body).error_message("Failed to serialize into MsgPack");
83
84        self.send_message(WsMessage::Binary(body_bytes.into()))
85            .await;
86
87        self
88    }
89
90    pub async fn send_message(&mut self, message: WsMessage) -> &mut Self {
91        self.stream
92            .send(message)
93            .await
94            .error_message("Failed to send websocket message");
95
96        self
97    }
98
99    #[must_use]
100    pub async fn receive_text(&mut self) -> String {
101        let message = self.receive_message().await;
102
103        message_to_text(message).error_message("Failed to receive websocket response as text")
104    }
105
106    #[must_use]
107    pub async fn receive_json<T>(&mut self) -> T
108    where
109        T: DeserializeOwned,
110    {
111        let bytes = self.receive_bytes().await;
112        serde_json::from_slice::<T>(&bytes)
113            .error_message_with_body("Failed to deserialize Json websocket response", &bytes)
114    }
115
116    #[cfg(feature = "yaml")]
117    #[must_use]
118    pub async fn receive_yaml<T>(&mut self) -> T
119    where
120        T: DeserializeOwned,
121    {
122        let bytes = self.receive_bytes().await;
123        serde_yaml::from_slice::<T>(&bytes)
124            .error_message_with_body("Failed to deserialize Yaml websocket response", &bytes)
125    }
126
127    #[cfg(feature = "msgpack")]
128    #[must_use]
129    pub async fn receive_msgpack<T>(&mut self) -> T
130    where
131        T: DeserializeOwned,
132    {
133        let received_bytes = self.receive_bytes().await;
134        rmp_serde::from_slice::<T>(&received_bytes)
135            .error_message("Failed to deserialize MsgPack websocket response")
136    }
137
138    #[must_use]
139    pub async fn receive_bytes(&mut self) -> Bytes {
140        let message = self.receive_message().await;
141        message_to_bytes(message).error_message("Failed to receive websocket response as bytes")
142    }
143
144    #[must_use]
145    pub async fn receive_message(&mut self) -> WsMessage {
146        self.maybe_receive_message()
147            .await
148            .expect("No message found on WebSocket stream")
149    }
150
151    #[must_use]
152    async fn maybe_receive_message(&mut self) -> Option<WsMessage> {
153        let maybe_message = self.stream.next().await;
154
155        match maybe_message {
156            None => None,
157            Some(message_result) => {
158                let message =
159                    message_result.error_message("Failed to receive message from WebSocket stream");
160
161                Some(message)
162            }
163        }
164    }
165
166    pub async fn assert_receive_json<T>(&mut self, expected: &T) -> &mut Self
167    where
168        T: Serialize + DeserializeOwned + PartialEq<T> + Debug,
169    {
170        let received = self.receive_json::<T>().await;
171
172        if *expected != received {
173            if let Err(error) = expect_json_eq(&received, &expected) {
174                panic!(
175                    "
176{error}
177",
178                );
179            }
180        }
181
182        self
183    }
184
185    pub async fn assert_receive_json_contains<T>(&mut self, expected: &T) -> &mut Self
186    where
187        T: Serialize,
188    {
189        let received = self.receive_json::<Value>().await;
190        let expected_value = serde_json::to_value(expected).unwrap();
191        let result = expect_json_eq(
192            &received,
193            &expect::object().propagated_contains(expected_value),
194        );
195        if let Err(error) = result {
196            panic!(
197                "
198{error}
199",
200            );
201        }
202
203        self
204    }
205
206    pub async fn assert_receive_text<C>(&mut self, expected: C) -> &mut Self
207    where
208        C: AsRef<str>,
209    {
210        let expected_contents = expected.as_ref();
211        assert_eq!(expected_contents, &self.receive_text().await);
212
213        self
214    }
215
216    pub async fn assert_receive_text_contains<C>(&mut self, expected: C) -> &mut Self
217    where
218        C: AsRef<str>,
219    {
220        let expected_contents = expected.as_ref();
221        let received = self.receive_text().await;
222        let is_contained = received.contains(expected_contents);
223
224        assert!(
225            is_contained,
226            "Failed to find '{expected_contents}', received '{received}'"
227        );
228
229        self
230    }
231
232    #[cfg(feature = "yaml")]
233    pub async fn assert_receive_yaml<T>(&mut self, expected: &T) -> &mut Self
234    where
235        T: DeserializeOwned + PartialEq<T> + Debug,
236    {
237        assert_eq!(*expected, self.receive_yaml::<T>().await);
238
239        self
240    }
241
242    #[cfg(feature = "msgpack")]
243    pub async fn assert_receive_msgpack<T>(&mut self, expected: &T) -> &mut Self
244    where
245        T: DeserializeOwned + PartialEq<T> + Debug,
246    {
247        assert_eq!(*expected, self.receive_msgpack::<T>().await);
248
249        self
250    }
251}
252
253fn message_to_text(message: WsMessage) -> Result<String> {
254    let text = match message {
255        WsMessage::Text(text) => text.to_string(),
256        WsMessage::Binary(data) => {
257            String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())?
258        }
259        WsMessage::Ping(data) => {
260            String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())?
261        }
262        WsMessage::Pong(data) => {
263            String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())?
264        }
265        WsMessage::Close(None) => String::new(),
266        WsMessage::Close(Some(frame)) => frame.reason.to_string(),
267        WsMessage::Frame(_) => {
268            return Err(anyhow!(
269                "Unexpected Frame, did not expect Frame message whilst reading"
270            ));
271        }
272    };
273
274    Ok(text)
275}
276
277fn message_to_bytes(message: WsMessage) -> Result<Bytes> {
278    let bytes = match message {
279        WsMessage::Text(string) => string.into(),
280        WsMessage::Binary(data) => data,
281        WsMessage::Ping(data) => data,
282        WsMessage::Pong(data) => data,
283        WsMessage::Close(None) => Bytes::new(),
284        WsMessage::Close(Some(frame)) => frame.reason.into(),
285        WsMessage::Frame(_) => {
286            return Err(anyhow!(
287                "Unexpected Frame, did not expect Frame message whilst reading"
288            ));
289        }
290    };
291
292    Ok(bytes)
293}
294
295#[cfg(test)]
296mod test_assert_receive_text {
297    use crate::TestServer;
298
299    use axum::Router;
300    use axum::extract::WebSocketUpgrade;
301    use axum::extract::ws::Message;
302    use axum::extract::ws::WebSocket;
303    use axum::response::Response;
304    use axum::routing::get;
305
306    fn new_test_app() -> TestServer {
307        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
308            async fn handle_ping_pong(mut socket: WebSocket) {
309                while let Some(maybe_message) = socket.recv().await {
310                    let message_text = maybe_message.unwrap().into_text().unwrap();
311
312                    let encoded_text = format!("Text: {message_text}").try_into().unwrap();
313                    let encoded_data = format!("Binary: {message_text}").into_bytes().into();
314
315                    socket.send(Message::Text(encoded_text)).await.unwrap();
316                    socket.send(Message::Binary(encoded_data)).await.unwrap();
317                }
318            }
319
320            ws.on_upgrade(move |socket| handle_ping_pong(socket))
321        }
322
323        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
324        TestServer::builder().http_transport().build(app)
325    }
326
327    #[tokio::test]
328    async fn it_should_ping_pong_text_in_text_and_binary() {
329        let server = new_test_app();
330
331        let mut websocket = server
332            .get_websocket(&"/ws-ping-pong")
333            .await
334            .into_websocket()
335            .await;
336
337        websocket.send_text("Hello World!").await;
338
339        websocket.assert_receive_text("Text: Hello World!").await;
340        websocket.assert_receive_text("Binary: Hello World!").await;
341    }
342
343    #[tokio::test]
344    async fn it_should_ping_pong_large_text_blobs() {
345        const LARGE_BLOB_SIZE: usize = 16777200; // Max websocket size (16mb) - 16 bytes for the 'Text: ' in the reply.
346        let large_blob = (0..LARGE_BLOB_SIZE).map(|_| "X").collect::<String>();
347
348        let server = new_test_app();
349        let mut websocket = server
350            .get_websocket(&"/ws-ping-pong")
351            .await
352            .into_websocket()
353            .await;
354
355        websocket.send_text(&large_blob).await;
356
357        websocket
358            .assert_receive_text(format!("Text: {large_blob}"))
359            .await;
360        websocket
361            .assert_receive_text(format!("Binary: {large_blob}"))
362            .await;
363    }
364
365    #[tokio::test]
366    #[should_panic]
367    async fn it_should_not_match_partial_text_match() {
368        let server = new_test_app();
369
370        let mut websocket = server
371            .get_websocket(&"/ws-ping-pong")
372            .await
373            .into_websocket()
374            .await;
375
376        websocket.send_text("Hello World!").await;
377        websocket.assert_receive_text("Hello World!").await;
378    }
379
380    #[tokio::test]
381    #[should_panic]
382    async fn it_should_not_match_different_text() {
383        let server = new_test_app();
384
385        let mut websocket = server
386            .get_websocket(&"/ws-ping-pong")
387            .await
388            .into_websocket()
389            .await;
390
391        websocket.send_text("Hello World!").await;
392        websocket.assert_receive_text("🦊").await;
393    }
394}
395
396#[cfg(test)]
397mod test_assert_receive_text_contains {
398    use crate::TestServer;
399    use axum::Router;
400    use axum::extract::WebSocketUpgrade;
401    use axum::extract::ws::Message;
402    use axum::extract::ws::WebSocket;
403    use axum::response::Response;
404    use axum::routing::get;
405
406    fn new_test_app() -> TestServer {
407        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
408            async fn handle_ping_pong(mut socket: WebSocket) {
409                while let Some(maybe_message) = socket.recv().await {
410                    let message_text = maybe_message.unwrap().into_text().unwrap();
411                    let encoded_text = format!("Text: {message_text}").try_into().unwrap();
412
413                    socket.send(Message::Text(encoded_text)).await.unwrap();
414                }
415            }
416
417            ws.on_upgrade(move |socket| handle_ping_pong(socket))
418        }
419
420        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
421        TestServer::builder().http_transport().build(app)
422    }
423
424    #[tokio::test]
425    async fn it_should_assert_whole_text_match() {
426        let server = new_test_app();
427
428        let mut websocket = server
429            .get_websocket(&"/ws-ping-pong")
430            .await
431            .into_websocket()
432            .await;
433
434        websocket.send_text("Hello World!").await;
435        websocket
436            .assert_receive_text_contains("Text: Hello World!")
437            .await;
438    }
439
440    #[tokio::test]
441    async fn it_should_assert_partial_text_match() {
442        let server = new_test_app();
443
444        let mut websocket = server
445            .get_websocket(&"/ws-ping-pong")
446            .await
447            .into_websocket()
448            .await;
449
450        websocket.send_text("Hello World!").await;
451        websocket.assert_receive_text_contains("Hello World!").await;
452    }
453
454    #[tokio::test]
455    #[should_panic]
456    async fn it_should_not_match_different_text() {
457        let server = new_test_app();
458
459        let mut websocket = server
460            .get_websocket(&"/ws-ping-pong")
461            .await
462            .into_websocket()
463            .await;
464
465        websocket.send_text("Hello World!").await;
466        websocket.assert_receive_text_contains("🦊").await;
467    }
468}
469
470#[cfg(test)]
471mod test_assert_receive_json {
472    use crate::TestServer;
473    use crate::testing::ExpectStrMinLen;
474    use axum::Router;
475    use axum::extract::WebSocketUpgrade;
476    use axum::extract::ws::Message;
477    use axum::extract::ws::WebSocket;
478    use axum::response::Response;
479    use axum::routing::get;
480    use expect_json::expect;
481    use serde_json::Value;
482    use serde_json::json;
483
484    fn new_test_app() -> TestServer {
485        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
486            async fn handle_ping_pong(mut socket: WebSocket) {
487                while let Some(maybe_message) = socket.recv().await {
488                    let message_text = maybe_message.unwrap().into_text().unwrap();
489                    let decoded = serde_json::from_str::<Value>(&message_text).unwrap();
490
491                    let encoded_text = serde_json::to_string(&json!({
492                        "format": "text",
493                        "message": decoded
494                    }))
495                    .unwrap()
496                    .try_into()
497                    .unwrap();
498                    let encoded_data = serde_json::to_vec(&json!({
499                        "format": "binary",
500                        "message": decoded
501                    }))
502                    .unwrap()
503                    .into();
504
505                    socket.send(Message::Text(encoded_text)).await.unwrap();
506                    socket.send(Message::Binary(encoded_data)).await.unwrap();
507                }
508            }
509
510            ws.on_upgrade(move |socket| handle_ping_pong(socket))
511        }
512
513        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
514        TestServer::builder().http_transport().build(app)
515    }
516
517    #[tokio::test]
518    async fn it_should_ping_pong_json_in_text_and_binary() {
519        let server = new_test_app();
520
521        let mut websocket = server
522            .get_websocket(&"/ws-ping-pong")
523            .await
524            .into_websocket()
525            .await;
526
527        websocket
528            .send_json(&json!({
529                "hello": "world",
530                "numbers": [1, 2, 3],
531            }))
532            .await;
533
534        // Once for text
535        websocket
536            .assert_receive_json(&json!({
537                "format": "text",
538                "message": {
539                    "hello": "world",
540                    "numbers": [1, 2, 3],
541                },
542            }))
543            .await;
544
545        // Again for binary
546        websocket
547            .assert_receive_json(&json!({
548                "format": "binary",
549                "message": {
550                    "hello": "world",
551                    "numbers": [1, 2, 3],
552                },
553            }))
554            .await;
555    }
556
557    #[tokio::test]
558    async fn it_should_work_with_custom_expect_op() {
559        let server = new_test_app();
560        let mut websocket = server
561            .get_websocket(&"/ws-ping-pong")
562            .await
563            .into_websocket()
564            .await;
565
566        websocket
567            .send_json(&json!({
568                "hello": "world",
569                "numbers": [1, 2, 3],
570            }))
571            .await;
572
573        // Once for text
574        websocket
575            .assert_receive_json(&json!({
576                "format": "text",
577                "message": {
578                    "hello": ExpectStrMinLen { min: 3 },
579                    "numbers": expect::array().len(3).all(expect::integer()),
580                },
581            }))
582            .await;
583
584        // Again for binary
585        websocket
586            .assert_receive_json(&json!({
587                "format": "binary",
588                "message": {
589                    "hello": ExpectStrMinLen { min: 3 },
590                    "numbers": expect::array().len(3).all(expect::integer()),
591                },
592            }))
593            .await;
594    }
595
596    #[tokio::test]
597    #[should_panic]
598    async fn it_should_panic_if_custom_expect_op_fails() {
599        let server = new_test_app();
600        let mut websocket = server
601            .get_websocket(&"/ws-ping-pong")
602            .await
603            .into_websocket()
604            .await;
605
606        websocket
607            .send_json(&json!({
608                "hello": "world",
609                "numbers": [1, 2, 3],
610            }))
611            .await;
612
613        // Once for text
614        websocket
615            .assert_receive_json(&json!({
616                "format": "text",
617                "message": {
618                    "hello": ExpectStrMinLen { min: 10 },
619                    "numbers": expect::array().len(3).all(expect::integer()),
620                },
621            }))
622            .await;
623    }
624}
625
626#[cfg(test)]
627mod test_assert_receive_json_contains {
628    use crate::TestServer;
629    use axum::Router;
630    use axum::extract::WebSocketUpgrade;
631    use axum::extract::ws::Message;
632    use axum::extract::ws::WebSocket;
633    use axum::response::Response;
634    use axum::routing::get;
635    use serde_json::Value;
636    use serde_json::json;
637
638    fn new_test_app() -> TestServer {
639        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
640            async fn handle_ping_pong(mut socket: WebSocket) {
641                while let Some(maybe_message) = socket.recv().await {
642                    let message_text = maybe_message.unwrap().into_text().unwrap();
643                    let decoded = serde_json::from_str::<Value>(&message_text).unwrap();
644
645                    let encoded_text = serde_json::to_string(&json!({
646                        "format": "text",
647                        "message": decoded
648                    }))
649                    .unwrap()
650                    .try_into()
651                    .unwrap();
652                    let encoded_data = serde_json::to_vec(&json!({
653                        "format": "binary",
654                        "message": decoded
655                    }))
656                    .unwrap()
657                    .into();
658
659                    socket.send(Message::Text(encoded_text)).await.unwrap();
660                    socket.send(Message::Binary(encoded_data)).await.unwrap();
661                }
662            }
663
664            ws.on_upgrade(move |socket| handle_ping_pong(socket))
665        }
666
667        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
668        TestServer::builder().http_transport().build(app)
669    }
670
671    #[tokio::test]
672    async fn it_should_ping_pong_json_in_text_and_binary_with_root_content_missing_in_contains() {
673        let server = new_test_app();
674
675        let mut websocket = server
676            .get_websocket(&"/ws-ping-pong")
677            .await
678            .into_websocket()
679            .await;
680
681        websocket
682            .send_json(&json!({
683                "hello": "world",
684                "numbers": [1, 2, 3],
685            }))
686            .await;
687
688        // Once for text
689        websocket
690            .assert_receive_json_contains(&json!({
691                // "format" is missing here
692                "message": {
693                    "hello": "world",
694                    "numbers": [1, 2, 3],
695                },
696            }))
697            .await;
698
699        // Again for binary
700        websocket
701            .assert_receive_json_contains(&json!({
702                "format": "binary",
703                // "message" is missing here
704            }))
705            .await;
706    }
707
708    #[tokio::test]
709    async fn it_should_ping_pong_json_in_text_and_binary_with_nested_content_missing_in_contains() {
710        let server = new_test_app();
711
712        let mut websocket = server
713            .get_websocket(&"/ws-ping-pong")
714            .await
715            .into_websocket()
716            .await;
717
718        websocket
719            .send_json(&json!({
720                "hello": "world",
721                "numbers": [1, 2, 3],
722            }))
723            .await;
724
725        // Once for text
726        websocket
727            .assert_receive_json_contains(&json!({
728                "format": "text",
729                "message": {
730                    // "hello" is missing here
731                    "numbers": [1, 2, 3],
732                },
733            }))
734            .await;
735
736        // Again for binary
737        websocket
738            .assert_receive_json_contains(&json!({
739                "format": "binary",
740                "message": {
741                    "hello": "world",
742                    // "numbers" is missing here
743                },
744            }))
745            .await;
746    }
747}
748
749#[cfg(feature = "yaml")]
750#[cfg(test)]
751mod test_assert_receive_yaml {
752    use crate::TestServer;
753
754    use axum::Router;
755    use axum::extract::WebSocketUpgrade;
756    use axum::extract::ws::Message;
757    use axum::extract::ws::WebSocket;
758    use axum::response::Response;
759    use axum::routing::get;
760    use serde_json::Value;
761    use serde_json::json;
762
763    fn new_test_app() -> TestServer {
764        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
765            async fn handle_ping_pong(mut socket: WebSocket) {
766                while let Some(maybe_message) = socket.recv().await {
767                    let message_text = maybe_message.unwrap().into_text().unwrap();
768                    let decoded = serde_yaml::from_str::<Value>(&message_text).unwrap();
769
770                    let encoded_text = serde_yaml::to_string(&json!({
771                        "format": "text",
772                        "message": decoded
773                    }))
774                    .unwrap()
775                    .try_into()
776                    .unwrap();
777                    let encoded_data = serde_yaml::to_string(&json!({
778                        "format": "binary",
779                        "message": decoded
780                    }))
781                    .unwrap()
782                    .into();
783
784                    socket.send(Message::Text(encoded_text)).await.unwrap();
785                    socket.send(Message::Binary(encoded_data)).await.unwrap();
786                }
787            }
788
789            ws.on_upgrade(move |socket| handle_ping_pong(socket))
790        }
791
792        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
793        TestServer::builder().http_transport().build(app)
794    }
795
796    #[tokio::test]
797    async fn it_should_ping_pong_yaml_in_text_and_binary() {
798        let server = new_test_app();
799
800        let mut websocket = server
801            .get_websocket(&"/ws-ping-pong")
802            .await
803            .into_websocket()
804            .await;
805
806        websocket
807            .send_json(&json!({
808                "hello": "world",
809                "numbers": [1, 2, 3],
810            }))
811            .await;
812
813        // Once for text
814        websocket
815            .assert_receive_yaml(&json!({
816                "format": "text",
817                "message": {
818                    "hello": "world",
819                    "numbers": [1, 2, 3],
820                },
821            }))
822            .await;
823
824        // Again for binary
825        websocket
826            .assert_receive_yaml(&json!({
827                "format": "binary",
828                "message": {
829                    "hello": "world",
830                    "numbers": [1, 2, 3],
831                },
832            }))
833            .await;
834    }
835}
836
837#[cfg(feature = "msgpack")]
838#[cfg(test)]
839mod test_assert_receive_msgpack {
840    use crate::TestServer;
841
842    use axum::Router;
843    use axum::extract::WebSocketUpgrade;
844    use axum::extract::ws::Message;
845    use axum::extract::ws::WebSocket;
846    use axum::response::Response;
847    use axum::routing::get;
848    use serde_json::Value;
849    use serde_json::json;
850
851    fn new_test_app() -> TestServer {
852        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
853            async fn handle_ping_pong(mut socket: WebSocket) {
854                while let Some(maybe_message) = socket.recv().await {
855                    let message_data = maybe_message.unwrap().into_data();
856                    let decoded = rmp_serde::from_slice::<Value>(&message_data).unwrap();
857
858                    let encoded_data = ::rmp_serde::to_vec(&json!({
859                        "format": "binary",
860                        "message": decoded
861                    }))
862                    .unwrap()
863                    .into();
864
865                    socket.send(Message::Binary(encoded_data)).await.unwrap();
866                }
867            }
868
869            ws.on_upgrade(move |socket| handle_ping_pong(socket))
870        }
871
872        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
873        TestServer::builder().http_transport().build(app)
874    }
875
876    #[tokio::test]
877    async fn it_should_ping_pong_msgpack_in_binary() {
878        let server = new_test_app();
879
880        let mut websocket = server
881            .get_websocket(&"/ws-ping-pong")
882            .await
883            .into_websocket()
884            .await;
885
886        websocket
887            .send_msgpack(&json!({
888                "hello": "world",
889                "numbers": [1, 2, 3],
890            }))
891            .await;
892
893        websocket
894            .assert_receive_msgpack(&json!({
895                "format": "binary",
896                "message": {
897                    "hello": "world",
898                    "numbers": [1, 2, 3],
899                },
900            }))
901            .await;
902    }
903}
904
905#[cfg(test)]
906mod test_receive_json {
907    use crate::TestServer;
908    use crate::testing::catch_panic_error_message_async;
909    use axum::Router;
910    use axum::extract::WebSocketUpgrade;
911    use axum::extract::ws::Message;
912    use axum::extract::ws::WebSocket;
913    use axum::response::Response;
914    use axum::routing::get;
915    use pretty_assertions::assert_eq;
916    use pretty_assertions::assert_str_eq;
917    use serde::Deserialize;
918    use serde::Serialize;
919
920    #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
921    struct PingPongMessage {
922        ping: String,
923    }
924
925    fn new_test_app() -> TestServer {
926        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
927            async fn handle_ping_pong(mut socket: WebSocket) {
928                while let Some(received_message) = socket.recv().await {
929                    let received = received_message.unwrap();
930                    let received_text = received.to_text().unwrap();
931                    let encoded_text = match received_text {
932                        r#""good""# => serde_json::to_string(&PingPongMessage {
933                            ping: "pong".to_string(),
934                        })
935                        .unwrap(),
936                        r#""bad""# => "🦊".to_string(),
937                        _ => panic!("unknown message given '{received_text}'"),
938                    };
939
940                    socket
941                        .send(Message::Text(encoded_text.into()))
942                        .await
943                        .unwrap();
944                }
945            }
946
947            ws.on_upgrade(move |socket| handle_ping_pong(socket))
948        }
949
950        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
951        TestServer::builder().http_transport().build(app)
952    }
953
954    #[tokio::test]
955    async fn it_should_parse_when_correct_structure() {
956        let server = new_test_app();
957
958        let mut websocket = server
959            .get_websocket(&"/ws-ping-pong")
960            .await
961            .into_websocket()
962            .await;
963
964        websocket.send_json(&"good").await;
965
966        let received = websocket.receive_json::<PingPongMessage>().await;
967
968        assert_eq!(
969            received,
970            PingPongMessage {
971                ping: "pong".to_string(),
972            }
973        );
974    }
975
976    #[tokio::test]
977    async fn it_should_display_error_with_body_on_parse_fail() {
978        let server = new_test_app();
979
980        let mut websocket = server
981            .get_websocket(&"/ws-ping-pong")
982            .await
983            .into_websocket()
984            .await;
985
986        websocket.send_json(&"bad").await;
987
988        let error_message =
989            catch_panic_error_message_async(websocket.receive_json::<PingPongMessage>()).await;
990
991        assert_str_eq!(
992            r#"Failed to deserialize Json websocket response,
993    expected value at line 1 column 1
994
995received:
996    🦊
997"#,
998            error_message
999        );
1000    }
1001}