1use crate::WsMessage;
2use anyhow::anyhow;
3use anyhow::Context;
4use anyhow::Result;
5use bytes::Bytes;
6use futures_util::sink::SinkExt;
7use futures_util::stream::StreamExt;
8use hyper::upgrade::Upgraded;
9use hyper_util::rt::TokioIo;
10use serde::de::DeserializeOwned;
11use serde::Serialize;
12use std::fmt::Debug;
13use std::fmt::Display;
14use tokio_tungstenite::tungstenite::protocol::Role;
15use tokio_tungstenite::WebSocketStream;
16
17#[cfg(feature = "pretty-assertions")]
18use pretty_assertions::assert_eq;
19
20#[derive(Debug)]
21pub struct TestWebSocket {
22 stream: WebSocketStream<TokioIo<Upgraded>>,
23}
24
25impl TestWebSocket {
26 pub(crate) async fn new(upgraded: Upgraded) -> Self {
27 let upgraded_io = TokioIo::new(upgraded);
28 let stream = WebSocketStream::from_raw_socket(upgraded_io, Role::Client, None).await;
29
30 Self { stream }
31 }
32
33 pub async fn close(mut self) {
34 self.stream
35 .close(None)
36 .await
37 .expect("Failed to close WebSocket stream");
38 }
39
40 pub async fn send_text<T>(&mut self, raw_text: T)
41 where
42 T: Display,
43 {
44 let text = raw_text.to_string();
45 self.send_message(WsMessage::Text(text.into())).await;
46 }
47
48 pub async fn send_json<J>(&mut self, body: &J)
49 where
50 J: ?Sized + Serialize,
51 {
52 let raw_json =
53 ::serde_json::to_string(body).expect("It should serialize the content into Json");
54
55 self.send_message(WsMessage::Text(raw_json.into())).await;
56 }
57
58 #[cfg(feature = "yaml")]
59 pub async fn send_yaml<Y>(&mut self, body: &Y)
60 where
61 Y: ?Sized + Serialize,
62 {
63 let raw_yaml =
64 ::serde_yaml::to_string(body).expect("It should serialize the content into Yaml");
65
66 self.send_message(WsMessage::Text(raw_yaml.into())).await;
67 }
68
69 #[cfg(feature = "msgpack")]
70 pub async fn send_msgpack<M>(&mut self, body: &M)
71 where
72 M: ?Sized + Serialize,
73 {
74 let body_bytes =
75 ::rmp_serde::to_vec(body).expect("It should serialize the content into MsgPack");
76
77 self.send_message(WsMessage::Binary(body_bytes.into()))
78 .await;
79 }
80
81 pub async fn send_message(&mut self, message: WsMessage) {
82 self.stream.send(message).await.unwrap();
83 }
84
85 #[must_use]
86 pub async fn receive_text(&mut self) -> String {
87 let message = self.receive_message().await;
88
89 message_to_text(message)
90 .context("Failed to read message as a String")
91 .unwrap()
92 }
93
94 #[must_use]
95 pub async fn receive_json<T>(&mut self) -> T
96 where
97 T: DeserializeOwned,
98 {
99 let bytes = self.receive_bytes().await;
100 serde_json::from_slice::<T>(&bytes)
101 .context("Failed to deserialize message as Json")
102 .unwrap()
103 }
104
105 #[cfg(feature = "yaml")]
106 #[must_use]
107 pub async fn receive_yaml<T>(&mut self) -> T
108 where
109 T: DeserializeOwned,
110 {
111 let bytes = self.receive_bytes().await;
112 serde_yaml::from_slice::<T>(&bytes)
113 .context("Failed to deserialize message as Yaml")
114 .unwrap()
115 }
116
117 #[cfg(feature = "msgpack")]
118 #[must_use]
119 pub async fn receive_msgpack<T>(&mut self) -> T
120 where
121 T: DeserializeOwned,
122 {
123 let received_bytes = self.receive_bytes().await;
124 rmp_serde::from_slice::<T>(&received_bytes)
125 .context("Failed to deserializing message as MsgPack")
126 .unwrap()
127 }
128
129 #[must_use]
130 pub async fn receive_bytes(&mut self) -> Bytes {
131 let message = self.receive_message().await;
132
133 message_to_bytes(message)
134 .context("Failed to read message as a Bytes")
135 .unwrap()
136 }
137
138 #[must_use]
139 pub async fn receive_message(&mut self) -> WsMessage {
140 self.maybe_receive_message()
141 .await
142 .expect("No message found on WebSocket stream")
143 }
144
145 pub async fn assert_receive_json<T>(&mut self, expected: &T)
146 where
147 T: DeserializeOwned + PartialEq<T> + Debug,
148 {
149 assert_eq!(*expected, self.receive_json::<T>().await);
150 }
151
152 pub async fn assert_receive_text<C>(&mut self, expected: C)
153 where
154 C: AsRef<str>,
155 {
156 let expected_contents = expected.as_ref();
157 assert_eq!(expected_contents, &self.receive_text().await);
158 }
159
160 pub async fn assert_receive_text_contains<C>(&mut self, expected: C)
161 where
162 C: AsRef<str>,
163 {
164 let expected_contents = expected.as_ref();
165 let received = self.receive_text().await;
166 let is_contained = received.contains(expected_contents);
167
168 assert!(
169 is_contained,
170 "Failed to find '{expected_contents}', received '{received}'"
171 );
172 }
173
174 #[cfg(feature = "yaml")]
175 pub async fn assert_receive_yaml<T>(&mut self, expected: &T)
176 where
177 T: DeserializeOwned + PartialEq<T> + Debug,
178 {
179 assert_eq!(*expected, self.receive_yaml::<T>().await);
180 }
181
182 #[cfg(feature = "msgpack")]
183 pub async fn assert_receive_msgpack<T>(&mut self, expected: &T)
184 where
185 T: DeserializeOwned + PartialEq<T> + Debug,
186 {
187 assert_eq!(*expected, self.receive_msgpack::<T>().await);
188 }
189
190 #[must_use]
191 async fn maybe_receive_message(&mut self) -> Option<WsMessage> {
192 let maybe_message = self.stream.next().await;
193
194 match maybe_message {
195 None => None,
196 Some(message_result) => {
197 let message =
198 message_result.expect("Failed to receive message from WebSocket stream");
199 Some(message)
200 }
201 }
202 }
203}
204
205fn message_to_text(message: WsMessage) -> Result<String> {
206 let text = match message {
207 WsMessage::Text(text) => text.to_string(),
208 WsMessage::Binary(data) => {
209 String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())?
210 }
211 WsMessage::Ping(data) => {
212 String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())?
213 }
214 WsMessage::Pong(data) => {
215 String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())?
216 }
217 WsMessage::Close(None) => String::new(),
218 WsMessage::Close(Some(frame)) => frame.reason.to_string(),
219 WsMessage::Frame(_) => {
220 return Err(anyhow!(
221 "Unexpected Frame, did not expect Frame message whilst reading"
222 ))
223 }
224 };
225
226 Ok(text)
227}
228
229fn message_to_bytes(message: WsMessage) -> Result<Bytes> {
230 let bytes = match message {
231 WsMessage::Text(string) => string.into(),
232 WsMessage::Binary(data) => data,
233 WsMessage::Ping(data) => data,
234 WsMessage::Pong(data) => data,
235 WsMessage::Close(None) => Bytes::new(),
236 WsMessage::Close(Some(frame)) => frame.reason.into(),
237 WsMessage::Frame(_) => {
238 return Err(anyhow!(
239 "Unexpected Frame, did not expect Frame message whilst reading"
240 ))
241 }
242 };
243
244 Ok(bytes)
245}
246
247#[cfg(test)]
248mod test_assert_receive_text {
249 use crate::TestServer;
250
251 use axum::extract::ws::Message;
252 use axum::extract::ws::WebSocket;
253 use axum::extract::WebSocketUpgrade;
254 use axum::response::Response;
255 use axum::routing::get;
256 use axum::Router;
257
258 fn new_test_app() -> TestServer {
259 pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
260 async fn handle_ping_pong(mut socket: WebSocket) {
261 while let Some(maybe_message) = socket.recv().await {
262 let message_text = maybe_message.unwrap().into_text().unwrap();
263
264 let encoded_text = format!("Text: {message_text}").try_into().unwrap();
265 let encoded_data = format!("Binary: {message_text}").into_bytes().into();
266
267 socket.send(Message::Text(encoded_text)).await.unwrap();
268 socket.send(Message::Binary(encoded_data)).await.unwrap();
269 }
270 }
271
272 ws.on_upgrade(move |socket| handle_ping_pong(socket))
273 }
274
275 let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
276 TestServer::builder().http_transport().build(app).unwrap()
277 }
278
279 #[tokio::test]
280 async fn it_should_ping_pong_text_in_text_and_binary() {
281 let server = new_test_app();
282
283 let mut websocket = server
284 .get_websocket(&"/ws-ping-pong")
285 .await
286 .into_websocket()
287 .await;
288
289 websocket.send_text("Hello World!").await;
290
291 websocket.assert_receive_text("Text: Hello World!").await;
292 websocket.assert_receive_text("Binary: Hello World!").await;
293 }
294
295 #[tokio::test]
296 async fn it_should_ping_pong_large_text_blobs() {
297 const LARGE_BLOB_SIZE: usize = 16777200; let large_blob = (0..LARGE_BLOB_SIZE).map(|_| "X").collect::<String>();
299
300 let server = new_test_app();
301 let mut websocket = server
302 .get_websocket(&"/ws-ping-pong")
303 .await
304 .into_websocket()
305 .await;
306
307 websocket.send_text(&large_blob).await;
308
309 websocket
310 .assert_receive_text(format!("Text: {large_blob}"))
311 .await;
312 websocket
313 .assert_receive_text(format!("Binary: {large_blob}"))
314 .await;
315 }
316
317 #[tokio::test]
318 #[should_panic]
319 async fn it_should_not_match_partial_text_match() {
320 let server = new_test_app();
321
322 let mut websocket = server
323 .get_websocket(&"/ws-ping-pong")
324 .await
325 .into_websocket()
326 .await;
327
328 websocket.send_text("Hello World!").await;
329 websocket.assert_receive_text("Hello World!").await;
330 }
331
332 #[tokio::test]
333 #[should_panic]
334 async fn it_should_not_match_different_text() {
335 let server = new_test_app();
336
337 let mut websocket = server
338 .get_websocket(&"/ws-ping-pong")
339 .await
340 .into_websocket()
341 .await;
342
343 websocket.send_text("Hello World!").await;
344 websocket.assert_receive_text("🦊").await;
345 }
346}
347
348#[cfg(test)]
349mod test_assert_receive_text_contains {
350 use crate::TestServer;
351
352 use axum::extract::ws::Message;
353 use axum::extract::ws::WebSocket;
354 use axum::extract::WebSocketUpgrade;
355 use axum::response::Response;
356 use axum::routing::get;
357 use axum::Router;
358
359 fn new_test_app() -> TestServer {
360 pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
361 async fn handle_ping_pong(mut socket: WebSocket) {
362 while let Some(maybe_message) = socket.recv().await {
363 let message_text = maybe_message.unwrap().into_text().unwrap();
364 let encoded_text = format!("Text: {message_text}").try_into().unwrap();
365
366 socket.send(Message::Text(encoded_text)).await.unwrap();
367 }
368 }
369
370 ws.on_upgrade(move |socket| handle_ping_pong(socket))
371 }
372
373 let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
374 TestServer::builder().http_transport().build(app).unwrap()
375 }
376
377 #[tokio::test]
378 async fn it_should_assert_whole_text_match() {
379 let server = new_test_app();
380
381 let mut websocket = server
382 .get_websocket(&"/ws-ping-pong")
383 .await
384 .into_websocket()
385 .await;
386
387 websocket.send_text("Hello World!").await;
388 websocket
389 .assert_receive_text_contains("Text: Hello World!")
390 .await;
391 }
392
393 #[tokio::test]
394 async fn it_should_assert_partial_text_match() {
395 let server = new_test_app();
396
397 let mut websocket = server
398 .get_websocket(&"/ws-ping-pong")
399 .await
400 .into_websocket()
401 .await;
402
403 websocket.send_text("Hello World!").await;
404 websocket.assert_receive_text_contains("Hello World!").await;
405 }
406
407 #[tokio::test]
408 #[should_panic]
409 async fn it_should_not_match_different_text() {
410 let server = new_test_app();
411
412 let mut websocket = server
413 .get_websocket(&"/ws-ping-pong")
414 .await
415 .into_websocket()
416 .await;
417
418 websocket.send_text("Hello World!").await;
419 websocket.assert_receive_text_contains("🦊").await;
420 }
421}
422
423#[cfg(test)]
424mod test_assert_receive_json {
425 use crate::TestServer;
426
427 use axum::extract::ws::Message;
428 use axum::extract::ws::WebSocket;
429 use axum::extract::WebSocketUpgrade;
430 use axum::response::Response;
431 use axum::routing::get;
432 use axum::Router;
433 use serde_json::json;
434 use serde_json::Value;
435
436 fn new_test_app() -> TestServer {
437 pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
438 async fn handle_ping_pong(mut socket: WebSocket) {
439 while let Some(maybe_message) = socket.recv().await {
440 let message_text = maybe_message.unwrap().into_text().unwrap();
441 let decoded = serde_json::from_str::<Value>(&message_text).unwrap();
442
443 let encoded_text = serde_json::to_string(&json!({
444 "format": "text",
445 "message": decoded
446 }))
447 .unwrap()
448 .try_into()
449 .unwrap();
450 let encoded_data = serde_json::to_vec(&json!({
451 "format": "binary",
452 "message": decoded
453 }))
454 .unwrap()
455 .into();
456
457 socket.send(Message::Text(encoded_text)).await.unwrap();
458 socket.send(Message::Binary(encoded_data)).await.unwrap();
459 }
460 }
461
462 ws.on_upgrade(move |socket| handle_ping_pong(socket))
463 }
464
465 let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
466 TestServer::builder().http_transport().build(app).unwrap()
467 }
468
469 #[tokio::test]
470 async fn it_should_ping_pong_json_in_text_and_binary() {
471 let server = new_test_app();
472
473 let mut websocket = server
474 .get_websocket(&"/ws-ping-pong")
475 .await
476 .into_websocket()
477 .await;
478
479 websocket
480 .send_json(&json!({
481 "hello": "world",
482 "numbers": [1, 2, 3],
483 }))
484 .await;
485
486 websocket
488 .assert_receive_json(&json!({
489 "format": "text",
490 "message": {
491 "hello": "world",
492 "numbers": [1, 2, 3],
493 },
494 }))
495 .await;
496
497 websocket
499 .assert_receive_json(&json!({
500 "format": "binary",
501 "message": {
502 "hello": "world",
503 "numbers": [1, 2, 3],
504 },
505 }))
506 .await;
507 }
508}
509
510#[cfg(feature = "yaml")]
511#[cfg(test)]
512mod test_assert_receive_yaml {
513 use crate::TestServer;
514
515 use axum::extract::ws::Message;
516 use axum::extract::ws::WebSocket;
517 use axum::extract::WebSocketUpgrade;
518 use axum::response::Response;
519 use axum::routing::get;
520 use axum::Router;
521 use serde_json::json;
522 use serde_json::Value;
523
524 fn new_test_app() -> TestServer {
525 pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
526 async fn handle_ping_pong(mut socket: WebSocket) {
527 while let Some(maybe_message) = socket.recv().await {
528 let message_text = maybe_message.unwrap().into_text().unwrap();
529 let decoded = serde_yaml::from_str::<Value>(&message_text).unwrap();
530
531 let encoded_text = serde_yaml::to_string(&json!({
532 "format": "text",
533 "message": decoded
534 }))
535 .unwrap()
536 .try_into()
537 .unwrap();
538 let encoded_data = serde_yaml::to_string(&json!({
539 "format": "binary",
540 "message": decoded
541 }))
542 .unwrap()
543 .into();
544
545 socket.send(Message::Text(encoded_text)).await.unwrap();
546 socket.send(Message::Binary(encoded_data)).await.unwrap();
547 }
548 }
549
550 ws.on_upgrade(move |socket| handle_ping_pong(socket))
551 }
552
553 let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
554 TestServer::builder().http_transport().build(app).unwrap()
555 }
556
557 #[tokio::test]
558 async fn it_should_ping_pong_yaml_in_text_and_binary() {
559 let server = new_test_app();
560
561 let mut websocket = server
562 .get_websocket(&"/ws-ping-pong")
563 .await
564 .into_websocket()
565 .await;
566
567 websocket
568 .send_json(&json!({
569 "hello": "world",
570 "numbers": [1, 2, 3],
571 }))
572 .await;
573
574 websocket
576 .assert_receive_yaml(&json!({
577 "format": "text",
578 "message": {
579 "hello": "world",
580 "numbers": [1, 2, 3],
581 },
582 }))
583 .await;
584
585 websocket
587 .assert_receive_yaml(&json!({
588 "format": "binary",
589 "message": {
590 "hello": "world",
591 "numbers": [1, 2, 3],
592 },
593 }))
594 .await;
595 }
596}
597
598#[cfg(feature = "msgpack")]
599#[cfg(test)]
600mod test_assert_receive_msgpack {
601 use crate::TestServer;
602
603 use axum::extract::ws::Message;
604 use axum::extract::ws::WebSocket;
605 use axum::extract::WebSocketUpgrade;
606 use axum::response::Response;
607 use axum::routing::get;
608 use axum::Router;
609 use serde_json::json;
610 use serde_json::Value;
611
612 fn new_test_app() -> TestServer {
613 pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response {
614 async fn handle_ping_pong(mut socket: WebSocket) {
615 while let Some(maybe_message) = socket.recv().await {
616 let message_data = maybe_message.unwrap().into_data();
617 let decoded = rmp_serde::from_slice::<Value>(&message_data).unwrap();
618
619 let encoded_data = ::rmp_serde::to_vec(&json!({
620 "format": "binary",
621 "message": decoded
622 }))
623 .unwrap()
624 .into();
625
626 socket.send(Message::Binary(encoded_data)).await.unwrap();
627 }
628 }
629
630 ws.on_upgrade(move |socket| handle_ping_pong(socket))
631 }
632
633 let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong));
634 TestServer::builder().http_transport().build(app).unwrap()
635 }
636
637 #[tokio::test]
638 async fn it_should_ping_pong_msgpack_in_binary() {
639 let server = new_test_app();
640
641 let mut websocket = server
642 .get_websocket(&"/ws-ping-pong")
643 .await
644 .into_websocket()
645 .await;
646
647 websocket
648 .send_msgpack(&json!({
649 "hello": "world",
650 "numbers": [1, 2, 3],
651 }))
652 .await;
653
654 websocket
655 .assert_receive_msgpack(&json!({
656 "format": "binary",
657 "message": {
658 "hello": "world",
659 "numbers": [1, 2, 3],
660 },
661 }))
662 .await;
663 }
664}