1use crate::WsMessage;
2use crate::internals::ErrorMessage;
3use anyhow::Result;
4use anyhow::anyhow;
5use bytes::Bytes;
6use expect_json::expect;
7use expect_json::expect_json_eq;
8use futures_util::sink::SinkExt;
9use futures_util::stream::StreamExt;
10use hyper::upgrade::Upgraded;
11use hyper_util::rt::TokioIo;
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use serde_json::Value;
15use std::fmt::Debug;
16use std::fmt::Display;
17use tokio_tungstenite::WebSocketStream;
18use tokio_tungstenite::tungstenite::protocol::Role;
19
20#[cfg(feature = "pretty-assertions")]
21use pretty_assertions::assert_eq;
22
23#[derive(Debug)]
24pub struct TestWebSocket {
25 stream: WebSocketStream<TokioIo<Upgraded>>,
26}
27
28impl TestWebSocket {
29 pub(crate) async fn new(upgraded: Upgraded) -> Self {
30 let upgraded_io = TokioIo::new(upgraded);
31 let stream = WebSocketStream::from_raw_socket(upgraded_io, Role::Client, None).await;
32
33 Self { stream }
34 }
35
36 pub async fn close(mut self) {
37 self.stream
38 .close(None)
39 .await
40 .error_message("Failed to close WebSocket stream");
41 }
42
43 pub async fn send_text<T>(&mut self, raw_text: T) -> &mut Self
44 where
45 T: Display,
46 {
47 let text = raw_text.to_string();
48 self.send_message(WsMessage::Text(text.into())).await;
49
50 self
51 }
52
53 pub async fn send_json<J>(&mut self, body: &J) -> &mut Self
54 where
55 J: ?Sized + Serialize,
56 {
57 let raw_json = ::serde_json::to_string(body).error_message("Failed to serialize into Json");
58
59 self.send_message(WsMessage::Text(raw_json.into())).await;
60
61 self
62 }
63
64 #[cfg(feature = "yaml")]
65 pub async fn send_yaml<Y>(&mut self, body: &Y) -> &mut Self
66 where
67 Y: ?Sized + Serialize,
68 {
69 let raw_yaml = ::serde_yaml::to_string(body).error_message("Failed to serialize into Yaml");
70
71 self.send_message(WsMessage::Text(raw_yaml.into())).await;
72
73 self
74 }
75
76 #[cfg(feature = "msgpack")]
77 pub async fn send_msgpack<M>(&mut self, body: &M) -> &mut Self
78 where
79 M: ?Sized + Serialize,
80 {
81 let body_bytes =
82 ::rmp_serde::to_vec(body).error_message("Failed to serialize into MsgPack");
83
84 self.send_message(WsMessage::Binary(body_bytes.into()))
85 .await;
86
87 self
88 }
89
90 pub async fn send_message(&mut self, message: WsMessage) -> &mut Self {
91 self.stream
92 .send(message)
93 .await
94 .error_message("Failed to send websocket message");
95
96 self
97 }
98
99 #[must_use]
100 pub async fn receive_text(&mut self) -> String {
101 let message = self.receive_message().await;
102
103 message_to_text(message).error_message("Failed to receive websocket response as text")
104 }
105
106 #[must_use]
107 pub async fn receive_json<T>(&mut self) -> T
108 where
109 T: DeserializeOwned,
110 {
111 let bytes = self.receive_bytes().await;
112 serde_json::from_slice::<T>(&bytes)
113 .error_message_with_body("Failed to deserialize Json websocket response", &bytes)
114 }
115
116 #[cfg(feature = "yaml")]
117 #[must_use]
118 pub async fn receive_yaml<T>(&mut self) -> T
119 where
120 T: DeserializeOwned,
121 {
122 let bytes = self.receive_bytes().await;
123 serde_yaml::from_slice::<T>(&bytes)
124 .error_message_with_body("Failed to deserialize Yaml websocket response", &bytes)
125 }
126
127 #[cfg(feature = "msgpack")]
128 #[must_use]
129 pub async fn receive_msgpack<T>(&mut self) -> T
130 where
131 T: DeserializeOwned,
132 {
133 let received_bytes = self.receive_bytes().await;
134 rmp_serde::from_slice::<T>(&received_bytes)
135 .error_message("Failed to deserialize MsgPack websocket response")
136 }
137
138 #[must_use]
139 pub async fn receive_bytes(&mut self) -> Bytes {
140 let message = self.receive_message().await;
141 message_to_bytes(message).error_message("Failed to receive websocket response as bytes")
142 }
143
144 #[must_use]
145 pub async fn receive_message(&mut self) -> WsMessage {
146 self.maybe_receive_message()
147 .await
148 .expect("No message found on WebSocket stream")
149 }
150
151 #[must_use]
152 async fn maybe_receive_message(&mut self) -> Option<WsMessage> {
153 let maybe_message = self.stream.next().await;
154
155 match maybe_message {
156 None => None,
157 Some(message_result) => {
158 let message =
159 message_result.error_message("Failed to receive message from WebSocket stream");
160
161 Some(message)
162 }
163 }
164 }
165
166 pub async fn assert_receive_json<T>(&mut self, expected: &T) -> &mut Self
167 where
168 T: Serialize + DeserializeOwned + PartialEq<T> + Debug,
169 {
170 let received = self.receive_json::<T>().await;
171
172 if *expected != received {
173 if let Err(error) = expect_json_eq(&received, &expected) {
174 panic!(
175 "
176{error:?}
177",
178 );
179 }
180 }
181
182 self
183 }
184
185 pub async fn assert_receive_json_contains<T>(&mut self, expected: &T) -> &mut Self
186 where
187 T: Serialize,
188 {
189 let received = self.receive_json::<Value>().await;
190 let expected_value = serde_json::to_value(expected).unwrap();
191 let result = expect_json_eq(
192 &received,
193 &expect::object().propagated_contains(expected_value),
194 );
195 if let Err(error) = result {
196 panic!(
197 "
198{error}
199",
200 );
201 }
202
203 self
204 }
205
206 pub async fn assert_receive_text<C>(&mut self, expected: C) -> &mut Self
207 where
208 C: AsRef<str>,
209 {
210 let expected_contents = expected.as_ref();
211 assert_eq!(expected_contents, &self.receive_text().await);
212
213 self
214 }
215
216 pub async fn assert_receive_text_contains<C>(&mut self, expected: C) -> &mut Self
217 where
218 C: AsRef<str>,
219 {
220 let expected_contents = expected.as_ref();
221 let received = self.receive_text().await;
222 let is_contained = received.contains(expected_contents);
223
224 assert!(
225 is_contained,
226 "Failed to find '{expected_contents}', received '{received}'"
227 );
228
229 self
230 }
231
232 #[cfg(feature = "yaml")]
233 pub async fn assert_receive_yaml<T>(&mut self, expected: &T) -> &mut Self
234 where
235 T: DeserializeOwned + PartialEq<T> + Debug,
236 {
237 assert_eq!(*expected, self.receive_yaml::<T>().await);
238
239 self
240 }
241
242 #[cfg(feature = "msgpack")]
243 pub async fn assert_receive_msgpack<T>(&mut self, expected: &T) -> &mut Self
244 where
245 T: DeserializeOwned + PartialEq<T> + Debug,
246 {
247 assert_eq!(*expected, self.receive_msgpack::<T>().await);
248
249 self
250 }
251}
252
253fn message_to_text(message: WsMessage) -> Result<String> {
254 let text = match message {
255 WsMessage::Text(text) => text.to_string(),
256 WsMessage::Binary(data) => {
257 String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())?
258 }
259 WsMessage::Ping(data) => {
260 String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())?
261 }
262 WsMessage::Pong(data) => {
263 String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())?
264 }
265 WsMessage::Close(None) => String::new(),
266 WsMessage::Close(Some(frame)) => frame.reason.to_string(),
267 WsMessage::Frame(_) => {
268 return Err(anyhow!(
269 "Unexpected Frame, did not expect Frame message whilst reading"
270 ));
271 }
272 };
273
274 Ok(text)
275}
276
277fn message_to_bytes(message: WsMessage) -> Result<Bytes> {
278 let bytes = match message {
279 WsMessage::Text(string) => string.into(),
280 WsMessage::Binary(data) => data,
281 WsMessage::Ping(data) => data,
282 WsMessage::Pong(data) => data,
283 WsMessage::Close(None) => Bytes::new(),
284 WsMessage::Close(Some(frame)) => frame.reason.into(),
285 WsMessage::Frame(_) => {
286 return Err(anyhow!(
287 "Unexpected Frame, did not expect Frame message whilst reading"
288 ));
289 }
290 };
291
292 Ok(bytes)
293}
294
295#[cfg(test)]
296mod test_assert_receive_text {
297 use crate::TestServer;
298 use crate::testing::assert_error_message;
299 use crate::testing::catch_panic_error_message_async;
300 use axum::Router;
301 use axum::extract::WebSocketUpgrade;
302 use axum::extract::ws::Message;
303 use axum::extract::ws::WebSocket;
304 use axum::response::Response;
305 use axum::routing::get;
306
307 fn new_test_app() -> TestServer {
308 pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
309 async fn handle_ping_pong(mut socket: WebSocket) {
310 while let Some(maybe_message) = socket.recv().await {
311 let message_text = maybe_message.unwrap().into_text().unwrap();
312
313 let encoded_text = format!("Text: {message_text}").try_into().unwrap();
314 let encoded_data = format!("Binary: {message_text}").into_bytes().into();
315
316 socket.send(Message::Text(encoded_text)).await.unwrap();
317 socket.send(Message::Binary(encoded_data)).await.unwrap();
318 }
319 }
320
321 ws.on_upgrade(move |socket| handle_ping_pong(socket))
322 }
323
324 let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
325 TestServer::builder().http_transport().build(app)
326 }
327
328 #[tokio::test]
329 async fn it_should_ping_pong_text_in_text_and_binary() {
330 let server = new_test_app();
331
332 let mut websocket = server
333 .get_websocket(&"/ws-ping-pong")
334 .await
335 .into_websocket()
336 .await;
337
338 websocket.send_text("Hello World!").await;
339
340 websocket.assert_receive_text("Text: Hello World!").await;
341 websocket.assert_receive_text("Binary: Hello World!").await;
342 }
343
344 #[tokio::test]
345 async fn it_should_ping_pong_large_text_blobs() {
346 const LARGE_BLOB_SIZE: usize = 16777200; let large_blob = (0..LARGE_BLOB_SIZE).map(|_| "X").collect::<String>();
348
349 let server = new_test_app();
350 let mut websocket = server
351 .get_websocket(&"/ws-ping-pong")
352 .await
353 .into_websocket()
354 .await;
355
356 websocket.send_text(&large_blob).await;
357
358 websocket
359 .assert_receive_text(format!("Text: {large_blob}"))
360 .await;
361 websocket
362 .assert_receive_text(format!("Binary: {large_blob}"))
363 .await;
364 }
365
366 #[tokio::test]
367 async fn it_should_not_match_partial_text_match() {
368 let server = new_test_app();
369
370 let mut websocket = server
371 .get_websocket(&"/ws-ping-pong")
372 .await
373 .into_websocket()
374 .await;
375
376 websocket.send_text("Hello World!").await;
377
378 let message =
379 catch_panic_error_message_async(websocket.assert_receive_text("Hello World!")).await;
380 assert_error_message(
381 "assertion failed: `(left == right)`
382
383Diff < left / right > :
384<Hello World!
385>Text: Hello World!
386
387",
388 message,
389 );
390 }
391
392 #[tokio::test]
393 async fn it_should_not_match_different_text() {
394 let server = new_test_app();
395
396 let mut websocket = server
397 .get_websocket(&"/ws-ping-pong")
398 .await
399 .into_websocket()
400 .await;
401
402 websocket.send_text("Hello World!").await;
403
404 let message = catch_panic_error_message_async(websocket.assert_receive_text("🦊")).await;
405 assert_error_message(
406 "assertion failed: `(left == right)`
407
408Diff < left / right > :
409<🦊
410>Text: Hello World!
411
412",
413 message,
414 );
415 }
416}
417
418#[cfg(test)]
419mod test_assert_receive_text_contains {
420 use crate::TestServer;
421 use crate::testing::catch_panic_error_message_async;
422 use axum::Router;
423 use axum::extract::WebSocketUpgrade;
424 use axum::extract::ws::Message;
425 use axum::extract::ws::WebSocket;
426 use axum::response::Response;
427 use axum::routing::get;
428 use pretty_assertions::assert_str_eq;
429
430 fn new_test_app() -> TestServer {
431 pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
432 async fn handle_ping_pong(mut socket: WebSocket) {
433 while let Some(maybe_message) = socket.recv().await {
434 let message_text = maybe_message.unwrap().into_text().unwrap();
435 let encoded_text = format!("Text: {message_text}").try_into().unwrap();
436
437 socket.send(Message::Text(encoded_text)).await.unwrap();
438 }
439 }
440
441 ws.on_upgrade(move |socket| handle_ping_pong(socket))
442 }
443
444 let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
445 TestServer::builder().http_transport().build(app)
446 }
447
448 #[tokio::test]
449 async fn it_should_assert_whole_text_match() {
450 let server = new_test_app();
451
452 let mut websocket = server
453 .get_websocket(&"/ws-ping-pong")
454 .await
455 .into_websocket()
456 .await;
457
458 websocket.send_text("Hello World!").await;
459 websocket
460 .assert_receive_text_contains("Text: Hello World!")
461 .await;
462 }
463
464 #[tokio::test]
465 async fn it_should_assert_partial_text_match() {
466 let server = new_test_app();
467
468 let mut websocket = server
469 .get_websocket(&"/ws-ping-pong")
470 .await
471 .into_websocket()
472 .await;
473
474 websocket.send_text("Hello World!").await;
475 websocket.assert_receive_text_contains("Hello World!").await;
476 }
477
478 #[tokio::test]
479 async fn it_should_not_match_different_text() {
480 let server = new_test_app();
481
482 let mut websocket = server
483 .get_websocket(&"/ws-ping-pong")
484 .await
485 .into_websocket()
486 .await;
487
488 websocket.send_text("Hello World!").await;
489
490 let message =
491 catch_panic_error_message_async(websocket.assert_receive_text_contains("🦊")).await;
492 assert_str_eq!(
493 "Failed to find '🦊', received 'Text: Hello World!'",
494 message
495 );
496 }
497}
498
499#[cfg(test)]
500mod test_assert_receive_json {
501 use crate::TestServer;
502 use crate::testing::ExpectStrMinLen;
503 use crate::testing::catch_panic_error_message_async;
504 use axum::Router;
505 use axum::extract::WebSocketUpgrade;
506 use axum::extract::ws::Message;
507 use axum::extract::ws::WebSocket;
508 use axum::response::Response;
509 use axum::routing::get;
510 use expect_json::expect;
511 use pretty_assertions::assert_str_eq;
512 use serde_json::Value;
513 use serde_json::json;
514
515 fn new_test_app() -> TestServer {
516 pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
517 async fn handle_ping_pong(mut socket: WebSocket) {
518 while let Some(maybe_message) = socket.recv().await {
519 let message_text = maybe_message.unwrap().into_text().unwrap();
520 let decoded = serde_json::from_str::<Value>(&message_text).unwrap();
521
522 let encoded_text = serde_json::to_string(&json!({
523 "format": "text",
524 "message": decoded
525 }))
526 .unwrap()
527 .try_into()
528 .unwrap();
529 let encoded_data = serde_json::to_vec(&json!({
530 "format": "binary",
531 "message": decoded
532 }))
533 .unwrap()
534 .into();
535
536 socket.send(Message::Text(encoded_text)).await.unwrap();
537 socket.send(Message::Binary(encoded_data)).await.unwrap();
538 }
539 }
540
541 ws.on_upgrade(move |socket| handle_ping_pong(socket))
542 }
543
544 let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
545 TestServer::builder().http_transport().build(app)
546 }
547
548 #[tokio::test]
549 async fn it_should_ping_pong_json_in_text_and_binary() {
550 let server = new_test_app();
551
552 let mut websocket = server
553 .get_websocket(&"/ws-ping-pong")
554 .await
555 .into_websocket()
556 .await;
557
558 websocket
559 .send_json(&json!({
560 "hello": "world",
561 "numbers": [1, 2, 3],
562 }))
563 .await;
564
565 websocket
567 .assert_receive_json(&json!({
568 "format": "text",
569 "message": {
570 "hello": "world",
571 "numbers": [1, 2, 3],
572 },
573 }))
574 .await;
575
576 websocket
578 .assert_receive_json(&json!({
579 "format": "binary",
580 "message": {
581 "hello": "world",
582 "numbers": [1, 2, 3],
583 },
584 }))
585 .await;
586 }
587
588 #[tokio::test]
589 async fn it_should_work_with_custom_expect_op() {
590 let server = new_test_app();
591 let mut websocket = server
592 .get_websocket(&"/ws-ping-pong")
593 .await
594 .into_websocket()
595 .await;
596
597 websocket
598 .send_json(&json!({
599 "hello": "world",
600 "numbers": [1, 2, 3],
601 }))
602 .await;
603
604 websocket
606 .assert_receive_json(&json!({
607 "format": "text",
608 "message": {
609 "hello": ExpectStrMinLen { min: 3 },
610 "numbers": expect::array().len(3).all(expect::integer()),
611 },
612 }))
613 .await;
614
615 websocket
617 .assert_receive_json(&json!({
618 "format": "binary",
619 "message": {
620 "hello": ExpectStrMinLen { min: 3 },
621 "numbers": expect::array().len(3).all(expect::integer()),
622 },
623 }))
624 .await;
625 }
626
627 #[tokio::test]
628 async fn it_should_panic_if_custom_expect_op_fails() {
629 let server = new_test_app();
630 let mut websocket = server
631 .get_websocket(&"/ws-ping-pong")
632 .await
633 .into_websocket()
634 .await;
635
636 websocket
637 .send_json(&json!({
638 "hello": "world",
639 "numbers": [1, 2, 3],
640 }))
641 .await;
642
643 let message = catch_panic_error_message_async(websocket.assert_receive_json(&json!({
645 "format": "text",
646 "message": {
647 "hello": ExpectStrMinLen { min: 10 },
648 "numbers": expect::array().len(3).all(expect::integer()),
649 },
650 })))
651 .await;
652 assert_str_eq!("String is too short, received: world", message);
653 }
654}
655
656#[cfg(test)]
657mod test_assert_receive_json_contains {
658 use crate::TestServer;
659 use axum::Router;
660 use axum::extract::WebSocketUpgrade;
661 use axum::extract::ws::Message;
662 use axum::extract::ws::WebSocket;
663 use axum::response::Response;
664 use axum::routing::get;
665 use serde_json::Value;
666 use serde_json::json;
667
668 fn new_test_app() -> TestServer {
669 pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
670 async fn handle_ping_pong(mut socket: WebSocket) {
671 while let Some(maybe_message) = socket.recv().await {
672 let message_text = maybe_message.unwrap().into_text().unwrap();
673 let decoded = serde_json::from_str::<Value>(&message_text).unwrap();
674
675 let encoded_text = serde_json::to_string(&json!({
676 "format": "text",
677 "message": decoded
678 }))
679 .unwrap()
680 .try_into()
681 .unwrap();
682 let encoded_data = serde_json::to_vec(&json!({
683 "format": "binary",
684 "message": decoded
685 }))
686 .unwrap()
687 .into();
688
689 socket.send(Message::Text(encoded_text)).await.unwrap();
690 socket.send(Message::Binary(encoded_data)).await.unwrap();
691 }
692 }
693
694 ws.on_upgrade(move |socket| handle_ping_pong(socket))
695 }
696
697 let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
698 TestServer::builder().http_transport().build(app)
699 }
700
701 #[tokio::test]
702 async fn it_should_ping_pong_json_in_text_and_binary_with_root_content_missing_in_contains() {
703 let server = new_test_app();
704
705 let mut websocket = server
706 .get_websocket(&"/ws-ping-pong")
707 .await
708 .into_websocket()
709 .await;
710
711 websocket
712 .send_json(&json!({
713 "hello": "world",
714 "numbers": [1, 2, 3],
715 }))
716 .await;
717
718 websocket
720 .assert_receive_json_contains(&json!({
721 "message": {
723 "hello": "world",
724 "numbers": [1, 2, 3],
725 },
726 }))
727 .await;
728
729 websocket
731 .assert_receive_json_contains(&json!({
732 "format": "binary",
733 }))
735 .await;
736 }
737
738 #[tokio::test]
739 async fn it_should_ping_pong_json_in_text_and_binary_with_nested_content_missing_in_contains() {
740 let server = new_test_app();
741
742 let mut websocket = server
743 .get_websocket(&"/ws-ping-pong")
744 .await
745 .into_websocket()
746 .await;
747
748 websocket
749 .send_json(&json!({
750 "hello": "world",
751 "numbers": [1, 2, 3],
752 }))
753 .await;
754
755 websocket
757 .assert_receive_json_contains(&json!({
758 "format": "text",
759 "message": {
760 "numbers": [1, 2, 3],
762 },
763 }))
764 .await;
765
766 websocket
768 .assert_receive_json_contains(&json!({
769 "format": "binary",
770 "message": {
771 "hello": "world",
772 },
774 }))
775 .await;
776 }
777}
778
779#[cfg(feature = "yaml")]
780#[cfg(test)]
781mod test_assert_receive_yaml {
782 use crate::TestServer;
783
784 use axum::Router;
785 use axum::extract::WebSocketUpgrade;
786 use axum::extract::ws::Message;
787 use axum::extract::ws::WebSocket;
788 use axum::response::Response;
789 use axum::routing::get;
790 use serde_json::Value;
791 use serde_json::json;
792
793 fn new_test_app() -> TestServer {
794 pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
795 async fn handle_ping_pong(mut socket: WebSocket) {
796 while let Some(maybe_message) = socket.recv().await {
797 let message_text = maybe_message.unwrap().into_text().unwrap();
798 let decoded = serde_yaml::from_str::<Value>(&message_text).unwrap();
799
800 let encoded_text = serde_yaml::to_string(&json!({
801 "format": "text",
802 "message": decoded
803 }))
804 .unwrap()
805 .try_into()
806 .unwrap();
807 let encoded_data = serde_yaml::to_string(&json!({
808 "format": "binary",
809 "message": decoded
810 }))
811 .unwrap()
812 .into();
813
814 socket.send(Message::Text(encoded_text)).await.unwrap();
815 socket.send(Message::Binary(encoded_data)).await.unwrap();
816 }
817 }
818
819 ws.on_upgrade(move |socket| handle_ping_pong(socket))
820 }
821
822 let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
823 TestServer::builder().http_transport().build(app)
824 }
825
826 #[tokio::test]
827 async fn it_should_ping_pong_yaml_in_text_and_binary() {
828 let server = new_test_app();
829
830 let mut websocket = server
831 .get_websocket(&"/ws-ping-pong")
832 .await
833 .into_websocket()
834 .await;
835
836 websocket
837 .send_json(&json!({
838 "hello": "world",
839 "numbers": [1, 2, 3],
840 }))
841 .await;
842
843 websocket
845 .assert_receive_yaml(&json!({
846 "format": "text",
847 "message": {
848 "hello": "world",
849 "numbers": [1, 2, 3],
850 },
851 }))
852 .await;
853
854 websocket
856 .assert_receive_yaml(&json!({
857 "format": "binary",
858 "message": {
859 "hello": "world",
860 "numbers": [1, 2, 3],
861 },
862 }))
863 .await;
864 }
865}
866
867#[cfg(feature = "msgpack")]
868#[cfg(test)]
869mod test_assert_receive_msgpack {
870 use crate::TestServer;
871
872 use axum::Router;
873 use axum::extract::WebSocketUpgrade;
874 use axum::extract::ws::Message;
875 use axum::extract::ws::WebSocket;
876 use axum::response::Response;
877 use axum::routing::get;
878 use serde_json::Value;
879 use serde_json::json;
880
881 fn new_test_app() -> TestServer {
882 pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
883 async fn handle_ping_pong(mut socket: WebSocket) {
884 while let Some(maybe_message) = socket.recv().await {
885 let message_data = maybe_message.unwrap().into_data();
886 let decoded = rmp_serde::from_slice::<Value>(&message_data).unwrap();
887
888 let encoded_data = ::rmp_serde::to_vec(&json!({
889 "format": "binary",
890 "message": decoded
891 }))
892 .unwrap()
893 .into();
894
895 socket.send(Message::Binary(encoded_data)).await.unwrap();
896 }
897 }
898
899 ws.on_upgrade(move |socket| handle_ping_pong(socket))
900 }
901
902 let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
903 TestServer::builder().http_transport().build(app)
904 }
905
906 #[tokio::test]
907 async fn it_should_ping_pong_msgpack_in_binary() {
908 let server = new_test_app();
909
910 let mut websocket = server
911 .get_websocket(&"/ws-ping-pong")
912 .await
913 .into_websocket()
914 .await;
915
916 websocket
917 .send_msgpack(&json!({
918 "hello": "world",
919 "numbers": [1, 2, 3],
920 }))
921 .await;
922
923 websocket
924 .assert_receive_msgpack(&json!({
925 "format": "binary",
926 "message": {
927 "hello": "world",
928 "numbers": [1, 2, 3],
929 },
930 }))
931 .await;
932 }
933}
934
935#[cfg(test)]
936mod test_receive_json {
937 use crate::TestServer;
938 use crate::testing::catch_panic_error_message_async;
939 use axum::Router;
940 use axum::extract::WebSocketUpgrade;
941 use axum::extract::ws::Message;
942 use axum::extract::ws::WebSocket;
943 use axum::response::Response;
944 use axum::routing::get;
945 use pretty_assertions::assert_eq;
946 use pretty_assertions::assert_str_eq;
947 use serde::Deserialize;
948 use serde::Serialize;
949
950 #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
951 struct PingPongMessage {
952 ping: String,
953 }
954
955 fn new_test_app() -> TestServer {
956 pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
957 async fn handle_ping_pong(mut socket: WebSocket) {
958 while let Some(received_message) = socket.recv().await {
959 let received = received_message.unwrap();
960 let received_text = received.to_text().unwrap();
961 let encoded_text = match received_text {
962 r#""good""# => serde_json::to_string(&PingPongMessage {
963 ping: "pong".to_string(),
964 })
965 .unwrap(),
966 r#""bad""# => "🦊".to_string(),
967 _ => panic!("unknown message given '{received_text}'"),
968 };
969
970 socket
971 .send(Message::Text(encoded_text.into()))
972 .await
973 .unwrap();
974 }
975 }
976
977 ws.on_upgrade(move |socket| handle_ping_pong(socket))
978 }
979
980 let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
981 TestServer::builder().http_transport().build(app)
982 }
983
984 #[tokio::test]
985 async fn it_should_parse_when_correct_structure() {
986 let server = new_test_app();
987
988 let mut websocket = server
989 .get_websocket(&"/ws-ping-pong")
990 .await
991 .into_websocket()
992 .await;
993
994 websocket.send_json(&"good").await;
995
996 let received = websocket.receive_json::<PingPongMessage>().await;
997
998 assert_eq!(
999 received,
1000 PingPongMessage {
1001 ping: "pong".to_string(),
1002 }
1003 );
1004 }
1005
1006 #[tokio::test]
1007 async fn it_should_display_error_with_body_on_parse_fail() {
1008 let server = new_test_app();
1009
1010 let mut websocket = server
1011 .get_websocket(&"/ws-ping-pong")
1012 .await
1013 .into_websocket()
1014 .await;
1015
1016 websocket.send_json(&"bad").await;
1017
1018 let error_message =
1019 catch_panic_error_message_async(websocket.receive_json::<PingPongMessage>()).await;
1020
1021 assert_str_eq!(
1022 r#"Failed to deserialize Json websocket response,
1023 expected value at line 1 column 1
1024
1025received:
1026 🦊
1027"#,
1028 error_message
1029 );
1030 }
1031}