Skip to main content

axum_test/
test_web_socket.rs

1use crate::WsMessage;
2use crate::internals::ErrorMessage;
3use anyhow::Result;
4use anyhow::anyhow;
5use bytes::Bytes;
6use expect_json::expect;
7use expect_json::expect_json_eq;
8use futures_util::sink::SinkExt;
9use futures_util::stream::StreamExt;
10use hyper::upgrade::Upgraded;
11use hyper_util::rt::TokioIo;
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use serde_json::Value;
15use std::fmt::Debug;
16use std::fmt::Display;
17use tokio_tungstenite::WebSocketStream;
18use tokio_tungstenite::tungstenite::protocol::Role;
19
20#[cfg(feature = "pretty-assertions")]
21use pretty_assertions::assert_eq;
22
23#[derive(Debug)]
24pub struct TestWebSocket {
25    stream: WebSocketStream<TokioIo<Upgraded>>,
26}
27
28impl TestWebSocket {
29    pub(crate) async fn new(upgraded: Upgraded) -> Self {
30        let upgraded_io = TokioIo::new(upgraded);
31        let stream = WebSocketStream::from_raw_socket(upgraded_io, Role::Client, None).await;
32
33        Self { stream }
34    }
35
36    pub async fn close(mut self) {
37        self.stream
38            .close(None)
39            .await
40            .error_message("Failed to close WebSocket stream");
41    }
42
43    pub async fn send_text<T>(&mut self, raw_text: T) -> &mut Self
44    where
45        T: Display,
46    {
47        let text = raw_text.to_string();
48        self.send_message(WsMessage::Text(text.into())).await;
49
50        self
51    }
52
53    pub async fn send_json<J>(&mut self, body: &J) -> &mut Self
54    where
55        J: ?Sized + Serialize,
56    {
57        let raw_json = ::serde_json::to_string(body).error_message("Failed to serialize into Json");
58
59        self.send_message(WsMessage::Text(raw_json.into())).await;
60
61        self
62    }
63
64    #[cfg(feature = "yaml")]
65    pub async fn send_yaml<Y>(&mut self, body: &Y) -> &mut Self
66    where
67        Y: ?Sized + Serialize,
68    {
69        let raw_yaml = ::serde_yaml::to_string(body).error_message("Failed to serialize into Yaml");
70
71        self.send_message(WsMessage::Text(raw_yaml.into())).await;
72
73        self
74    }
75
76    #[cfg(feature = "msgpack")]
77    pub async fn send_msgpack<M>(&mut self, body: &M) -> &mut Self
78    where
79        M: ?Sized + Serialize,
80    {
81        let body_bytes =
82            ::rmp_serde::to_vec(body).error_message("Failed to serialize into MsgPack");
83
84        self.send_message(WsMessage::Binary(body_bytes.into()))
85            .await;
86
87        self
88    }
89
90    pub async fn send_message(&mut self, message: WsMessage) -> &mut Self {
91        self.stream
92            .send(message)
93            .await
94            .error_message("Failed to send websocket message");
95
96        self
97    }
98
99    #[must_use]
100    pub async fn receive_text(&mut self) -> String {
101        let message = self.receive_message().await;
102
103        message_to_text(message).error_message("Failed to receive websocket response as text")
104    }
105
106    #[must_use]
107    pub async fn receive_json<T>(&mut self) -> T
108    where
109        T: DeserializeOwned,
110    {
111        let bytes = self.receive_bytes().await;
112        serde_json::from_slice::<T>(&bytes)
113            .error_message_with_body("Failed to deserialize Json websocket response", &bytes)
114    }
115
116    #[cfg(feature = "yaml")]
117    #[must_use]
118    pub async fn receive_yaml<T>(&mut self) -> T
119    where
120        T: DeserializeOwned,
121    {
122        let bytes = self.receive_bytes().await;
123        serde_yaml::from_slice::<T>(&bytes)
124            .error_message_with_body("Failed to deserialize Yaml websocket response", &bytes)
125    }
126
127    #[cfg(feature = "msgpack")]
128    #[must_use]
129    pub async fn receive_msgpack<T>(&mut self) -> T
130    where
131        T: DeserializeOwned,
132    {
133        let received_bytes = self.receive_bytes().await;
134        rmp_serde::from_slice::<T>(&received_bytes)
135            .error_message("Failed to deserialize MsgPack websocket response")
136    }
137
138    #[must_use]
139    pub async fn receive_bytes(&mut self) -> Bytes {
140        let message = self.receive_message().await;
141        message_to_bytes(message).error_message("Failed to receive websocket response as bytes")
142    }
143
144    #[must_use]
145    pub async fn receive_message(&mut self) -> WsMessage {
146        self.maybe_receive_message()
147            .await
148            .expect("No message found on WebSocket stream")
149    }
150
151    #[must_use]
152    async fn maybe_receive_message(&mut self) -> Option<WsMessage> {
153        let maybe_message = self.stream.next().await;
154
155        match maybe_message {
156            None => None,
157            Some(message_result) => {
158                let message =
159                    message_result.error_message("Failed to receive message from WebSocket stream");
160
161                Some(message)
162            }
163        }
164    }
165
166    pub async fn assert_receive_json<T>(&mut self, expected: &T) -> &mut Self
167    where
168        T: Serialize + DeserializeOwned + PartialEq<T> + Debug,
169    {
170        let received = self.receive_json::<T>().await;
171
172        if *expected != received {
173            if let Err(error) = expect_json_eq(&received, &expected) {
174                panic!(
175                    "
176{error:?}
177",
178                );
179            }
180        }
181
182        self
183    }
184
185    pub async fn assert_receive_json_contains<T>(&mut self, expected: &T) -> &mut Self
186    where
187        T: Serialize,
188    {
189        let received = self.receive_json::<Value>().await;
190        let expected_value = serde_json::to_value(expected).unwrap();
191        let result = expect_json_eq(
192            &received,
193            &expect::object().propagated_contains(expected_value),
194        );
195        if let Err(error) = result {
196            panic!(
197                "
198{error}
199",
200            );
201        }
202
203        self
204    }
205
206    pub async fn assert_receive_text<C>(&mut self, expected: C) -> &mut Self
207    where
208        C: AsRef<str>,
209    {
210        let expected_contents = expected.as_ref();
211        assert_eq!(expected_contents, &self.receive_text().await);
212
213        self
214    }
215
216    pub async fn assert_receive_text_contains<C>(&mut self, expected: C) -> &mut Self
217    where
218        C: AsRef<str>,
219    {
220        let expected_contents = expected.as_ref();
221        let received = self.receive_text().await;
222        let is_contained = received.contains(expected_contents);
223
224        assert!(
225            is_contained,
226            "Failed to find '{expected_contents}', received '{received}'"
227        );
228
229        self
230    }
231
232    #[cfg(feature = "yaml")]
233    pub async fn assert_receive_yaml<T>(&mut self, expected: &T) -> &mut Self
234    where
235        T: DeserializeOwned + PartialEq<T> + Debug,
236    {
237        assert_eq!(*expected, self.receive_yaml::<T>().await);
238
239        self
240    }
241
242    #[cfg(feature = "msgpack")]
243    pub async fn assert_receive_msgpack<T>(&mut self, expected: &T) -> &mut Self
244    where
245        T: DeserializeOwned + PartialEq<T> + Debug,
246    {
247        assert_eq!(*expected, self.receive_msgpack::<T>().await);
248
249        self
250    }
251}
252
253fn message_to_text(message: WsMessage) -> Result<String> {
254    let text = match message {
255        WsMessage::Text(text) => text.to_string(),
256        WsMessage::Binary(data) => {
257            String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())?
258        }
259        WsMessage::Ping(data) => {
260            String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())?
261        }
262        WsMessage::Pong(data) => {
263            String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())?
264        }
265        WsMessage::Close(None) => String::new(),
266        WsMessage::Close(Some(frame)) => frame.reason.to_string(),
267        WsMessage::Frame(_) => {
268            return Err(anyhow!(
269                "Unexpected Frame, did not expect Frame message whilst reading"
270            ));
271        }
272    };
273
274    Ok(text)
275}
276
277fn message_to_bytes(message: WsMessage) -> Result<Bytes> {
278    let bytes = match message {
279        WsMessage::Text(string) => string.into(),
280        WsMessage::Binary(data) => data,
281        WsMessage::Ping(data) => data,
282        WsMessage::Pong(data) => data,
283        WsMessage::Close(None) => Bytes::new(),
284        WsMessage::Close(Some(frame)) => frame.reason.into(),
285        WsMessage::Frame(_) => {
286            return Err(anyhow!(
287                "Unexpected Frame, did not expect Frame message whilst reading"
288            ));
289        }
290    };
291
292    Ok(bytes)
293}
294
295#[cfg(test)]
296mod test_assert_receive_text {
297    use crate::TestServer;
298    use crate::testing::assert_error_message;
299    use crate::testing::catch_panic_error_message_async;
300    use axum::Router;
301    use axum::extract::WebSocketUpgrade;
302    use axum::extract::ws::Message;
303    use axum::extract::ws::WebSocket;
304    use axum::response::Response;
305    use axum::routing::get;
306
307    fn new_test_app() -> TestServer {
308        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
309            async fn handle_ping_pong(mut socket: WebSocket) {
310                while let Some(maybe_message) = socket.recv().await {
311                    let message_text = maybe_message.unwrap().into_text().unwrap();
312
313                    let encoded_text = format!("Text: {message_text}").try_into().unwrap();
314                    let encoded_data = format!("Binary: {message_text}").into_bytes().into();
315
316                    socket.send(Message::Text(encoded_text)).await.unwrap();
317                    socket.send(Message::Binary(encoded_data)).await.unwrap();
318                }
319            }
320
321            ws.on_upgrade(move |socket| handle_ping_pong(socket))
322        }
323
324        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
325        TestServer::builder().http_transport().build(app)
326    }
327
328    #[tokio::test]
329    async fn it_should_ping_pong_text_in_text_and_binary() {
330        let server = new_test_app();
331
332        let mut websocket = server
333            .get_websocket(&"/ws-ping-pong")
334            .await
335            .into_websocket()
336            .await;
337
338        websocket.send_text("Hello World!").await;
339
340        websocket.assert_receive_text("Text: Hello World!").await;
341        websocket.assert_receive_text("Binary: Hello World!").await;
342    }
343
344    #[tokio::test]
345    async fn it_should_ping_pong_large_text_blobs() {
346        const LARGE_BLOB_SIZE: usize = 16777200; // Max websocket size (16mb) - 16 bytes for the 'Text: ' in the reply.
347        let large_blob = (0..LARGE_BLOB_SIZE).map(|_| "X").collect::<String>();
348
349        let server = new_test_app();
350        let mut websocket = server
351            .get_websocket(&"/ws-ping-pong")
352            .await
353            .into_websocket()
354            .await;
355
356        websocket.send_text(&large_blob).await;
357
358        websocket
359            .assert_receive_text(format!("Text: {large_blob}"))
360            .await;
361        websocket
362            .assert_receive_text(format!("Binary: {large_blob}"))
363            .await;
364    }
365
366    #[tokio::test]
367    async fn it_should_not_match_partial_text_match() {
368        let server = new_test_app();
369
370        let mut websocket = server
371            .get_websocket(&"/ws-ping-pong")
372            .await
373            .into_websocket()
374            .await;
375
376        websocket.send_text("Hello World!").await;
377
378        let message =
379            catch_panic_error_message_async(websocket.assert_receive_text("Hello World!")).await;
380        assert_error_message(
381            "assertion failed: `(left == right)`
382
383Diff < left / right > :
384<Hello World!
385>Text: Hello World!
386
387",
388            message,
389        );
390    }
391
392    #[tokio::test]
393    async fn it_should_not_match_different_text() {
394        let server = new_test_app();
395
396        let mut websocket = server
397            .get_websocket(&"/ws-ping-pong")
398            .await
399            .into_websocket()
400            .await;
401
402        websocket.send_text("Hello World!").await;
403
404        let message = catch_panic_error_message_async(websocket.assert_receive_text("🦊")).await;
405        assert_error_message(
406            "assertion failed: `(left == right)`
407
408Diff < left / right > :
409<🦊
410>Text: Hello World!
411
412",
413            message,
414        );
415    }
416}
417
418#[cfg(test)]
419mod test_assert_receive_text_contains {
420    use crate::TestServer;
421    use crate::testing::catch_panic_error_message_async;
422    use axum::Router;
423    use axum::extract::WebSocketUpgrade;
424    use axum::extract::ws::Message;
425    use axum::extract::ws::WebSocket;
426    use axum::response::Response;
427    use axum::routing::get;
428    use pretty_assertions::assert_str_eq;
429
430    fn new_test_app() -> TestServer {
431        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
432            async fn handle_ping_pong(mut socket: WebSocket) {
433                while let Some(maybe_message) = socket.recv().await {
434                    let message_text = maybe_message.unwrap().into_text().unwrap();
435                    let encoded_text = format!("Text: {message_text}").try_into().unwrap();
436
437                    socket.send(Message::Text(encoded_text)).await.unwrap();
438                }
439            }
440
441            ws.on_upgrade(move |socket| handle_ping_pong(socket))
442        }
443
444        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
445        TestServer::builder().http_transport().build(app)
446    }
447
448    #[tokio::test]
449    async fn it_should_assert_whole_text_match() {
450        let server = new_test_app();
451
452        let mut websocket = server
453            .get_websocket(&"/ws-ping-pong")
454            .await
455            .into_websocket()
456            .await;
457
458        websocket.send_text("Hello World!").await;
459        websocket
460            .assert_receive_text_contains("Text: Hello World!")
461            .await;
462    }
463
464    #[tokio::test]
465    async fn it_should_assert_partial_text_match() {
466        let server = new_test_app();
467
468        let mut websocket = server
469            .get_websocket(&"/ws-ping-pong")
470            .await
471            .into_websocket()
472            .await;
473
474        websocket.send_text("Hello World!").await;
475        websocket.assert_receive_text_contains("Hello World!").await;
476    }
477
478    #[tokio::test]
479    async fn it_should_not_match_different_text() {
480        let server = new_test_app();
481
482        let mut websocket = server
483            .get_websocket(&"/ws-ping-pong")
484            .await
485            .into_websocket()
486            .await;
487
488        websocket.send_text("Hello World!").await;
489
490        let message =
491            catch_panic_error_message_async(websocket.assert_receive_text_contains("🦊")).await;
492        assert_str_eq!(
493            "Failed to find '🦊', received 'Text: Hello World!'",
494            message
495        );
496    }
497}
498
499#[cfg(test)]
500mod test_assert_receive_json {
501    use crate::TestServer;
502    use crate::testing::ExpectStrMinLen;
503    use crate::testing::catch_panic_error_message_async;
504    use axum::Router;
505    use axum::extract::WebSocketUpgrade;
506    use axum::extract::ws::Message;
507    use axum::extract::ws::WebSocket;
508    use axum::response::Response;
509    use axum::routing::get;
510    use expect_json::expect;
511    use pretty_assertions::assert_str_eq;
512    use serde_json::Value;
513    use serde_json::json;
514
515    fn new_test_app() -> TestServer {
516        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
517            async fn handle_ping_pong(mut socket: WebSocket) {
518                while let Some(maybe_message) = socket.recv().await {
519                    let message_text = maybe_message.unwrap().into_text().unwrap();
520                    let decoded = serde_json::from_str::<Value>(&message_text).unwrap();
521
522                    let encoded_text = serde_json::to_string(&json!({
523                        "format": "text",
524                        "message": decoded
525                    }))
526                    .unwrap()
527                    .try_into()
528                    .unwrap();
529                    let encoded_data = serde_json::to_vec(&json!({
530                        "format": "binary",
531                        "message": decoded
532                    }))
533                    .unwrap()
534                    .into();
535
536                    socket.send(Message::Text(encoded_text)).await.unwrap();
537                    socket.send(Message::Binary(encoded_data)).await.unwrap();
538                }
539            }
540
541            ws.on_upgrade(move |socket| handle_ping_pong(socket))
542        }
543
544        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
545        TestServer::builder().http_transport().build(app)
546    }
547
548    #[tokio::test]
549    async fn it_should_ping_pong_json_in_text_and_binary() {
550        let server = new_test_app();
551
552        let mut websocket = server
553            .get_websocket(&"/ws-ping-pong")
554            .await
555            .into_websocket()
556            .await;
557
558        websocket
559            .send_json(&json!({
560                "hello": "world",
561                "numbers": [1, 2, 3],
562            }))
563            .await;
564
565        // Once for text
566        websocket
567            .assert_receive_json(&json!({
568                "format": "text",
569                "message": {
570                    "hello": "world",
571                    "numbers": [1, 2, 3],
572                },
573            }))
574            .await;
575
576        // Again for binary
577        websocket
578            .assert_receive_json(&json!({
579                "format": "binary",
580                "message": {
581                    "hello": "world",
582                    "numbers": [1, 2, 3],
583                },
584            }))
585            .await;
586    }
587
588    #[tokio::test]
589    async fn it_should_work_with_custom_expect_op() {
590        let server = new_test_app();
591        let mut websocket = server
592            .get_websocket(&"/ws-ping-pong")
593            .await
594            .into_websocket()
595            .await;
596
597        websocket
598            .send_json(&json!({
599                "hello": "world",
600                "numbers": [1, 2, 3],
601            }))
602            .await;
603
604        // Once for text
605        websocket
606            .assert_receive_json(&json!({
607                "format": "text",
608                "message": {
609                    "hello": ExpectStrMinLen { min: 3 },
610                    "numbers": expect::array().len(3).all(expect::integer()),
611                },
612            }))
613            .await;
614
615        // Again for binary
616        websocket
617            .assert_receive_json(&json!({
618                "format": "binary",
619                "message": {
620                    "hello": ExpectStrMinLen { min: 3 },
621                    "numbers": expect::array().len(3).all(expect::integer()),
622                },
623            }))
624            .await;
625    }
626
627    #[tokio::test]
628    async fn it_should_panic_if_custom_expect_op_fails() {
629        let server = new_test_app();
630        let mut websocket = server
631            .get_websocket(&"/ws-ping-pong")
632            .await
633            .into_websocket()
634            .await;
635
636        websocket
637            .send_json(&json!({
638                "hello": "world",
639                "numbers": [1, 2, 3],
640            }))
641            .await;
642
643        // Once for text
644        let message = catch_panic_error_message_async(websocket.assert_receive_json(&json!({
645            "format": "text",
646            "message": {
647                "hello": ExpectStrMinLen { min: 10 },
648                "numbers": expect::array().len(3).all(expect::integer()),
649            },
650        })))
651        .await;
652        assert_str_eq!("String is too short, received: world", message);
653    }
654}
655
656#[cfg(test)]
657mod test_assert_receive_json_contains {
658    use crate::TestServer;
659    use axum::Router;
660    use axum::extract::WebSocketUpgrade;
661    use axum::extract::ws::Message;
662    use axum::extract::ws::WebSocket;
663    use axum::response::Response;
664    use axum::routing::get;
665    use serde_json::Value;
666    use serde_json::json;
667
668    fn new_test_app() -> TestServer {
669        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
670            async fn handle_ping_pong(mut socket: WebSocket) {
671                while let Some(maybe_message) = socket.recv().await {
672                    let message_text = maybe_message.unwrap().into_text().unwrap();
673                    let decoded = serde_json::from_str::<Value>(&message_text).unwrap();
674
675                    let encoded_text = serde_json::to_string(&json!({
676                        "format": "text",
677                        "message": decoded
678                    }))
679                    .unwrap()
680                    .try_into()
681                    .unwrap();
682                    let encoded_data = serde_json::to_vec(&json!({
683                        "format": "binary",
684                        "message": decoded
685                    }))
686                    .unwrap()
687                    .into();
688
689                    socket.send(Message::Text(encoded_text)).await.unwrap();
690                    socket.send(Message::Binary(encoded_data)).await.unwrap();
691                }
692            }
693
694            ws.on_upgrade(move |socket| handle_ping_pong(socket))
695        }
696
697        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
698        TestServer::builder().http_transport().build(app)
699    }
700
701    #[tokio::test]
702    async fn it_should_ping_pong_json_in_text_and_binary_with_root_content_missing_in_contains() {
703        let server = new_test_app();
704
705        let mut websocket = server
706            .get_websocket(&"/ws-ping-pong")
707            .await
708            .into_websocket()
709            .await;
710
711        websocket
712            .send_json(&json!({
713                "hello": "world",
714                "numbers": [1, 2, 3],
715            }))
716            .await;
717
718        // Once for text
719        websocket
720            .assert_receive_json_contains(&json!({
721                // "format" is missing here
722                "message": {
723                    "hello": "world",
724                    "numbers": [1, 2, 3],
725                },
726            }))
727            .await;
728
729        // Again for binary
730        websocket
731            .assert_receive_json_contains(&json!({
732                "format": "binary",
733                // "message" is missing here
734            }))
735            .await;
736    }
737
738    #[tokio::test]
739    async fn it_should_ping_pong_json_in_text_and_binary_with_nested_content_missing_in_contains() {
740        let server = new_test_app();
741
742        let mut websocket = server
743            .get_websocket(&"/ws-ping-pong")
744            .await
745            .into_websocket()
746            .await;
747
748        websocket
749            .send_json(&json!({
750                "hello": "world",
751                "numbers": [1, 2, 3],
752            }))
753            .await;
754
755        // Once for text
756        websocket
757            .assert_receive_json_contains(&json!({
758                "format": "text",
759                "message": {
760                    // "hello" is missing here
761                    "numbers": [1, 2, 3],
762                },
763            }))
764            .await;
765
766        // Again for binary
767        websocket
768            .assert_receive_json_contains(&json!({
769                "format": "binary",
770                "message": {
771                    "hello": "world",
772                    // "numbers" is missing here
773                },
774            }))
775            .await;
776    }
777}
778
779#[cfg(feature = "yaml")]
780#[cfg(test)]
781mod test_assert_receive_yaml {
782    use crate::TestServer;
783
784    use axum::Router;
785    use axum::extract::WebSocketUpgrade;
786    use axum::extract::ws::Message;
787    use axum::extract::ws::WebSocket;
788    use axum::response::Response;
789    use axum::routing::get;
790    use serde_json::Value;
791    use serde_json::json;
792
793    fn new_test_app() -> TestServer {
794        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
795            async fn handle_ping_pong(mut socket: WebSocket) {
796                while let Some(maybe_message) = socket.recv().await {
797                    let message_text = maybe_message.unwrap().into_text().unwrap();
798                    let decoded = serde_yaml::from_str::<Value>(&message_text).unwrap();
799
800                    let encoded_text = serde_yaml::to_string(&json!({
801                        "format": "text",
802                        "message": decoded
803                    }))
804                    .unwrap()
805                    .try_into()
806                    .unwrap();
807                    let encoded_data = serde_yaml::to_string(&json!({
808                        "format": "binary",
809                        "message": decoded
810                    }))
811                    .unwrap()
812                    .into();
813
814                    socket.send(Message::Text(encoded_text)).await.unwrap();
815                    socket.send(Message::Binary(encoded_data)).await.unwrap();
816                }
817            }
818
819            ws.on_upgrade(move |socket| handle_ping_pong(socket))
820        }
821
822        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
823        TestServer::builder().http_transport().build(app)
824    }
825
826    #[tokio::test]
827    async fn it_should_ping_pong_yaml_in_text_and_binary() {
828        let server = new_test_app();
829
830        let mut websocket = server
831            .get_websocket(&"/ws-ping-pong")
832            .await
833            .into_websocket()
834            .await;
835
836        websocket
837            .send_json(&json!({
838                "hello": "world",
839                "numbers": [1, 2, 3],
840            }))
841            .await;
842
843        // Once for text
844        websocket
845            .assert_receive_yaml(&json!({
846                "format": "text",
847                "message": {
848                    "hello": "world",
849                    "numbers": [1, 2, 3],
850                },
851            }))
852            .await;
853
854        // Again for binary
855        websocket
856            .assert_receive_yaml(&json!({
857                "format": "binary",
858                "message": {
859                    "hello": "world",
860                    "numbers": [1, 2, 3],
861                },
862            }))
863            .await;
864    }
865}
866
867#[cfg(feature = "msgpack")]
868#[cfg(test)]
869mod test_assert_receive_msgpack {
870    use crate::TestServer;
871
872    use axum::Router;
873    use axum::extract::WebSocketUpgrade;
874    use axum::extract::ws::Message;
875    use axum::extract::ws::WebSocket;
876    use axum::response::Response;
877    use axum::routing::get;
878    use serde_json::Value;
879    use serde_json::json;
880
881    fn new_test_app() -> TestServer {
882        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
883            async fn handle_ping_pong(mut socket: WebSocket) {
884                while let Some(maybe_message) = socket.recv().await {
885                    let message_data = maybe_message.unwrap().into_data();
886                    let decoded = rmp_serde::from_slice::<Value>(&message_data).unwrap();
887
888                    let encoded_data = ::rmp_serde::to_vec(&json!({
889                        "format": "binary",
890                        "message": decoded
891                    }))
892                    .unwrap()
893                    .into();
894
895                    socket.send(Message::Binary(encoded_data)).await.unwrap();
896                }
897            }
898
899            ws.on_upgrade(move |socket| handle_ping_pong(socket))
900        }
901
902        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
903        TestServer::builder().http_transport().build(app)
904    }
905
906    #[tokio::test]
907    async fn it_should_ping_pong_msgpack_in_binary() {
908        let server = new_test_app();
909
910        let mut websocket = server
911            .get_websocket(&"/ws-ping-pong")
912            .await
913            .into_websocket()
914            .await;
915
916        websocket
917            .send_msgpack(&json!({
918                "hello": "world",
919                "numbers": [1, 2, 3],
920            }))
921            .await;
922
923        websocket
924            .assert_receive_msgpack(&json!({
925                "format": "binary",
926                "message": {
927                    "hello": "world",
928                    "numbers": [1, 2, 3],
929                },
930            }))
931            .await;
932    }
933}
934
935#[cfg(test)]
936mod test_receive_json {
937    use crate::TestServer;
938    use crate::testing::catch_panic_error_message_async;
939    use axum::Router;
940    use axum::extract::WebSocketUpgrade;
941    use axum::extract::ws::Message;
942    use axum::extract::ws::WebSocket;
943    use axum::response::Response;
944    use axum::routing::get;
945    use pretty_assertions::assert_eq;
946    use pretty_assertions::assert_str_eq;
947    use serde::Deserialize;
948    use serde::Serialize;
949
950    #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
951    struct PingPongMessage {
952        ping: String,
953    }
954
955    fn new_test_app() -> TestServer {
956        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
957            async fn handle_ping_pong(mut socket: WebSocket) {
958                while let Some(received_message) = socket.recv().await {
959                    let received = received_message.unwrap();
960                    let received_text = received.to_text().unwrap();
961                    let encoded_text = match received_text {
962                        r#""good""# => serde_json::to_string(&PingPongMessage {
963                            ping: "pong".to_string(),
964                        })
965                        .unwrap(),
966                        r#""bad""# => "🦊".to_string(),
967                        _ => panic!("unknown message given '{received_text}'"),
968                    };
969
970                    socket
971                        .send(Message::Text(encoded_text.into()))
972                        .await
973                        .unwrap();
974                }
975            }
976
977            ws.on_upgrade(move |socket| handle_ping_pong(socket))
978        }
979
980        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
981        TestServer::builder().http_transport().build(app)
982    }
983
984    #[tokio::test]
985    async fn it_should_parse_when_correct_structure() {
986        let server = new_test_app();
987
988        let mut websocket = server
989            .get_websocket(&"/ws-ping-pong")
990            .await
991            .into_websocket()
992            .await;
993
994        websocket.send_json(&"good").await;
995
996        let received = websocket.receive_json::<PingPongMessage>().await;
997
998        assert_eq!(
999            received,
1000            PingPongMessage {
1001                ping: "pong".to_string(),
1002            }
1003        );
1004    }
1005
1006    #[tokio::test]
1007    async fn it_should_display_error_with_body_on_parse_fail() {
1008        let server = new_test_app();
1009
1010        let mut websocket = server
1011            .get_websocket(&"/ws-ping-pong")
1012            .await
1013            .into_websocket()
1014            .await;
1015
1016        websocket.send_json(&"bad").await;
1017
1018        let error_message =
1019            catch_panic_error_message_async(websocket.receive_json::<PingPongMessage>()).await;
1020
1021        assert_str_eq!(
1022            r#"Failed to deserialize Json websocket response,
1023    expected value at line 1 column 1
1024
1025received:
1026    🦊
1027"#,
1028            error_message
1029        );
1030    }
1031}