axum_test/
test_web_socket.rs

1use crate::WsMessage;
2use anyhow::anyhow;
3use anyhow::Context;
4use anyhow::Result;
5use bytes::Bytes;
6use futures_util::sink::SinkExt;
7use futures_util::stream::StreamExt;
8use hyper::upgrade::Upgraded;
9use hyper_util::rt::TokioIo;
10use serde::de::DeserializeOwned;
11use serde::Serialize;
12use std::fmt::Debug;
13use std::fmt::Display;
14use tokio_tungstenite::tungstenite::protocol::Role;
15use tokio_tungstenite::WebSocketStream;
16
17#[cfg(feature = "pretty-assertions")]
18use pretty_assertions::assert_eq;
19
20pub struct TestWebSocket {
21    stream: WebSocketStream<TokioIo<Upgraded>>,
22}
23
24impl TestWebSocket {
25    pub(crate) async fn new(upgraded: Upgraded) -> Self {
26        let upgraded_io = TokioIo::new(upgraded);
27        let stream = WebSocketStream::from_raw_socket(upgraded_io, Role::Client, None).await;
28
29        Self { stream }
30    }
31
32    pub async fn close(mut self) {
33        self.stream
34            .close(None)
35            .await
36            .expect("Failed to close WebSocket stream");
37    }
38
39    pub async fn send_text<T>(&mut self, raw_text: T)
40    where
41        T: Display,
42    {
43        let text = format!("{}", raw_text);
44        self.send_message(WsMessage::Text(text.into())).await;
45    }
46
47    pub async fn send_json<J>(&mut self, body: &J)
48    where
49        J: ?Sized + Serialize,
50    {
51        let raw_json =
52            ::serde_json::to_string(body).expect("It should serialize the content into Json");
53
54        self.send_message(WsMessage::Text(raw_json.into())).await;
55    }
56
57    #[cfg(feature = "yaml")]
58    pub async fn send_yaml<Y>(&mut self, body: &Y)
59    where
60        Y: ?Sized + Serialize,
61    {
62        let raw_yaml =
63            ::serde_yaml::to_string(body).expect("It should serialize the content into Yaml");
64
65        self.send_message(WsMessage::Text(raw_yaml.into())).await;
66    }
67
68    #[cfg(feature = "msgpack")]
69    pub async fn send_msgpack<M>(&mut self, body: &M)
70    where
71        M: ?Sized + Serialize,
72    {
73        let body_bytes =
74            ::rmp_serde::to_vec(body).expect("It should serialize the content into MsgPack");
75
76        self.send_message(WsMessage::Binary(body_bytes.into()))
77            .await;
78    }
79
80    pub async fn send_message(&mut self, message: WsMessage) {
81        self.stream.send(message).await.unwrap();
82    }
83
84    #[must_use]
85    pub async fn receive_text(&mut self) -> String {
86        let message = self.receive_message().await;
87
88        message_to_text(message)
89            .context("Failed to read message as a String")
90            .unwrap()
91    }
92
93    #[must_use]
94    pub async fn receive_json<T>(&mut self) -> T
95    where
96        T: DeserializeOwned,
97    {
98        let bytes = self.receive_bytes().await;
99        serde_json::from_slice::<T>(&bytes)
100            .context("Failed to deserialize message as Json")
101            .unwrap()
102    }
103
104    #[cfg(feature = "yaml")]
105    #[must_use]
106    pub async fn receive_yaml<T>(&mut self) -> T
107    where
108        T: DeserializeOwned,
109    {
110        let bytes = self.receive_bytes().await;
111        serde_yaml::from_slice::<T>(&bytes)
112            .context("Failed to deserialize message as Yaml")
113            .unwrap()
114    }
115
116    #[cfg(feature = "msgpack")]
117    #[must_use]
118    pub async fn receive_msgpack<T>(&mut self) -> T
119    where
120        T: DeserializeOwned,
121    {
122        let received_bytes = self.receive_bytes().await;
123        rmp_serde::from_slice::<T>(&received_bytes)
124            .context("Failed to deserializing message as MsgPack")
125            .unwrap()
126    }
127
128    #[must_use]
129    pub async fn receive_bytes(&mut self) -> Bytes {
130        let message = self.receive_message().await;
131
132        message_to_bytes(message)
133            .context("Failed to read message as a Bytes")
134            .unwrap()
135    }
136
137    #[must_use]
138    pub async fn receive_message(&mut self) -> WsMessage {
139        self.maybe_receive_message()
140            .await
141            .expect("No message found on WebSocket stream")
142    }
143
144    pub async fn assert_receive_json<T>(&mut self, expected: &T)
145    where
146        T: DeserializeOwned + PartialEq<T> + Debug,
147    {
148        assert_eq!(*expected, self.receive_json::<T>().await);
149    }
150
151    pub async fn assert_receive_text<C>(&mut self, expected: C)
152    where
153        C: AsRef<str>,
154    {
155        let expected_contents = expected.as_ref();
156        assert_eq!(expected_contents, &self.receive_text().await);
157    }
158
159    pub async fn assert_receive_text_contains<C>(&mut self, expected: C)
160    where
161        C: AsRef<str>,
162    {
163        let expected_contents = expected.as_ref();
164        let received = self.receive_text().await;
165        let is_contained = received.contains(expected_contents);
166
167        assert!(
168            is_contained,
169            "Failed to find '{expected_contents}', received '{received}'"
170        );
171    }
172
173    #[cfg(feature = "yaml")]
174    pub async fn assert_receive_yaml<T>(&mut self, expected: &T)
175    where
176        T: DeserializeOwned + PartialEq<T> + Debug,
177    {
178        assert_eq!(*expected, self.receive_yaml::<T>().await);
179    }
180
181    #[cfg(feature = "msgpack")]
182    pub async fn assert_receive_msgpack<T>(&mut self, expected: &T)
183    where
184        T: DeserializeOwned + PartialEq<T> + Debug,
185    {
186        assert_eq!(*expected, self.receive_msgpack::<T>().await);
187    }
188
189    #[must_use]
190    async fn maybe_receive_message(&mut self) -> Option<WsMessage> {
191        let maybe_message = self.stream.next().await;
192
193        match maybe_message {
194            None => None,
195            Some(message_result) => {
196                let message =
197                    message_result.expect("Failed to receive message from WebSocket stream");
198                Some(message)
199            }
200        }
201    }
202}
203
204fn message_to_text(message: WsMessage) -> Result<String> {
205    let text = match message {
206        WsMessage::Text(text) => text.to_string(),
207        WsMessage::Binary(data) => {
208            String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())?
209        }
210        WsMessage::Ping(data) => {
211            String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())?
212        }
213        WsMessage::Pong(data) => {
214            String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())?
215        }
216        WsMessage::Close(None) => String::new(),
217        WsMessage::Close(Some(frame)) => frame.reason.to_string(),
218        WsMessage::Frame(_) => {
219            return Err(anyhow!(
220                "Unexpected Frame, did not expect Frame message whilst reading"
221            ))
222        }
223    };
224
225    Ok(text)
226}
227
228fn message_to_bytes(message: WsMessage) -> Result<Bytes> {
229    let bytes = match message {
230        WsMessage::Text(string) => string.into(),
231        WsMessage::Binary(data) => data,
232        WsMessage::Ping(data) => data,
233        WsMessage::Pong(data) => data,
234        WsMessage::Close(None) => Bytes::new(),
235        WsMessage::Close(Some(frame)) => frame.reason.into(),
236        WsMessage::Frame(_) => {
237            return Err(anyhow!(
238                "Unexpected Frame, did not expect Frame message whilst reading"
239            ))
240        }
241    };
242
243    Ok(bytes)
244}
245
246#[cfg(test)]
247mod test_assert_receive_text {
248    use crate::TestServer;
249
250    use axum::extract::ws::Message;
251    use axum::extract::ws::WebSocket;
252    use axum::extract::WebSocketUpgrade;
253    use axum::response::Response;
254    use axum::routing::get;
255    use axum::Router;
256
257    fn new_test_app() -> TestServer {
258        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
259            async fn handle_ping_pong(mut socket: WebSocket) {
260                while let Some(maybe_message) = socket.recv().await {
261                    let message_text = maybe_message.unwrap().into_text().unwrap();
262
263                    let encoded_text = format!("Text: {message_text}").try_into().unwrap();
264                    let encoded_data = format!("Binary: {message_text}").into_bytes().into();
265
266                    socket.send(Message::Text(encoded_text)).await.unwrap();
267                    socket.send(Message::Binary(encoded_data)).await.unwrap();
268                }
269            }
270
271            ws.on_upgrade(move |socket| handle_ping_pong(socket))
272        }
273
274        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
275        TestServer::builder().http_transport().build(app).unwrap()
276    }
277
278    #[tokio::test]
279    async fn it_should_ping_pong_text_in_text_and_binary() {
280        let server = new_test_app();
281
282        let mut websocket = server
283            .get_websocket(&"/ws-ping-pong")
284            .await
285            .into_websocket()
286            .await;
287
288        websocket.send_text("Hello World!").await;
289
290        websocket.assert_receive_text("Text: Hello World!").await;
291        websocket.assert_receive_text("Binary: Hello World!").await;
292    }
293
294    #[tokio::test]
295    async fn it_should_ping_pong_large_text_blobs() {
296        const LARGE_BLOB_SIZE: usize = 16777200; // Max websocket size (16mb) - 16 bytes for the 'Text: ' in the reply.
297        let large_blob = (0..LARGE_BLOB_SIZE).map(|_| "X").collect::<String>();
298
299        let server = new_test_app();
300        let mut websocket = server
301            .get_websocket(&"/ws-ping-pong")
302            .await
303            .into_websocket()
304            .await;
305
306        websocket.send_text(&large_blob).await;
307
308        websocket
309            .assert_receive_text(format!("Text: {large_blob}"))
310            .await;
311        websocket
312            .assert_receive_text(format!("Binary: {large_blob}"))
313            .await;
314    }
315
316    #[tokio::test]
317    #[should_panic]
318    async fn it_should_not_match_partial_text_match() {
319        let server = new_test_app();
320
321        let mut websocket = server
322            .get_websocket(&"/ws-ping-pong")
323            .await
324            .into_websocket()
325            .await;
326
327        websocket.send_text("Hello World!").await;
328        websocket.assert_receive_text("Hello World!").await;
329    }
330
331    #[tokio::test]
332    #[should_panic]
333    async fn it_should_not_match_different_text() {
334        let server = new_test_app();
335
336        let mut websocket = server
337            .get_websocket(&"/ws-ping-pong")
338            .await
339            .into_websocket()
340            .await;
341
342        websocket.send_text("Hello World!").await;
343        websocket.assert_receive_text("🦊").await;
344    }
345}
346
347#[cfg(test)]
348mod test_assert_receive_text_contains {
349    use crate::TestServer;
350
351    use axum::extract::ws::Message;
352    use axum::extract::ws::WebSocket;
353    use axum::extract::WebSocketUpgrade;
354    use axum::response::Response;
355    use axum::routing::get;
356    use axum::Router;
357
358    fn new_test_app() -> TestServer {
359        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
360            async fn handle_ping_pong(mut socket: WebSocket) {
361                while let Some(maybe_message) = socket.recv().await {
362                    let message_text = maybe_message.unwrap().into_text().unwrap();
363                    let encoded_text = format!("Text: {message_text}").try_into().unwrap();
364
365                    socket.send(Message::Text(encoded_text)).await.unwrap();
366                }
367            }
368
369            ws.on_upgrade(move |socket| handle_ping_pong(socket))
370        }
371
372        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
373        TestServer::builder().http_transport().build(app).unwrap()
374    }
375
376    #[tokio::test]
377    async fn it_should_assert_whole_text_match() {
378        let server = new_test_app();
379
380        let mut websocket = server
381            .get_websocket(&"/ws-ping-pong")
382            .await
383            .into_websocket()
384            .await;
385
386        websocket.send_text("Hello World!").await;
387        websocket
388            .assert_receive_text_contains("Text: Hello World!")
389            .await;
390    }
391
392    #[tokio::test]
393    async fn it_should_assert_partial_text_match() {
394        let server = new_test_app();
395
396        let mut websocket = server
397            .get_websocket(&"/ws-ping-pong")
398            .await
399            .into_websocket()
400            .await;
401
402        websocket.send_text("Hello World!").await;
403        websocket.assert_receive_text_contains("Hello World!").await;
404    }
405
406    #[tokio::test]
407    #[should_panic]
408    async fn it_should_not_match_different_text() {
409        let server = new_test_app();
410
411        let mut websocket = server
412            .get_websocket(&"/ws-ping-pong")
413            .await
414            .into_websocket()
415            .await;
416
417        websocket.send_text("Hello World!").await;
418        websocket.assert_receive_text_contains("🦊").await;
419    }
420}
421
422#[cfg(test)]
423mod test_assert_receive_json {
424    use crate::TestServer;
425
426    use axum::extract::ws::Message;
427    use axum::extract::ws::WebSocket;
428    use axum::extract::WebSocketUpgrade;
429    use axum::response::Response;
430    use axum::routing::get;
431    use axum::Router;
432    use serde_json::json;
433    use serde_json::Value;
434
435    fn new_test_app() -> TestServer {
436        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
437            async fn handle_ping_pong(mut socket: WebSocket) {
438                while let Some(maybe_message) = socket.recv().await {
439                    let message_text = maybe_message.unwrap().into_text().unwrap();
440                    let decoded = serde_json::from_str::<Value>(&message_text).unwrap();
441
442                    let encoded_text = serde_json::to_string(&json!({
443                        "format": "text",
444                        "message": decoded
445                    }))
446                    .unwrap()
447                    .try_into()
448                    .unwrap();
449                    let encoded_data = serde_json::to_vec(&json!({
450                        "format": "binary",
451                        "message": decoded
452                    }))
453                    .unwrap()
454                    .into();
455
456                    socket.send(Message::Text(encoded_text)).await.unwrap();
457                    socket.send(Message::Binary(encoded_data)).await.unwrap();
458                }
459            }
460
461            ws.on_upgrade(move |socket| handle_ping_pong(socket))
462        }
463
464        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
465        TestServer::builder().http_transport().build(app).unwrap()
466    }
467
468    #[tokio::test]
469    async fn it_should_ping_pong_json_in_text_and_binary() {
470        let server = new_test_app();
471
472        let mut websocket = server
473            .get_websocket(&"/ws-ping-pong")
474            .await
475            .into_websocket()
476            .await;
477
478        websocket
479            .send_json(&json!({
480                "hello": "world",
481                "numbers": [1, 2, 3],
482            }))
483            .await;
484
485        // Once for text
486        websocket
487            .assert_receive_json(&json!({
488                "format": "text",
489                "message": {
490                    "hello": "world",
491                    "numbers": [1, 2, 3],
492                },
493            }))
494            .await;
495
496        // Again for binary
497        websocket
498            .assert_receive_json(&json!({
499                "format": "binary",
500                "message": {
501                    "hello": "world",
502                    "numbers": [1, 2, 3],
503                },
504            }))
505            .await;
506    }
507}
508
509#[cfg(feature = "yaml")]
510#[cfg(test)]
511mod test_assert_receive_yaml {
512    use crate::TestServer;
513
514    use axum::extract::ws::Message;
515    use axum::extract::ws::WebSocket;
516    use axum::extract::WebSocketUpgrade;
517    use axum::response::Response;
518    use axum::routing::get;
519    use axum::Router;
520    use serde_json::json;
521    use serde_json::Value;
522
523    fn new_test_app() -> TestServer {
524        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
525            async fn handle_ping_pong(mut socket: WebSocket) {
526                while let Some(maybe_message) = socket.recv().await {
527                    let message_text = maybe_message.unwrap().into_text().unwrap();
528                    let decoded = serde_yaml::from_str::<Value>(&message_text).unwrap();
529
530                    let encoded_text = serde_yaml::to_string(&json!({
531                        "format": "text",
532                        "message": decoded
533                    }))
534                    .unwrap()
535                    .try_into()
536                    .unwrap();
537                    let encoded_data = serde_yaml::to_string(&json!({
538                        "format": "binary",
539                        "message": decoded
540                    }))
541                    .unwrap()
542                    .into();
543
544                    socket.send(Message::Text(encoded_text)).await.unwrap();
545                    socket.send(Message::Binary(encoded_data)).await.unwrap();
546                }
547            }
548
549            ws.on_upgrade(move |socket| handle_ping_pong(socket))
550        }
551
552        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
553        TestServer::builder().http_transport().build(app).unwrap()
554    }
555
556    #[tokio::test]
557    async fn it_should_ping_pong_yaml_in_text_and_binary() {
558        let server = new_test_app();
559
560        let mut websocket = server
561            .get_websocket(&"/ws-ping-pong")
562            .await
563            .into_websocket()
564            .await;
565
566        websocket
567            .send_json(&json!({
568                "hello": "world",
569                "numbers": [1, 2, 3],
570            }))
571            .await;
572
573        // Once for text
574        websocket
575            .assert_receive_yaml(&json!({
576                "format": "text",
577                "message": {
578                    "hello": "world",
579                    "numbers": [1, 2, 3],
580                },
581            }))
582            .await;
583
584        // Again for binary
585        websocket
586            .assert_receive_yaml(&json!({
587                "format": "binary",
588                "message": {
589                    "hello": "world",
590                    "numbers": [1, 2, 3],
591                },
592            }))
593            .await;
594    }
595}
596
597#[cfg(feature = "msgpack")]
598#[cfg(test)]
599mod test_assert_receive_msgpack {
600    use crate::TestServer;
601
602    use axum::extract::ws::Message;
603    use axum::extract::ws::WebSocket;
604    use axum::extract::WebSocketUpgrade;
605    use axum::response::Response;
606    use axum::routing::get;
607    use axum::Router;
608    use serde_json::json;
609    use serde_json::Value;
610
611    fn new_test_app() -> TestServer {
612        pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
613            async fn handle_ping_pong(mut socket: WebSocket) {
614                while let Some(maybe_message) = socket.recv().await {
615                    let message_data = maybe_message.unwrap().into_data();
616                    let decoded = rmp_serde::from_slice::<Value>(&message_data).unwrap();
617
618                    let encoded_data = ::rmp_serde::to_vec(&json!({
619                        "format": "binary",
620                        "message": decoded
621                    }))
622                    .unwrap()
623                    .into();
624
625                    socket.send(Message::Binary(encoded_data)).await.unwrap();
626                }
627            }
628
629            ws.on_upgrade(move |socket| handle_ping_pong(socket))
630        }
631
632        let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
633        TestServer::builder().http_transport().build(app).unwrap()
634    }
635
636    #[tokio::test]
637    async fn it_should_ping_pong_msgpack_in_binary() {
638        let server = new_test_app();
639
640        let mut websocket = server
641            .get_websocket(&"/ws-ping-pong")
642            .await
643            .into_websocket()
644            .await;
645
646        websocket
647            .send_msgpack(&json!({
648                "hello": "world",
649                "numbers": [1, 2, 3],
650            }))
651            .await;
652
653        websocket
654            .assert_receive_msgpack(&json!({
655                "format": "binary",
656                "message": {
657                    "hello": "world",
658                    "numbers": [1, 2, 3],
659                },
660            }))
661            .await;
662    }
663}