1use crate::WsMessage;
2use anyhow::anyhow;
3use anyhow::Context;
4use anyhow::Result;
5use bytes::Bytes;
6use futures_util::sink::SinkExt;
7use futures_util::stream::StreamExt;
8use hyper::upgrade::Upgraded;
9use hyper_util::rt::TokioIo;
10use serde::de::DeserializeOwned;
11use serde::Serialize;
12use std::fmt::Debug;
13use std::fmt::Display;
14use tokio_tungstenite::tungstenite::protocol::Role;
15use tokio_tungstenite::WebSocketStream;
16
17#[cfg(feature = "pretty-assertions")]
18use pretty_assertions::assert_eq;
19
20pub struct TestWebSocket {
21 stream: WebSocketStream<TokioIo<Upgraded>>,
22}
23
24impl TestWebSocket {
25 pub(crate) async fn new(upgraded: Upgraded) -> Self {
26 let upgraded_io = TokioIo::new(upgraded);
27 let stream = WebSocketStream::from_raw_socket(upgraded_io, Role::Client, None).await;
28
29 Self { stream }
30 }
31
32 pub async fn close(mut self) {
33 self.stream
34 .close(None)
35 .await
36 .expect("Failed to close WebSocket stream");
37 }
38
39 pub async fn send_text<T>(&mut self, raw_text: T)
40 where
41 T: Display,
42 {
43 let text = format!("{}", raw_text);
44 self.send_message(WsMessage::Text(text.into())).await;
45 }
46
47 pub async fn send_json<J>(&mut self, body: &J)
48 where
49 J: ?Sized + Serialize,
50 {
51 let raw_json =
52 ::serde_json::to_string(body).expect("It should serialize the content into Json");
53
54 self.send_message(WsMessage::Text(raw_json.into())).await;
55 }
56
57 #[cfg(feature = "yaml")]
58 pub async fn send_yaml<Y>(&mut self, body: &Y)
59 where
60 Y: ?Sized + Serialize,
61 {
62 let raw_yaml =
63 ::serde_yaml::to_string(body).expect("It should serialize the content into Yaml");
64
65 self.send_message(WsMessage::Text(raw_yaml.into())).await;
66 }
67
68 #[cfg(feature = "msgpack")]
69 pub async fn send_msgpack<M>(&mut self, body: &M)
70 where
71 M: ?Sized + Serialize,
72 {
73 let body_bytes =
74 ::rmp_serde::to_vec(body).expect("It should serialize the content into MsgPack");
75
76 self.send_message(WsMessage::Binary(body_bytes.into()))
77 .await;
78 }
79
80 pub async fn send_message(&mut self, message: WsMessage) {
81 self.stream.send(message).await.unwrap();
82 }
83
84 #[must_use]
85 pub async fn receive_text(&mut self) -> String {
86 let message = self.receive_message().await;
87
88 message_to_text(message)
89 .context("Failed to read message as a String")
90 .unwrap()
91 }
92
93 #[must_use]
94 pub async fn receive_json<T>(&mut self) -> T
95 where
96 T: DeserializeOwned,
97 {
98 let bytes = self.receive_bytes().await;
99 serde_json::from_slice::<T>(&bytes)
100 .context("Failed to deserialize message as Json")
101 .unwrap()
102 }
103
104 #[cfg(feature = "yaml")]
105 #[must_use]
106 pub async fn receive_yaml<T>(&mut self) -> T
107 where
108 T: DeserializeOwned,
109 {
110 let bytes = self.receive_bytes().await;
111 serde_yaml::from_slice::<T>(&bytes)
112 .context("Failed to deserialize message as Yaml")
113 .unwrap()
114 }
115
116 #[cfg(feature = "msgpack")]
117 #[must_use]
118 pub async fn receive_msgpack<T>(&mut self) -> T
119 where
120 T: DeserializeOwned,
121 {
122 let received_bytes = self.receive_bytes().await;
123 rmp_serde::from_slice::<T>(&received_bytes)
124 .context("Failed to deserializing message as MsgPack")
125 .unwrap()
126 }
127
128 #[must_use]
129 pub async fn receive_bytes(&mut self) -> Bytes {
130 let message = self.receive_message().await;
131
132 message_to_bytes(message)
133 .context("Failed to read message as a Bytes")
134 .unwrap()
135 }
136
137 #[must_use]
138 pub async fn receive_message(&mut self) -> WsMessage {
139 self.maybe_receive_message()
140 .await
141 .expect("No message found on WebSocket stream")
142 }
143
144 pub async fn assert_receive_json<T>(&mut self, expected: &T)
145 where
146 T: DeserializeOwned + PartialEq<T> + Debug,
147 {
148 assert_eq!(*expected, self.receive_json::<T>().await);
149 }
150
151 pub async fn assert_receive_text<C>(&mut self, expected: C)
152 where
153 C: AsRef<str>,
154 {
155 let expected_contents = expected.as_ref();
156 assert_eq!(expected_contents, &self.receive_text().await);
157 }
158
159 pub async fn assert_receive_text_contains<C>(&mut self, expected: C)
160 where
161 C: AsRef<str>,
162 {
163 let expected_contents = expected.as_ref();
164 let received = self.receive_text().await;
165 let is_contained = received.contains(expected_contents);
166
167 assert!(
168 is_contained,
169 "Failed to find '{expected_contents}', received '{received}'"
170 );
171 }
172
173 #[cfg(feature = "yaml")]
174 pub async fn assert_receive_yaml<T>(&mut self, expected: &T)
175 where
176 T: DeserializeOwned + PartialEq<T> + Debug,
177 {
178 assert_eq!(*expected, self.receive_yaml::<T>().await);
179 }
180
181 #[cfg(feature = "msgpack")]
182 pub async fn assert_receive_msgpack<T>(&mut self, expected: &T)
183 where
184 T: DeserializeOwned + PartialEq<T> + Debug,
185 {
186 assert_eq!(*expected, self.receive_msgpack::<T>().await);
187 }
188
189 #[must_use]
190 async fn maybe_receive_message(&mut self) -> Option<WsMessage> {
191 let maybe_message = self.stream.next().await;
192
193 match maybe_message {
194 None => None,
195 Some(message_result) => {
196 let message =
197 message_result.expect("Failed to receive message from WebSocket stream");
198 Some(message)
199 }
200 }
201 }
202}
203
204fn message_to_text(message: WsMessage) -> Result<String> {
205 let text = match message {
206 WsMessage::Text(text) => text.to_string(),
207 WsMessage::Binary(data) => {
208 String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())?
209 }
210 WsMessage::Ping(data) => {
211 String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())?
212 }
213 WsMessage::Pong(data) => {
214 String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())?
215 }
216 WsMessage::Close(None) => String::new(),
217 WsMessage::Close(Some(frame)) => frame.reason.to_string(),
218 WsMessage::Frame(_) => {
219 return Err(anyhow!(
220 "Unexpected Frame, did not expect Frame message whilst reading"
221 ))
222 }
223 };
224
225 Ok(text)
226}
227
228fn message_to_bytes(message: WsMessage) -> Result<Bytes> {
229 let bytes = match message {
230 WsMessage::Text(string) => string.into(),
231 WsMessage::Binary(data) => data,
232 WsMessage::Ping(data) => data,
233 WsMessage::Pong(data) => data,
234 WsMessage::Close(None) => Bytes::new(),
235 WsMessage::Close(Some(frame)) => frame.reason.into(),
236 WsMessage::Frame(_) => {
237 return Err(anyhow!(
238 "Unexpected Frame, did not expect Frame message whilst reading"
239 ))
240 }
241 };
242
243 Ok(bytes)
244}
245
246#[cfg(test)]
247mod test_assert_receive_text {
248 use crate::TestServer;
249
250 use axum::extract::ws::Message;
251 use axum::extract::ws::WebSocket;
252 use axum::extract::WebSocketUpgrade;
253 use axum::response::Response;
254 use axum::routing::get;
255 use axum::Router;
256
257 fn new_test_app() -> TestServer {
258 pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
259 async fn handle_ping_pong(mut socket: WebSocket) {
260 while let Some(maybe_message) = socket.recv().await {
261 let message_text = maybe_message.unwrap().into_text().unwrap();
262
263 let encoded_text = format!("Text: {message_text}").try_into().unwrap();
264 let encoded_data = format!("Binary: {message_text}").into_bytes().into();
265
266 socket.send(Message::Text(encoded_text)).await.unwrap();
267 socket.send(Message::Binary(encoded_data)).await.unwrap();
268 }
269 }
270
271 ws.on_upgrade(move |socket| handle_ping_pong(socket))
272 }
273
274 let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
275 TestServer::builder().http_transport().build(app).unwrap()
276 }
277
278 #[tokio::test]
279 async fn it_should_ping_pong_text_in_text_and_binary() {
280 let server = new_test_app();
281
282 let mut websocket = server
283 .get_websocket(&"/ws-ping-pong")
284 .await
285 .into_websocket()
286 .await;
287
288 websocket.send_text("Hello World!").await;
289
290 websocket.assert_receive_text("Text: Hello World!").await;
291 websocket.assert_receive_text("Binary: Hello World!").await;
292 }
293
294 #[tokio::test]
295 async fn it_should_ping_pong_large_text_blobs() {
296 const LARGE_BLOB_SIZE: usize = 16777200; let large_blob = (0..LARGE_BLOB_SIZE).map(|_| "X").collect::<String>();
298
299 let server = new_test_app();
300 let mut websocket = server
301 .get_websocket(&"/ws-ping-pong")
302 .await
303 .into_websocket()
304 .await;
305
306 websocket.send_text(&large_blob).await;
307
308 websocket
309 .assert_receive_text(format!("Text: {large_blob}"))
310 .await;
311 websocket
312 .assert_receive_text(format!("Binary: {large_blob}"))
313 .await;
314 }
315
316 #[tokio::test]
317 #[should_panic]
318 async fn it_should_not_match_partial_text_match() {
319 let server = new_test_app();
320
321 let mut websocket = server
322 .get_websocket(&"/ws-ping-pong")
323 .await
324 .into_websocket()
325 .await;
326
327 websocket.send_text("Hello World!").await;
328 websocket.assert_receive_text("Hello World!").await;
329 }
330
331 #[tokio::test]
332 #[should_panic]
333 async fn it_should_not_match_different_text() {
334 let server = new_test_app();
335
336 let mut websocket = server
337 .get_websocket(&"/ws-ping-pong")
338 .await
339 .into_websocket()
340 .await;
341
342 websocket.send_text("Hello World!").await;
343 websocket.assert_receive_text("🦊").await;
344 }
345}
346
347#[cfg(test)]
348mod test_assert_receive_text_contains {
349 use crate::TestServer;
350
351 use axum::extract::ws::Message;
352 use axum::extract::ws::WebSocket;
353 use axum::extract::WebSocketUpgrade;
354 use axum::response::Response;
355 use axum::routing::get;
356 use axum::Router;
357
358 fn new_test_app() -> TestServer {
359 pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
360 async fn handle_ping_pong(mut socket: WebSocket) {
361 while let Some(maybe_message) = socket.recv().await {
362 let message_text = maybe_message.unwrap().into_text().unwrap();
363 let encoded_text = format!("Text: {message_text}").try_into().unwrap();
364
365 socket.send(Message::Text(encoded_text)).await.unwrap();
366 }
367 }
368
369 ws.on_upgrade(move |socket| handle_ping_pong(socket))
370 }
371
372 let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
373 TestServer::builder().http_transport().build(app).unwrap()
374 }
375
376 #[tokio::test]
377 async fn it_should_assert_whole_text_match() {
378 let server = new_test_app();
379
380 let mut websocket = server
381 .get_websocket(&"/ws-ping-pong")
382 .await
383 .into_websocket()
384 .await;
385
386 websocket.send_text("Hello World!").await;
387 websocket
388 .assert_receive_text_contains("Text: Hello World!")
389 .await;
390 }
391
392 #[tokio::test]
393 async fn it_should_assert_partial_text_match() {
394 let server = new_test_app();
395
396 let mut websocket = server
397 .get_websocket(&"/ws-ping-pong")
398 .await
399 .into_websocket()
400 .await;
401
402 websocket.send_text("Hello World!").await;
403 websocket.assert_receive_text_contains("Hello World!").await;
404 }
405
406 #[tokio::test]
407 #[should_panic]
408 async fn it_should_not_match_different_text() {
409 let server = new_test_app();
410
411 let mut websocket = server
412 .get_websocket(&"/ws-ping-pong")
413 .await
414 .into_websocket()
415 .await;
416
417 websocket.send_text("Hello World!").await;
418 websocket.assert_receive_text_contains("🦊").await;
419 }
420}
421
422#[cfg(test)]
423mod test_assert_receive_json {
424 use crate::TestServer;
425
426 use axum::extract::ws::Message;
427 use axum::extract::ws::WebSocket;
428 use axum::extract::WebSocketUpgrade;
429 use axum::response::Response;
430 use axum::routing::get;
431 use axum::Router;
432 use serde_json::json;
433 use serde_json::Value;
434
435 fn new_test_app() -> TestServer {
436 pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
437 async fn handle_ping_pong(mut socket: WebSocket) {
438 while let Some(maybe_message) = socket.recv().await {
439 let message_text = maybe_message.unwrap().into_text().unwrap();
440 let decoded = serde_json::from_str::<Value>(&message_text).unwrap();
441
442 let encoded_text = serde_json::to_string(&json!({
443 "format": "text",
444 "message": decoded
445 }))
446 .unwrap()
447 .try_into()
448 .unwrap();
449 let encoded_data = serde_json::to_vec(&json!({
450 "format": "binary",
451 "message": decoded
452 }))
453 .unwrap()
454 .into();
455
456 socket.send(Message::Text(encoded_text)).await.unwrap();
457 socket.send(Message::Binary(encoded_data)).await.unwrap();
458 }
459 }
460
461 ws.on_upgrade(move |socket| handle_ping_pong(socket))
462 }
463
464 let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
465 TestServer::builder().http_transport().build(app).unwrap()
466 }
467
468 #[tokio::test]
469 async fn it_should_ping_pong_json_in_text_and_binary() {
470 let server = new_test_app();
471
472 let mut websocket = server
473 .get_websocket(&"/ws-ping-pong")
474 .await
475 .into_websocket()
476 .await;
477
478 websocket
479 .send_json(&json!({
480 "hello": "world",
481 "numbers": [1, 2, 3],
482 }))
483 .await;
484
485 websocket
487 .assert_receive_json(&json!({
488 "format": "text",
489 "message": {
490 "hello": "world",
491 "numbers": [1, 2, 3],
492 },
493 }))
494 .await;
495
496 websocket
498 .assert_receive_json(&json!({
499 "format": "binary",
500 "message": {
501 "hello": "world",
502 "numbers": [1, 2, 3],
503 },
504 }))
505 .await;
506 }
507}
508
509#[cfg(feature = "yaml")]
510#[cfg(test)]
511mod test_assert_receive_yaml {
512 use crate::TestServer;
513
514 use axum::extract::ws::Message;
515 use axum::extract::ws::WebSocket;
516 use axum::extract::WebSocketUpgrade;
517 use axum::response::Response;
518 use axum::routing::get;
519 use axum::Router;
520 use serde_json::json;
521 use serde_json::Value;
522
523 fn new_test_app() -> TestServer {
524 pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
525 async fn handle_ping_pong(mut socket: WebSocket) {
526 while let Some(maybe_message) = socket.recv().await {
527 let message_text = maybe_message.unwrap().into_text().unwrap();
528 let decoded = serde_yaml::from_str::<Value>(&message_text).unwrap();
529
530 let encoded_text = serde_yaml::to_string(&json!({
531 "format": "text",
532 "message": decoded
533 }))
534 .unwrap()
535 .try_into()
536 .unwrap();
537 let encoded_data = serde_yaml::to_string(&json!({
538 "format": "binary",
539 "message": decoded
540 }))
541 .unwrap()
542 .into();
543
544 socket.send(Message::Text(encoded_text)).await.unwrap();
545 socket.send(Message::Binary(encoded_data)).await.unwrap();
546 }
547 }
548
549 ws.on_upgrade(move |socket| handle_ping_pong(socket))
550 }
551
552 let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
553 TestServer::builder().http_transport().build(app).unwrap()
554 }
555
556 #[tokio::test]
557 async fn it_should_ping_pong_yaml_in_text_and_binary() {
558 let server = new_test_app();
559
560 let mut websocket = server
561 .get_websocket(&"/ws-ping-pong")
562 .await
563 .into_websocket()
564 .await;
565
566 websocket
567 .send_json(&json!({
568 "hello": "world",
569 "numbers": [1, 2, 3],
570 }))
571 .await;
572
573 websocket
575 .assert_receive_yaml(&json!({
576 "format": "text",
577 "message": {
578 "hello": "world",
579 "numbers": [1, 2, 3],
580 },
581 }))
582 .await;
583
584 websocket
586 .assert_receive_yaml(&json!({
587 "format": "binary",
588 "message": {
589 "hello": "world",
590 "numbers": [1, 2, 3],
591 },
592 }))
593 .await;
594 }
595}
596
597#[cfg(feature = "msgpack")]
598#[cfg(test)]
599mod test_assert_receive_msgpack {
600 use crate::TestServer;
601
602 use axum::extract::ws::Message;
603 use axum::extract::ws::WebSocket;
604 use axum::extract::WebSocketUpgrade;
605 use axum::response::Response;
606 use axum::routing::get;
607 use axum::Router;
608 use serde_json::json;
609 use serde_json::Value;
610
611 fn new_test_app() -> TestServer {
612 pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
613 async fn handle_ping_pong(mut socket: WebSocket) {
614 while let Some(maybe_message) = socket.recv().await {
615 let message_data = maybe_message.unwrap().into_data();
616 let decoded = rmp_serde::from_slice::<Value>(&message_data).unwrap();
617
618 let encoded_data = ::rmp_serde::to_vec(&json!({
619 "format": "binary",
620 "message": decoded
621 }))
622 .unwrap()
623 .into();
624
625 socket.send(Message::Binary(encoded_data)).await.unwrap();
626 }
627 }
628
629 ws.on_upgrade(move |socket| handle_ping_pong(socket))
630 }
631
632 let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
633 TestServer::builder().http_transport().build(app).unwrap()
634 }
635
636 #[tokio::test]
637 async fn it_should_ping_pong_msgpack_in_binary() {
638 let server = new_test_app();
639
640 let mut websocket = server
641 .get_websocket(&"/ws-ping-pong")
642 .await
643 .into_websocket()
644 .await;
645
646 websocket
647 .send_msgpack(&json!({
648 "hello": "world",
649 "numbers": [1, 2, 3],
650 }))
651 .await;
652
653 websocket
654 .assert_receive_msgpack(&json!({
655 "format": "binary",
656 "message": {
657 "hello": "world",
658 "numbers": [1, 2, 3],
659 },
660 }))
661 .await;
662 }
663}