axum_test/
test_web_socket.rs

1use crate::WsMessage;
2use anyhow::anyhow;
3use anyhow::Context;
4use anyhow::Result;
5use bytes::Bytes;
6use futures_util::sink::SinkExt;
7use futures_util::stream::StreamExt;
8use hyper::upgrade::Upgraded;
9use hyper_util::rt::TokioIo;
10use serde::de::DeserializeOwned;
11use serde::Serialize;
12use std::fmt::Debug;
13use std::fmt::Display;
14use tokio_tungstenite::tungstenite::protocol::Role;
15use tokio_tungstenite::WebSocketStream;
16
17#[cfg(feature = "pretty-assertions")]
18use pretty_assertions::assert_eq;
19
20#[derive(Debug)]
21pub struct TestWebSocket {
22    stream: WebSocketStream<TokioIo<Upgraded>>,
23}
24
25impl TestWebSocket {
26    pub(crate) async fn new(upgraded: Upgraded) -> Self {
27        let upgraded_io = TokioIo::new(upgraded);
28        let stream = WebSocketStream::from_raw_socket(upgraded_io, Role::Client, None).await;
29
30        Self { stream }
31    }
32
33    pub async fn close(mut self) {
34        self.stream
35            .close(None)
36            .await
37            .expect("Failed to close WebSocket stream");
38    }
39
40    pub async fn send_text<T>(&mut self, raw_text: T)
41    where
42        T: Display,
43    {
44        let text = raw_text.to_string();
45        self.send_message(WsMessage::Text(text.into())).await;
46    }
47
48    pub async fn send_json<J>(&mut self, body: &J)
49    where
50        J: ?Sized + Serialize,
51    {
52        let raw_json =
53            ::serde_json::to_string(body).expect("It should serialize the content into Json");
54
55        self.send_message(WsMessage::Text(raw_json.into())).await;
56    }
57
58    #[cfg(feature = "yaml")]
59    pub async fn send_yaml<Y>(&mut self, body: &Y)
60    where
61        Y: ?Sized + Serialize,
62    {
63        let raw_yaml =
64            ::serde_yaml::to_string(body).expect("It should serialize the content into Yaml");
65
66        self.send_message(WsMessage::Text(raw_yaml.into())).await;
67    }
68
69    #[cfg(feature = "msgpack")]
70    pub async fn send_msgpack<M>(&mut self, body: &M)
71    where
72        M: ?Sized + Serialize,
73    {
74        let body_bytes =
75            ::rmp_serde::to_vec(body).expect("It should serialize the content into MsgPack");
76
77        self.send_message(WsMessage::Binary(body_bytes.into()))
78            .await;
79    }
80
81    pub async fn send_message(&mut self, message: WsMessage) {
82        self.stream.send(message).await.unwrap();
83    }
84
85    #[must_use]
86    pub async fn receive_text(&mut self) -> String {
87        let message = self.receive_message().await;
88
89        message_to_text(message)
90            .context("Failed to read message as a String")
91            .unwrap()
92    }
93
94    #[must_use]
95    pub async fn receive_json<T>(&mut self) -> T
96    where
97        T: DeserializeOwned,
98    {
99        let bytes = self.receive_bytes().await;
100        serde_json::from_slice::<T>(&bytes)
101            .context("Failed to deserialize message as Json")
102            .unwrap()
103    }
104
105    #[cfg(feature = "yaml")]
106    #[must_use]
107    pub async fn receive_yaml<T>(&mut self) -> T
108    where
109        T: DeserializeOwned,
110    {
111        let bytes = self.receive_bytes().await;
112        serde_yaml::from_slice::<T>(&bytes)
113            .context("Failed to deserialize message as Yaml")
114            .unwrap()
115    }
116
117    #[cfg(feature = "msgpack")]
118    #[must_use]
119    pub async fn receive_msgpack<T>(&mut self) -> T
120    where
121        T: DeserializeOwned,
122    {
123        let received_bytes = self.receive_bytes().await;
124        rmp_serde::from_slice::<T>(&received_bytes)
125            .context("Failed to deserializing message as MsgPack")
126            .unwrap()
127    }
128
129    #[must_use]
130    pub async fn receive_bytes(&mut self) -> Bytes {
131        let message = self.receive_message().await;
132
133        message_to_bytes(message)
134            .context("Failed to read message as a Bytes")
135            .unwrap()
136    }
137
138    #[must_use]
139    pub async fn receive_message(&mut self) -> WsMessage {
140        self.maybe_receive_message()
141            .await
142            .expect("No message found on WebSocket stream")
143    }
144
145    pub async fn assert_receive_json<T>(&mut self, expected: &T)
146    where
147        T: DeserializeOwned + PartialEq<T> + Debug,
148    {
149        assert_eq!(*expected, self.receive_json::<T>().await);
150    }
151
152    pub async fn assert_receive_text<C>(&mut self, expected: C)
153    where
154        C: AsRef<str>,
155    {
156        let expected_contents = expected.as_ref();
157        assert_eq!(expected_contents, &self.receive_text().await);
158    }
159
160    pub async fn assert_receive_text_contains<C>(&mut self, expected: C)
161    where
162        C: AsRef<str>,
163    {
164        let expected_contents = expected.as_ref();
165        let received = self.receive_text().await;
166        let is_contained = received.contains(expected_contents);
167
168        assert!(
169            is_contained,
170            "Failed to find '{expected_contents}', received '{received}'"
171        );
172    }
173
174    #[cfg(feature = "yaml")]
175    pub async fn assert_receive_yaml<T>(&mut self, expected: &T)
176    where
177        T: DeserializeOwned + PartialEq<T> + Debug,
178    {
179        assert_eq!(*expected, self.receive_yaml::<T>().await);
180    }
181
182    #[cfg(feature = "msgpack")]
183    pub async fn assert_receive_msgpack<T>(&mut self, expected: &T)
184    where
185        T: DeserializeOwned + PartialEq<T> + Debug,
186    {
187        assert_eq!(*expected, self.receive_msgpack::<T>().await);
188    }
189
190    #[must_use]
191    async fn maybe_receive_message(&mut self) -> Option<WsMessage> {
192        let maybe_message = self.stream.next().await;
193
194        match maybe_message {
195            None => None,
196            Some(message_result) => {
197                let message =
198                    message_result.expect("Failed to receive message from WebSocket stream");
199                Some(message)
200            }
201        }
202    }
203}
204
205fn message_to_text(message: WsMessage) -> Result<String> {
206    let text = match message {
207        WsMessage::Text(text) => text.to_string(),
208        WsMessage::Binary(data) => {
209            String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())?
210        }
211        WsMessage::Ping(data) => {
212            String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())?
213        }
214        WsMessage::Pong(data) => {
215            String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())?
216        }
217        WsMessage::Close(None) => String::new(),
218        WsMessage::Close(Some(frame)) => frame.reason.to_string(),
219        WsMessage::Frame(_) => {
220            return Err(anyhow!(
221                "Unexpected Frame, did not expect Frame message whilst reading"
222            ))
223        }
224    };
225
226    Ok(text)
227}
228
229fn message_to_bytes(message: WsMessage) -> Result<Bytes> {
230    let bytes = match message {
231        WsMessage::Text(string) => string.into(),
232        WsMessage::Binary(data) => data,
233        WsMessage::Ping(data) => data,
234        WsMessage::Pong(data) => data,
235        WsMessage::Close(None) => Bytes::new(),
236        WsMessage::Close(Some(frame)) => frame.reason.into(),
237        WsMessage::Frame(_) => {
238            return Err(anyhow!(
239                "Unexpected Frame, did not expect Frame message whilst reading"
240            ))
241        }
242    };
243
244    Ok(bytes)
245}
246
247#[cfg(test)]
248mod test_assert_receive_text {
249    use crate::TestServer;
250
251    use axum::extract::ws::Message;
252    use axum::extract::ws::WebSocket;
253    use axum::extract::WebSocketUpgrade;
254    use axum::response::Response;
255    use axum::routing::get;
256    use axum::Router;
257
258    fn new_test_app() -> TestServer {
259        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
260            async fn handle_ping_pong(mut socket: WebSocket) {
261                while let Some(maybe_message) = socket.recv().await {
262                    let message_text = maybe_message.unwrap().into_text().unwrap();
263
264                    let encoded_text = format!("Text: {message_text}").try_into().unwrap();
265                    let encoded_data = format!("Binary: {message_text}").into_bytes().into();
266
267                    socket.send(Message::Text(encoded_text)).await.unwrap();
268                    socket.send(Message::Binary(encoded_data)).await.unwrap();
269                }
270            }
271
272            ws.on_upgrade(move |socket| handle_ping_pong(socket))
273        }
274
275        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
276        TestServer::builder().http_transport().build(app).unwrap()
277    }
278
279    #[tokio::test]
280    async fn it_should_ping_pong_text_in_text_and_binary() {
281        let server = new_test_app();
282
283        let mut websocket = server
284            .get_websocket(&"/ws-ping-pong")
285            .await
286            .into_websocket()
287            .await;
288
289        websocket.send_text("Hello World!").await;
290
291        websocket.assert_receive_text("Text: Hello World!").await;
292        websocket.assert_receive_text("Binary: Hello World!").await;
293    }
294
295    #[tokio::test]
296    async fn it_should_ping_pong_large_text_blobs() {
297        const LARGE_BLOB_SIZE: usize = 16777200; // Max websocket size (16mb) - 16 bytes for the 'Text: ' in the reply.
298        let large_blob = (0..LARGE_BLOB_SIZE).map(|_| "X").collect::<String>();
299
300        let server = new_test_app();
301        let mut websocket = server
302            .get_websocket(&"/ws-ping-pong")
303            .await
304            .into_websocket()
305            .await;
306
307        websocket.send_text(&large_blob).await;
308
309        websocket
310            .assert_receive_text(format!("Text: {large_blob}"))
311            .await;
312        websocket
313            .assert_receive_text(format!("Binary: {large_blob}"))
314            .await;
315    }
316
317    #[tokio::test]
318    #[should_panic]
319    async fn it_should_not_match_partial_text_match() {
320        let server = new_test_app();
321
322        let mut websocket = server
323            .get_websocket(&"/ws-ping-pong")
324            .await
325            .into_websocket()
326            .await;
327
328        websocket.send_text("Hello World!").await;
329        websocket.assert_receive_text("Hello World!").await;
330    }
331
332    #[tokio::test]
333    #[should_panic]
334    async fn it_should_not_match_different_text() {
335        let server = new_test_app();
336
337        let mut websocket = server
338            .get_websocket(&"/ws-ping-pong")
339            .await
340            .into_websocket()
341            .await;
342
343        websocket.send_text("Hello World!").await;
344        websocket.assert_receive_text("🦊").await;
345    }
346}
347
348#[cfg(test)]
349mod test_assert_receive_text_contains {
350    use crate::TestServer;
351
352    use axum::extract::ws::Message;
353    use axum::extract::ws::WebSocket;
354    use axum::extract::WebSocketUpgrade;
355    use axum::response::Response;
356    use axum::routing::get;
357    use axum::Router;
358
359    fn new_test_app() -> TestServer {
360        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
361            async fn handle_ping_pong(mut socket: WebSocket) {
362                while let Some(maybe_message) = socket.recv().await {
363                    let message_text = maybe_message.unwrap().into_text().unwrap();
364                    let encoded_text = format!("Text: {message_text}").try_into().unwrap();
365
366                    socket.send(Message::Text(encoded_text)).await.unwrap();
367                }
368            }
369
370            ws.on_upgrade(move |socket| handle_ping_pong(socket))
371        }
372
373        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
374        TestServer::builder().http_transport().build(app).unwrap()
375    }
376
377    #[tokio::test]
378    async fn it_should_assert_whole_text_match() {
379        let server = new_test_app();
380
381        let mut websocket = server
382            .get_websocket(&"/ws-ping-pong")
383            .await
384            .into_websocket()
385            .await;
386
387        websocket.send_text("Hello World!").await;
388        websocket
389            .assert_receive_text_contains("Text: Hello World!")
390            .await;
391    }
392
393    #[tokio::test]
394    async fn it_should_assert_partial_text_match() {
395        let server = new_test_app();
396
397        let mut websocket = server
398            .get_websocket(&"/ws-ping-pong")
399            .await
400            .into_websocket()
401            .await;
402
403        websocket.send_text("Hello World!").await;
404        websocket.assert_receive_text_contains("Hello World!").await;
405    }
406
407    #[tokio::test]
408    #[should_panic]
409    async fn it_should_not_match_different_text() {
410        let server = new_test_app();
411
412        let mut websocket = server
413            .get_websocket(&"/ws-ping-pong")
414            .await
415            .into_websocket()
416            .await;
417
418        websocket.send_text("Hello World!").await;
419        websocket.assert_receive_text_contains("🦊").await;
420    }
421}
422
423#[cfg(test)]
424mod test_assert_receive_json {
425    use crate::TestServer;
426
427    use axum::extract::ws::Message;
428    use axum::extract::ws::WebSocket;
429    use axum::extract::WebSocketUpgrade;
430    use axum::response::Response;
431    use axum::routing::get;
432    use axum::Router;
433    use serde_json::json;
434    use serde_json::Value;
435
436    fn new_test_app() -> TestServer {
437        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
438            async fn handle_ping_pong(mut socket: WebSocket) {
439                while let Some(maybe_message) = socket.recv().await {
440                    let message_text = maybe_message.unwrap().into_text().unwrap();
441                    let decoded = serde_json::from_str::<Value>(&message_text).unwrap();
442
443                    let encoded_text = serde_json::to_string(&json!({
444                        "format": "text",
445                        "message": decoded
446                    }))
447                    .unwrap()
448                    .try_into()
449                    .unwrap();
450                    let encoded_data = serde_json::to_vec(&json!({
451                        "format": "binary",
452                        "message": decoded
453                    }))
454                    .unwrap()
455                    .into();
456
457                    socket.send(Message::Text(encoded_text)).await.unwrap();
458                    socket.send(Message::Binary(encoded_data)).await.unwrap();
459                }
460            }
461
462            ws.on_upgrade(move |socket| handle_ping_pong(socket))
463        }
464
465        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
466        TestServer::builder().http_transport().build(app).unwrap()
467    }
468
469    #[tokio::test]
470    async fn it_should_ping_pong_json_in_text_and_binary() {
471        let server = new_test_app();
472
473        let mut websocket = server
474            .get_websocket(&"/ws-ping-pong")
475            .await
476            .into_websocket()
477            .await;
478
479        websocket
480            .send_json(&json!({
481                "hello": "world",
482                "numbers": [1, 2, 3],
483            }))
484            .await;
485
486        // Once for text
487        websocket
488            .assert_receive_json(&json!({
489                "format": "text",
490                "message": {
491                    "hello": "world",
492                    "numbers": [1, 2, 3],
493                },
494            }))
495            .await;
496
497        // Again for binary
498        websocket
499            .assert_receive_json(&json!({
500                "format": "binary",
501                "message": {
502                    "hello": "world",
503                    "numbers": [1, 2, 3],
504                },
505            }))
506            .await;
507    }
508}
509
510#[cfg(feature = "yaml")]
511#[cfg(test)]
512mod test_assert_receive_yaml {
513    use crate::TestServer;
514
515    use axum::extract::ws::Message;
516    use axum::extract::ws::WebSocket;
517    use axum::extract::WebSocketUpgrade;
518    use axum::response::Response;
519    use axum::routing::get;
520    use axum::Router;
521    use serde_json::json;
522    use serde_json::Value;
523
524    fn new_test_app() -> TestServer {
525        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
526            async fn handle_ping_pong(mut socket: WebSocket) {
527                while let Some(maybe_message) = socket.recv().await {
528                    let message_text = maybe_message.unwrap().into_text().unwrap();
529                    let decoded = serde_yaml::from_str::<Value>(&message_text).unwrap();
530
531                    let encoded_text = serde_yaml::to_string(&json!({
532                        "format": "text",
533                        "message": decoded
534                    }))
535                    .unwrap()
536                    .try_into()
537                    .unwrap();
538                    let encoded_data = serde_yaml::to_string(&json!({
539                        "format": "binary",
540                        "message": decoded
541                    }))
542                    .unwrap()
543                    .into();
544
545                    socket.send(Message::Text(encoded_text)).await.unwrap();
546                    socket.send(Message::Binary(encoded_data)).await.unwrap();
547                }
548            }
549
550            ws.on_upgrade(move |socket| handle_ping_pong(socket))
551        }
552
553        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
554        TestServer::builder().http_transport().build(app).unwrap()
555    }
556
557    #[tokio::test]
558    async fn it_should_ping_pong_yaml_in_text_and_binary() {
559        let server = new_test_app();
560
561        let mut websocket = server
562            .get_websocket(&"/ws-ping-pong")
563            .await
564            .into_websocket()
565            .await;
566
567        websocket
568            .send_json(&json!({
569                "hello": "world",
570                "numbers": [1, 2, 3],
571            }))
572            .await;
573
574        // Once for text
575        websocket
576            .assert_receive_yaml(&json!({
577                "format": "text",
578                "message": {
579                    "hello": "world",
580                    "numbers": [1, 2, 3],
581                },
582            }))
583            .await;
584
585        // Again for binary
586        websocket
587            .assert_receive_yaml(&json!({
588                "format": "binary",
589                "message": {
590                    "hello": "world",
591                    "numbers": [1, 2, 3],
592                },
593            }))
594            .await;
595    }
596}
597
598#[cfg(feature = "msgpack")]
599#[cfg(test)]
600mod test_assert_receive_msgpack {
601    use crate::TestServer;
602
603    use axum::extract::ws::Message;
604    use axum::extract::ws::WebSocket;
605    use axum::extract::WebSocketUpgrade;
606    use axum::response::Response;
607    use axum::routing::get;
608    use axum::Router;
609    use serde_json::json;
610    use serde_json::Value;
611
612    fn new_test_app() -> TestServer {
613        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
614            async fn handle_ping_pong(mut socket: WebSocket) {
615                while let Some(maybe_message) = socket.recv().await {
616                    let message_data = maybe_message.unwrap().into_data();
617                    let decoded = rmp_serde::from_slice::<Value>(&message_data).unwrap();
618
619                    let encoded_data = ::rmp_serde::to_vec(&json!({
620                        "format": "binary",
621                        "message": decoded
622                    }))
623                    .unwrap()
624                    .into();
625
626                    socket.send(Message::Binary(encoded_data)).await.unwrap();
627                }
628            }
629
630            ws.on_upgrade(move |socket| handle_ping_pong(socket))
631        }
632
633        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
634        TestServer::builder().http_transport().build(app).unwrap()
635    }
636
637    #[tokio::test]
638    async fn it_should_ping_pong_msgpack_in_binary() {
639        let server = new_test_app();
640
641        let mut websocket = server
642            .get_websocket(&"/ws-ping-pong")
643            .await
644            .into_websocket()
645            .await;
646
647        websocket
648            .send_msgpack(&json!({
649                "hello": "world",
650                "numbers": [1, 2, 3],
651            }))
652            .await;
653
654        websocket
655            .assert_receive_msgpack(&json!({
656                "format": "binary",
657                "message": {
658                    "hello": "world",
659                    "numbers": [1, 2, 3],
660                },
661            }))
662            .await;
663    }
664}