1use crate::TestRequest;
2use crate::TestRequestConfig;
3use crate::TestServerBuilder;
4use crate::TestServerConfig;
5use crate::Transport;
6use crate::internals::AtomicCrossCookieJar;
7use crate::internals::ErrorMessage;
8use crate::internals::ExpectedState;
9use crate::internals::QueryParamsStore;
10use crate::transport_layer::IntoTransportLayer;
11use crate::transport_layer::TransportLayer;
12use crate::transport_layer::TransportLayerBuilder;
13use anyhow::Result;
14use anyhow::anyhow;
15use cookie::Cookie;
16use cookie::CookieJar;
17use http::HeaderName;
18use http::HeaderValue;
19use http::Method;
20use http::Uri;
21use serde::Serialize;
22use std::fmt::Debug;
23use std::sync::Arc;
24use url::Url;
25
26#[cfg(feature = "typed-routing")]
27use axum_extra::routing::TypedPath;
28
29#[cfg(feature = "reqwest")]
30use crate::transport_layer::TransportLayerType;
31#[cfg(feature = "reqwest")]
32use reqwest::Client;
33#[cfg(feature = "reqwest")]
34use reqwest::RequestBuilder;
35#[cfg(feature = "reqwest")]
36use std::cell::OnceCell;
37
38mod server_shared_state;
39pub(crate) use self::server_shared_state::*;
40
41const DEFAULT_URL_ADDRESS: &str = "http://localhost";
42
43#[derive(Debug)]
142pub struct TestServer {
143 state: ServerSharedState,
144 cookie_jar: Arc<AtomicCrossCookieJar>,
145 transport: Arc<Box<dyn TransportLayer>>,
146 expected_state: ExpectedState,
147 default_content_type: Option<String>,
148 is_http_path_restricted: bool,
149
150 #[cfg(feature = "reqwest")]
151 maybe_reqwest_client: OnceCell<Client>,
152}
153
154impl TestServer {
155 pub fn builder() -> TestServerBuilder {
157 TestServerBuilder::default()
158 }
159
160 pub fn new<A>(app: A) -> Self
195 where
196 A: IntoTransportLayer,
197 {
198 Self::try_new(app).error_message("Failed to build TestServer")
199 }
200
201 pub fn try_new<A>(app: A) -> Result<Self>
203 where
204 A: IntoTransportLayer,
205 {
206 Self::try_new_with_config(app, TestServerConfig::default())
207 }
208
209 pub fn new_with_config<A, C>(app: A, config: C) -> Self
216 where
217 A: IntoTransportLayer,
218 C: Into<TestServerConfig>,
219 {
220 Self::try_new_with_config(app, config).error_message("Failed to build TestServer")
221 }
222
223 pub fn try_new_with_config<A, C>(app: A, config: C) -> Result<Self>
225 where
226 A: IntoTransportLayer,
227 C: Into<TestServerConfig>,
228 {
229 let config = config.into();
230 let state = ServerSharedState::new();
231
232 let transport = match config.transport {
233 None => {
234 let builder = TransportLayerBuilder::new(None, None);
235 let transport = app.into_default_transport(builder)?;
236 Arc::new(transport)
237 }
238 Some(Transport::HttpRandomPort) => {
239 let builder = TransportLayerBuilder::new(None, None);
240 let transport = app.into_http_transport_layer(builder)?;
241 Arc::new(transport)
242 }
243 Some(Transport::HttpIpPort { ip, port }) => {
244 let builder = TransportLayerBuilder::new(ip, port);
245 let transport = app.into_http_transport_layer(builder)?;
246 Arc::new(transport)
247 }
248 Some(Transport::MockHttp) => {
249 let transport = app.into_mock_transport_layer()?;
250 Arc::new(transport)
251 }
252 };
253
254 let expected_state = match config.expect_success_by_default {
255 true => ExpectedState::Success,
256 false => ExpectedState::None,
257 };
258
259 Ok(Self {
260 state,
261 cookie_jar: Arc::new(AtomicCrossCookieJar::new(config.save_cookies)),
262 transport,
263 expected_state,
264 default_content_type: config.default_content_type,
265 is_http_path_restricted: config.restrict_requests_with_http_scheme,
266
267 #[cfg(feature = "reqwest")]
268 maybe_reqwest_client: Default::default(),
269 })
270 }
271
272 pub fn get(&self, path: &str) -> TestRequest {
274 self.method(Method::GET, path)
275 }
276
277 pub fn post(&self, path: &str) -> TestRequest {
279 self.method(Method::POST, path)
280 }
281
282 pub fn patch(&self, path: &str) -> TestRequest {
284 self.method(Method::PATCH, path)
285 }
286
287 pub fn put(&self, path: &str) -> TestRequest {
289 self.method(Method::PUT, path)
290 }
291
292 pub fn delete(&self, path: &str) -> TestRequest {
294 self.method(Method::DELETE, path)
295 }
296
297 pub fn method(&self, method: Method, path: &str) -> TestRequest {
299 let config = self
300 .build_test_request_config(method.clone(), path)
301 .error_message_fn(|| format!("Failed to build request, for {method} {path}"));
302
303 TestRequest::new(self.transport.clone(), config)
304 }
305
306 #[cfg(feature = "reqwest")]
307 fn reqwest_client(&self) -> &Client {
308 self.maybe_reqwest_client.get_or_init(|| {
309 if self.transport.transport_layer_type() == TransportLayerType::Mock {
310 panic!("Reqwest client is not available, TestServer must be build with HTTP transport for Reqwest to be available");
311 }
312
313 reqwest::Client::builder()
314 .redirect(reqwest::redirect::Policy::none())
315 .cookie_provider(self.cookie_jar.clone())
316 .build()
317 .expect("Failed to build Reqwest Client")
318 })
319 }
320
321 #[cfg(feature = "reqwest")]
322 pub fn reqwest_get(&self, path: &str) -> RequestBuilder {
323 self.reqwest_method(Method::GET, path)
324 }
325
326 #[cfg(feature = "reqwest")]
327 pub fn reqwest_post(&self, path: &str) -> RequestBuilder {
328 self.reqwest_method(Method::POST, path)
329 }
330
331 #[cfg(feature = "reqwest")]
332 pub fn reqwest_put(&self, path: &str) -> RequestBuilder {
333 self.reqwest_method(Method::PUT, path)
334 }
335
336 #[cfg(feature = "reqwest")]
337 pub fn reqwest_patch(&self, path: &str) -> RequestBuilder {
338 self.reqwest_method(Method::PATCH, path)
339 }
340
341 #[cfg(feature = "reqwest")]
342 pub fn reqwest_delete(&self, path: &str) -> RequestBuilder {
343 self.reqwest_method(Method::DELETE, path)
344 }
345
346 #[cfg(feature = "reqwest")]
347 pub fn reqwest_head(&self, path: &str) -> RequestBuilder {
348 self.reqwest_method(Method::HEAD, path)
349 }
350
351 #[cfg(feature = "reqwest")]
376 pub fn reqwest_method(&self, method: Method, path: &str) -> RequestBuilder {
377 let request_url = self
378 .server_url(path)
379 .expect("Failed to generate server url for request {method} {path}");
380
381 self.reqwest_client().request(method, request_url)
382 }
383
384 #[cfg(feature = "ws")]
419 pub fn get_websocket(&self, path: &str) -> TestRequest {
420 use http::header;
421
422 self.get(path)
423 .add_header(header::CONNECTION, "upgrade")
424 .add_header(header::UPGRADE, "websocket")
425 .add_header(header::SEC_WEBSOCKET_VERSION, "13")
426 .add_header(
427 header::SEC_WEBSOCKET_KEY,
428 crate::internals::generate_ws_key(),
429 )
430 }
431
432 #[cfg(feature = "typed-routing")]
479 pub fn typed_get<P>(&self, path: &P) -> TestRequest
480 where
481 P: TypedPath,
482 {
483 self.typed_method(Method::GET, path)
484 }
485
486 #[cfg(feature = "typed-routing")]
490 pub fn typed_post<P>(&self, path: &P) -> TestRequest
491 where
492 P: TypedPath,
493 {
494 self.typed_method(Method::POST, path)
495 }
496
497 #[cfg(feature = "typed-routing")]
501 pub fn typed_patch<P>(&self, path: &P) -> TestRequest
502 where
503 P: TypedPath,
504 {
505 self.typed_method(Method::PATCH, path)
506 }
507
508 #[cfg(feature = "typed-routing")]
512 pub fn typed_put<P>(&self, path: &P) -> TestRequest
513 where
514 P: TypedPath,
515 {
516 self.typed_method(Method::PUT, path)
517 }
518
519 #[cfg(feature = "typed-routing")]
523 pub fn typed_delete<P>(&self, path: &P) -> TestRequest
524 where
525 P: TypedPath,
526 {
527 self.typed_method(Method::DELETE, path)
528 }
529
530 #[cfg(feature = "typed-routing")]
534 pub fn typed_method<P>(&self, method: Method, path: &P) -> TestRequest
535 where
536 P: TypedPath,
537 {
538 self.method(method, &path.to_string())
539 }
540
541 pub fn server_address(&self) -> Option<Url> {
549 self.url()
550 }
551
552 pub fn server_url(&self, path: &str) -> Result<Url> {
585 let path_uri = path.parse::<Uri>()?;
586 if is_absolute_uri(&path_uri) {
587 return Err(anyhow!(
588 "Absolute path provided for building server url, need to provide a relative uri"
589 ));
590 }
591
592 let server_url = self.url()
593 .ok_or_else(||
594 anyhow!(
595 "No local address for server, need to run with HTTP transport to have a server address",
596 )
597 )?;
598
599 let mut query_params = self.state.query_params().clone();
600 let mut full_server_url = build_url(
601 server_url,
602 path,
603 &mut query_params,
604 self.is_http_path_restricted,
605 )?;
606
607 if query_params.has_content() {
609 full_server_url.set_query(Some(&query_params.to_string()));
610 }
611
612 Ok(full_server_url)
613 }
614
615 pub fn add_cookie(&mut self, cookie: Cookie) {
620 self.cookie_jar.add_cookie(cookie);
621 }
622
623 pub fn add_cookies(&mut self, cookies: CookieJar) {
628 self.cookie_jar.add_cookies_by_jar(cookies);
629 }
630
631 pub fn clear_cookies(&mut self) {
633 self.cookie_jar.clear_cookies();
634 }
635
636 pub fn save_cookies(&mut self) {
641 self.cookie_jar.enable_saving();
642 }
643
644 pub fn do_not_save_cookies(&mut self) {
649 self.cookie_jar.disable_saving();
650 }
651
652 pub fn expect_success(&mut self) {
656 self.expected_state = ExpectedState::Success;
657 }
658
659 pub fn expect_failure(&mut self) {
663 self.expected_state = ExpectedState::Failure;
664 }
665
666 pub fn add_query_param<V>(&mut self, key: &str, value: V)
668 where
669 V: Serialize,
670 {
671 self.state
672 .add_query_param(key, value)
673 .error_message("Failed to add query parameter");
674 }
675
676 pub fn add_query_params<V>(&mut self, query_params: V)
678 where
679 V: Serialize,
680 {
681 self.state
682 .add_query_params(query_params)
683 .error_message("Failed to add query parameters");
684 }
685
686 pub fn add_raw_query_param(&mut self, raw_query_param: &str) {
689 self.state.add_raw_query_param(raw_query_param);
690 }
691
692 pub fn clear_query_params(&mut self) {
694 self.state.clear_query_params();
695 }
696
697 pub fn add_header<N, V>(&mut self, name: N, value: V)
718 where
719 N: TryInto<HeaderName>,
720 N::Error: Debug,
721 V: TryInto<HeaderValue>,
722 V::Error: Debug,
723 {
724 let header_name: HeaderName = name
725 .try_into()
726 .expect("Failed to convert header name to HeaderName");
727 let header_value: HeaderValue = value
728 .try_into()
729 .expect("Failed to convert header vlue to HeaderValue");
730
731 self.state.add_header(header_name, header_value);
732 }
733
734 pub fn clear_headers(&mut self) {
736 self.state.clear_headers();
737 }
738
739 pub(crate) fn url(&self) -> Option<Url> {
740 self.transport.url().cloned()
741 }
742
743 pub(crate) fn build_test_request_config(
744 &self,
745 method: Method,
746 path: &str,
747 ) -> Result<TestRequestConfig> {
748 let url = self
749 .url()
750 .unwrap_or_else(|| DEFAULT_URL_ADDRESS.parse().unwrap());
751
752 let mut query_params = self.state.query_params().clone();
753 let headers = self.state.headers().clone();
754 let full_request_url =
755 build_url(url, path, &mut query_params, self.is_http_path_restricted)?;
756
757 Ok(TestRequestConfig {
758 atomic_cookie_jar: self.cookie_jar.clone(),
759
760 is_saving_cookies: self.cookie_jar.is_saving(),
763 cookies: self.cookie_jar.to_cookie_jar(),
764
765 expected_state: self.expected_state,
766 content_type: self.default_content_type.clone(),
767 method,
768
769 full_request_url,
770 query_params,
771 headers,
772 })
773 }
774
775 pub fn is_running(&self) -> bool {
781 self.transport.is_running()
782 }
783}
784
785fn build_url(
786 mut url: Url,
787 path: &str,
788 query_params: &mut QueryParamsStore,
789 is_http_restricted: bool,
790) -> Result<Url> {
791 let path_uri = path.parse::<Uri>()?;
792
793 if let Some(scheme) = path_uri.scheme_str() {
795 if is_http_restricted {
796 if has_different_scheme(&url, &path_uri) || has_different_authority(&url, &path_uri) {
797 return Err(anyhow!(
798 "Request disallowed for path '{path}', requests are only allowed to local server. Turn off 'restrict_requests_with_http_scheme' to change this."
799 ));
800 }
801 } else {
802 url.set_scheme(scheme)
803 .map_err(|_| anyhow!("Failed to set scheme for request, with path '{path}'"))?;
804
805 if let Some(authority) = path_uri.authority() {
807 url.set_host(Some(authority.host()))
808 .map_err(|_| anyhow!("Failed to set host for request, with path '{path}'"))?;
809 url.set_port(authority.port().map(|p| p.as_u16()))
810 .map_err(|_| anyhow!("Failed to set port for request, with path '{path}'"))?;
811
812 }
814 }
815 }
816
817 if is_absolute_uri(&path_uri) {
827 url.set_path(path_uri.path());
828
829 if url.query().is_some() {
831 url.set_query(None);
832 }
833 } else {
834 let calculated_path = path.split('?').next().unwrap_or(path);
836 url.set_path(calculated_path);
837
838 if let Some(url_query) = url.query() {
840 query_params.add_raw(url_query.to_string());
841 url.set_query(None);
842 }
843 }
844
845 if let Some(path_query) = path_uri.query() {
846 query_params.add_raw(path_query.to_string());
847 }
848
849 Ok(url)
850}
851
852fn is_absolute_uri(path_uri: &Uri) -> bool {
853 path_uri.scheme_str().is_some()
854}
855
856fn has_different_scheme(base_url: &Url, path_uri: &Uri) -> bool {
857 if let Some(scheme) = path_uri.scheme_str() {
858 return scheme != base_url.scheme();
859 }
860
861 false
862}
863
864fn has_different_authority(base_url: &Url, path_uri: &Uri) -> bool {
865 if let Some(authority) = path_uri.authority() {
866 return authority.as_str() != base_url.authority();
867 }
868
869 false
870}
871
872#[cfg(test)]
873mod test_build_url {
874 use super::*;
875
876 #[test]
877 fn it_should_copy_path_to_url_returned_when_restricted() {
878 let base_url = "http://example.com".parse::<Url>().unwrap();
879 let path = "/users";
880 let mut query_params = QueryParamsStore::new();
881 let result = build_url(base_url, &path, &mut query_params, true).unwrap();
882
883 assert_eq!("http://example.com/users", result.as_str());
884 assert!(query_params.is_empty());
885 }
886
887 #[test]
888 fn it_should_copy_all_query_params_to_store_when_restricted() {
889 let base_url = "http://example.com?base=aaa".parse::<Url>().unwrap();
890 let path = "/users?path=bbb&path-flag";
891 let mut query_params = QueryParamsStore::new();
892 let result = build_url(base_url, &path, &mut query_params, true).unwrap();
893
894 assert_eq!("http://example.com/users", result.as_str());
895 assert_eq!("base=aaa&path=bbb&path-flag", query_params.to_string());
896 }
897
898 #[test]
899 fn it_should_not_replace_url_when_restricted_with_different_scheme() {
900 let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
901 let path = "ftp://google.com:123/users.csv?limit=456";
902 let mut query_params = QueryParamsStore::new();
903 let result = build_url(base_url, &path, &mut query_params, true);
904
905 assert!(result.is_err());
906 }
907
908 #[test]
909 fn it_should_not_replace_url_when_restricted_with_same_scheme() {
910 let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
911 let path = "http://google.com:123/users.csv?limit=456";
912 let mut query_params = QueryParamsStore::new();
913 let result = build_url(base_url, &path, &mut query_params, true);
914
915 assert!(result.is_err());
916 }
917
918 #[test]
919 fn it_should_block_url_when_restricted_with_same_scheme() {
920 let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
921 let path = "http://google.com";
922 let mut query_params = QueryParamsStore::new();
923 let result = build_url(base_url, &path, &mut query_params, true);
924
925 assert!(result.is_err());
926 }
927
928 #[test]
929 fn it_should_block_url_when_restricted_and_same_domain_with_different_scheme() {
930 let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
931 let path = "ftp://example.com/users";
932 let mut query_params = QueryParamsStore::new();
933 let result = build_url(base_url, &path, &mut query_params, true);
934
935 assert!(result.is_err());
936 }
937
938 #[test]
939 fn it_should_copy_path_to_url_returned_when_unrestricted() {
940 let base_url = "http://example.com".parse::<Url>().unwrap();
941 let path = "/users";
942 let mut query_params = QueryParamsStore::new();
943 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
944
945 assert_eq!("http://example.com/users", result.as_str());
946 assert!(query_params.is_empty());
947 }
948
949 #[test]
950 fn it_should_copy_all_query_params_to_store_when_unrestricted() {
951 let base_url = "http://example.com?base=aaa".parse::<Url>().unwrap();
952 let path = "/users?path=bbb&path-flag";
953 let mut query_params = QueryParamsStore::new();
954 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
955
956 assert_eq!("http://example.com/users", result.as_str());
957 assert_eq!("base=aaa&path=bbb&path-flag", query_params.to_string());
958 }
959
960 #[test]
961 fn it_should_copy_host_like_a_path_when_unrestricted() {
962 let base_url = "http://example.com".parse::<Url>().unwrap();
963 let path = "google.com";
964 let mut query_params = QueryParamsStore::new();
965 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
966
967 assert_eq!("http://example.com/google.com", result.as_str());
968 assert!(query_params.is_empty());
969 }
970
971 #[test]
972 fn it_should_copy_host_like_a_path_when_restricted() {
973 let base_url = "http://example.com".parse::<Url>().unwrap();
974 let path = "google.com";
975 let mut query_params = QueryParamsStore::new();
976 let result = build_url(base_url, &path, &mut query_params, true).unwrap();
977
978 assert_eq!("http://example.com/google.com", result.as_str());
979 assert!(query_params.is_empty());
980 }
981
982 #[test]
983 fn it_should_replace_url_when_unrestricted() {
984 let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
985 let path = "ftp://google.com:123/users.csv?limit=456";
986 let mut query_params = QueryParamsStore::new();
987 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
988
989 assert_eq!("ftp://google.com:123/users.csv", result.as_str());
990 assert_eq!("limit=456", query_params.to_string());
991 }
992
993 #[test]
994 fn it_should_allow_different_scheme_when_unrestricted() {
995 let base_url = "http://example.com".parse::<Url>().unwrap();
996 let path = "ftp://example.com";
997 let mut query_params = QueryParamsStore::new();
998 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
999
1000 assert_eq!("ftp://example.com/", result.as_str());
1001 }
1002
1003 #[test]
1004 fn it_should_allow_different_host_when_unrestricted() {
1005 let base_url = "http://example.com".parse::<Url>().unwrap();
1006 let path = "http://google.com";
1007 let mut query_params = QueryParamsStore::new();
1008 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1009
1010 assert_eq!("http://google.com/", result.as_str());
1011 }
1012
1013 #[test]
1014 fn it_should_allow_different_port_when_unrestricted() {
1015 let base_url = "http://example.com:123".parse::<Url>().unwrap();
1016 let path = "http://example.com:456";
1017 let mut query_params = QueryParamsStore::new();
1018 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1019
1020 assert_eq!("http://example.com:456/", result.as_str());
1021 }
1022
1023 #[test]
1024 fn it_should_allow_same_host_port_when_unrestricted() {
1025 let base_url = "http://example.com:123".parse::<Url>().unwrap();
1026 let path = "http://example.com:123";
1027 let mut query_params = QueryParamsStore::new();
1028 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1029
1030 assert_eq!("http://example.com:123/", result.as_str());
1031 }
1032
1033 #[test]
1034 fn it_should_not_allow_different_scheme_when_restricted() {
1035 let base_url = "http://example.com".parse::<Url>().unwrap();
1036 let path = "ftp://example.com";
1037 let mut query_params = QueryParamsStore::new();
1038 let result = build_url(base_url, &path, &mut query_params, true);
1039
1040 assert!(result.is_err());
1041 }
1042
1043 #[test]
1044 fn it_should_not_allow_different_host_when_restricted() {
1045 let base_url = "http://example.com".parse::<Url>().unwrap();
1046 let path = "http://google.com";
1047 let mut query_params = QueryParamsStore::new();
1048 let result = build_url(base_url, &path, &mut query_params, true);
1049
1050 assert!(result.is_err());
1051 }
1052
1053 #[test]
1054 fn it_should_not_allow_different_port_when_restricted() {
1055 let base_url = "http://example.com:123".parse::<Url>().unwrap();
1056 let path = "http://example.com:456";
1057 let mut query_params = QueryParamsStore::new();
1058 let result = build_url(base_url, &path, &mut query_params, true);
1059
1060 assert!(result.is_err());
1061 }
1062
1063 #[test]
1064 fn it_should_allow_same_host_port_when_restricted() {
1065 let base_url = "http://example.com:123".parse::<Url>().unwrap();
1066 let path = "http://example.com:123";
1067 let mut query_params = QueryParamsStore::new();
1068 let result = build_url(base_url, &path, &mut query_params, true).unwrap();
1069
1070 assert_eq!("http://example.com:123/", result.as_str());
1071 }
1072}
1073
1074#[cfg(test)]
1075mod test_new {
1076 use axum::Router;
1077 use axum::routing::get;
1078 use std::net::SocketAddr;
1079
1080 use crate::TestServer;
1081
1082 async fn get_ping() -> &'static str {
1083 "pong!"
1084 }
1085
1086 #[tokio::test]
1087 async fn it_should_run_into_make_into_service_with_connect_info_by_default() {
1088 let app = Router::new()
1090 .route("/ping", get(get_ping))
1091 .into_make_service_with_connect_info::<SocketAddr>();
1092
1093 let server = TestServer::new(app);
1095
1096 server.get(&"/ping").await.assert_text(&"pong!");
1098 }
1099}
1100
1101#[cfg(test)]
1102mod test_get {
1103 use super::*;
1104 use crate::testing::catch_panic_error_message;
1105 use axum::Router;
1106 use axum::routing::get;
1107 use pretty_assertions::assert_str_eq;
1108 use reserve_port::ReservedSocketAddr;
1109
1110 async fn get_ping() -> &'static str {
1111 "pong!"
1112 }
1113
1114 #[tokio::test]
1115 async fn it_should_get_using_relative_path_with_slash() {
1116 let app = Router::new().route("/ping", get(get_ping));
1117 let server = TestServer::new(app);
1118
1119 server.get(&"/ping").await.assert_text(&"pong!");
1121 }
1122
1123 #[tokio::test]
1124 async fn it_should_get_using_relative_path_without_slash() {
1125 let app = Router::new().route("/ping", get(get_ping));
1126 let server = TestServer::new(app);
1127
1128 server.get(&"ping").await.assert_text(&"pong!");
1130 }
1131
1132 #[tokio::test]
1133 async fn it_should_get_using_absolute_path() {
1134 let app = Router::new().route("/ping", get(get_ping));
1136
1137 let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1139 let ip = reserved_address.ip();
1140 let port = reserved_address.port();
1141
1142 let server = TestServer::builder()
1144 .http_transport_with_ip_port(Some(ip), Some(port))
1145 .try_build(app)
1146 .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1147
1148 let absolute_url = format!("http://{ip}:{port}/ping");
1150 let response = server.get(&absolute_url).await;
1151
1152 response.assert_text(&"pong!");
1153 let request_path = response.request_url();
1154 assert_eq!(request_path.to_string(), format!("http://{ip}:{port}/ping"));
1155 }
1156
1157 #[tokio::test]
1158 async fn it_should_get_using_absolute_path_and_restricted_if_path_is_for_server() {
1159 let app = Router::new().route("/ping", get(get_ping));
1161
1162 let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1164 let ip = reserved_address.ip();
1165 let port = reserved_address.port();
1166
1167 let server = TestServer::builder()
1169 .http_transport_with_ip_port(Some(ip), Some(port))
1170 .restrict_requests_with_http_scheme() .try_build(app)
1172 .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1173
1174 let absolute_url = format!("http://{ip}:{port}/ping");
1176 let response = server.get(&absolute_url).await;
1177
1178 response.assert_text(&"pong!");
1179 let request_path = response.request_url();
1180 assert_eq!(request_path.to_string(), format!("http://{ip}:{port}/ping"));
1181 }
1182
1183 #[tokio::test]
1184 async fn it_should_not_get_using_absolute_path_if_restricted_and_different_port() {
1185 let app = Router::new().route("/ping", get(get_ping));
1187
1188 let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1190 let ip = reserved_address.ip();
1191 let mut port = reserved_address.port();
1192
1193 let server = TestServer::builder()
1195 .http_transport_with_ip_port(Some(ip), Some(port))
1196 .restrict_requests_with_http_scheme() .try_build(app)
1198 .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1199
1200 port += 1; let absolute_url = format!("http://{ip}:{port}/ping");
1203
1204 let message = catch_panic_error_message(|| {
1205 let _ = server.get(&absolute_url);
1206 });
1207
1208 let expected = format!("Failed to build request, for GET http://{ip}:{port}/ping,
1209 Request disallowed for path 'http://{ip}:{port}/ping', requests are only allowed to local server. Turn off 'restrict_requests_with_http_scheme' to change this.
1210");
1211 assert_str_eq!(expected, message);
1212 }
1213
1214 #[tokio::test]
1215 async fn it_should_work_in_parallel() {
1216 let app = Router::new().route("/ping", get(get_ping));
1217 let server = TestServer::new(app);
1218
1219 let future1 = async { server.get("/ping").await };
1220 let future2 = async { server.get("/ping").await };
1221 let (r1, r2) = tokio::join!(future1, future2);
1222
1223 assert_eq!(r1.text(), r2.text());
1224 }
1225
1226 #[tokio::test]
1227 async fn it_should_work_in_parallel_with_sleeping_requests() {
1228 let app = axum::Router::new().route(
1229 &"/slow",
1230 axum::routing::get(|| async {
1231 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1232 "hello!"
1233 }),
1234 );
1235
1236 let server = TestServer::new(app);
1237
1238 let future1 = async { server.get("/slow").await };
1239 let future2 = async { server.get("/slow").await };
1240 let (r1, r2) = tokio::join!(future1, future2);
1241
1242 assert_eq!(r1.text(), r2.text());
1243 }
1244}
1245
1246#[cfg(feature = "reqwest")]
1247#[cfg(test)]
1248mod test_reqwest_get {
1249 use super::*;
1250 use axum::Router;
1251 use axum::routing::get;
1252
1253 async fn get_ping() -> &'static str {
1254 "pong!"
1255 }
1256
1257 #[tokio::test]
1258 async fn it_should_get_using_relative_path_with_slash() {
1259 let app = Router::new().route("/ping", get(get_ping));
1260 let server = TestServer::builder().http_transport().build(app);
1261
1262 let response = server
1263 .reqwest_get(&"/ping")
1264 .send()
1265 .await
1266 .unwrap()
1267 .text()
1268 .await
1269 .unwrap();
1270
1271 assert_eq!(response, "pong!");
1272 }
1273}
1274
1275#[cfg(feature = "reqwest")]
1276#[cfg(test)]
1277mod test_reqwest_post {
1278 use super::*;
1279 use axum::Json;
1280 use axum::Router;
1281 use axum::routing::post;
1282 use serde::Deserialize;
1283
1284 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1285 struct TestBody {
1286 number: u32,
1287 text: String,
1288 }
1289
1290 async fn post_json(Json(body): Json<TestBody>) -> Json<TestBody> {
1291 let response = TestBody {
1292 number: body.number * 2,
1293 text: format!("{}_plus_response", body.text),
1294 };
1295
1296 Json(response)
1297 }
1298
1299 #[tokio::test]
1300 async fn it_should_post_and_receive_json() {
1301 let app = Router::new().route("/json", post(post_json));
1302 let server = TestServer::builder().http_transport().build(app);
1303
1304 let response = server
1305 .reqwest_post(&"/json")
1306 .json(&TestBody {
1307 number: 111,
1308 text: format!("request"),
1309 })
1310 .send()
1311 .await
1312 .unwrap()
1313 .json::<TestBody>()
1314 .await
1315 .unwrap();
1316
1317 assert_eq!(
1318 response,
1319 TestBody {
1320 number: 222,
1321 text: format!("request_plus_response"),
1322 }
1323 );
1324 }
1325}
1326
1327#[cfg(test)]
1328mod test_server_address {
1329 use super::*;
1330 use axum::Router;
1331 use regex::Regex;
1332 use reserve_port::ReservedPort;
1333 use std::net::Ipv4Addr;
1334
1335 #[tokio::test]
1336 async fn it_should_return_address_used_from_config() {
1337 let reserved_port = ReservedPort::random().unwrap();
1338 let ip = Ipv4Addr::LOCALHOST.into();
1339 let port = reserved_port.port();
1340
1341 let app = Router::new();
1343 let server = TestServer::builder()
1344 .http_transport_with_ip_port(Some(ip), Some(port))
1345 .try_build(app)
1346 .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1347
1348 let expected_ip_port = format!("http://{}:{}/", ip, reserved_port.port());
1349 assert_eq!(
1350 server.server_address().unwrap().to_string(),
1351 expected_ip_port
1352 );
1353 }
1354
1355 #[tokio::test]
1356 async fn it_should_return_default_address_without_ending_slash() {
1357 let app = Router::new();
1358 let server = TestServer::builder().http_transport().build(app);
1359
1360 let address_regex = Regex::new("^http://127\\.0\\.0\\.1:[0-9]+/$").unwrap();
1361 let is_match = address_regex.is_match(&server.server_address().unwrap().to_string());
1362 assert!(is_match);
1363 }
1364
1365 #[tokio::test]
1366 async fn it_should_return_none_on_mock_transport() {
1367 let app = Router::new();
1368 let server = TestServer::builder().mock_transport().build(app);
1369
1370 assert!(server.server_address().is_none());
1371 }
1372}
1373
1374#[cfg(test)]
1375mod test_server_url {
1376 use super::*;
1377 use axum::Router;
1378 use pretty_assertions::assert_str_eq;
1379 use regex::Regex;
1380 use reserve_port::ReservedPort;
1381 use std::net::Ipv4Addr;
1382
1383 #[tokio::test]
1384 async fn it_should_return_address_with_url_on_http_ip_port() {
1385 let reserved_port = ReservedPort::random().unwrap();
1386 let ip = Ipv4Addr::LOCALHOST.into();
1387 let port = reserved_port.port();
1388
1389 let app = Router::new();
1391 let server = TestServer::builder()
1392 .http_transport_with_ip_port(Some(ip), Some(port))
1393 .try_build(app)
1394 .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1395
1396 let expected_ip_port_url = format!("http://{}:{}/users", ip, reserved_port.port());
1397 let absolute_url = server.server_url("/users").unwrap().to_string();
1398 assert_eq!(expected_ip_port_url, absolute_url);
1399 }
1400
1401 #[tokio::test]
1402 async fn it_should_return_address_with_url_on_random_http() {
1403 let app = Router::new();
1404 let server = TestServer::builder().http_transport().build(app);
1405
1406 let address_regex =
1407 Regex::new("^http://127\\.0\\.0\\.1:[0-9]+/users/123\\?filter=enabled$").unwrap();
1408 let absolute_url = &server
1409 .server_url(&"/users/123?filter=enabled")
1410 .unwrap()
1411 .to_string();
1412
1413 let is_match = address_regex.is_match(absolute_url);
1414 assert!(is_match);
1415 }
1416
1417 #[tokio::test]
1418 async fn it_should_error_on_mock_transport() {
1419 let app = Router::new();
1421 let server = TestServer::builder().mock_transport().build(app);
1422
1423 let result = server.server_url("/users");
1424 assert!(result.is_err());
1425 }
1426
1427 #[tokio::test]
1428 async fn it_should_include_path_query_params() {
1429 let reserved_port = ReservedPort::random().unwrap();
1430 let ip = Ipv4Addr::LOCALHOST.into();
1431 let port = reserved_port.port();
1432
1433 let app = Router::new();
1435 let server = TestServer::builder()
1436 .http_transport_with_ip_port(Some(ip), Some(port))
1437 .try_build(app)
1438 .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1439
1440 let expected_url = format!(
1441 "http://{}:{}/users?filter=enabled",
1442 ip,
1443 reserved_port.port()
1444 );
1445 let received_url = server
1446 .server_url("/users?filter=enabled")
1447 .unwrap()
1448 .to_string();
1449
1450 assert_eq!(expected_url, received_url);
1451 }
1452
1453 #[tokio::test]
1454 async fn it_should_include_server_query_params() {
1455 let reserved_port = ReservedPort::random().unwrap();
1456 let ip = Ipv4Addr::LOCALHOST.into();
1457 let port = reserved_port.port();
1458
1459 let app = Router::new();
1461 let mut server = TestServer::builder()
1462 .http_transport_with_ip_port(Some(ip), Some(port))
1463 .try_build(app)
1464 .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1465
1466 server.add_query_param("filter", "enabled");
1467
1468 let expected_url = format!(
1469 "http://{}:{}/users?filter=enabled",
1470 ip,
1471 reserved_port.port()
1472 );
1473 let received_url = server.server_url("/users").unwrap().to_string();
1474
1475 assert_eq!(expected_url, received_url);
1476 }
1477
1478 #[tokio::test]
1479 async fn it_should_include_server_and_path_query_params() {
1480 let reserved_port = ReservedPort::random().unwrap();
1481 let ip = Ipv4Addr::LOCALHOST.into();
1482 let port = reserved_port.port();
1483
1484 let app = Router::new();
1486 let mut server = TestServer::builder()
1487 .http_transport_with_ip_port(Some(ip), Some(port))
1488 .try_build(app)
1489 .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1490
1491 server.add_query_param("filter", "enabled");
1492
1493 let expected_url = format!(
1494 "http://{}:{}/users?filter=enabled&animal=donkeys",
1495 ip,
1496 reserved_port.port()
1497 );
1498 let received_url = server
1499 .server_url("/users?animal=donkeys")
1500 .unwrap()
1501 .to_string();
1502
1503 assert_eq!(expected_url, received_url);
1504 }
1505
1506 #[tokio::test]
1507 async fn it_should_include_both_server_and_path_queries() {
1508 let reserved_port = ReservedPort::random().unwrap();
1509 let ip = Ipv4Addr::LOCALHOST.into();
1510 let port = reserved_port.port();
1511
1512 let app = Router::new();
1514 let mut server = TestServer::builder()
1515 .http_transport_with_ip_port(Some(ip), Some(port))
1516 .try_build(app)
1517 .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1518
1519 server.add_query_param("query", "server");
1520
1521 let expected_url = format!(
1522 "http://{}:{}/users?query=server&query=path",
1523 ip,
1524 reserved_port.port()
1525 );
1526 let received_url = server.server_url("/users?query=path").unwrap().to_string();
1527
1528 assert_eq!(expected_url, received_url);
1529 }
1530
1531 #[tokio::test]
1532 async fn it_should_work_for_paths_with_leading_slash() {
1533 let reserved_port = ReservedPort::random().unwrap();
1534 let ip = Ipv4Addr::LOCALHOST.into();
1535 let port = reserved_port.port();
1536
1537 let app = Router::new();
1539 let server = TestServer::builder()
1540 .http_transport_with_ip_port(Some(ip), Some(port))
1541 .try_build(app)
1542 .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1543
1544 let expected_url = format!("http://{}:{}/users", ip, reserved_port.port());
1545 let received_url = server.server_url("users").unwrap().to_string();
1546
1547 assert_eq!(expected_url, received_url);
1548 }
1549
1550 #[tokio::test]
1552 async fn it_should_panic_when_provided_an_empty_path() {
1553 let reserved_port = ReservedPort::random().unwrap();
1554 let ip = Ipv4Addr::LOCALHOST.into();
1555 let port = reserved_port.port();
1556
1557 let app = Router::new();
1559 let server = TestServer::builder()
1560 .http_transport_with_ip_port(Some(ip), Some(port))
1561 .try_build(app)
1562 .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1563
1564 let error_message = server.server_url("").unwrap_err().to_string();
1566
1567 assert_str_eq!("empty string", error_message);
1568 }
1569}
1570
1571#[cfg(test)]
1572mod test_add_cookie {
1573 use crate::TestServer;
1574 use axum::Router;
1575 use axum::routing::get;
1576 use axum_extra::extract::cookie::CookieJar;
1577 use cookie::Cookie;
1578
1579 const TEST_COOKIE_NAME: &'static str = &"test-cookie";
1580
1581 async fn get_cookie(cookies: CookieJar) -> (CookieJar, String) {
1582 let cookie = cookies.get(&TEST_COOKIE_NAME);
1583 let cookie_value = cookie
1584 .map(|c| c.value().to_string())
1585 .unwrap_or_else(|| "cookie-not-found".to_string());
1586
1587 (cookies, cookie_value)
1588 }
1589
1590 #[tokio::test]
1591 async fn it_should_send_cookies_added_to_request() {
1592 let app = Router::new().route("/cookie", get(get_cookie));
1593 let mut server = TestServer::new(app);
1594
1595 let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1596 server.add_cookie(cookie);
1597
1598 let response_text = server.get(&"/cookie").await.text();
1599 assert_eq!(response_text, "my-custom-cookie");
1600 }
1601}
1602
1603#[cfg(test)]
1604mod test_add_cookies {
1605 use crate::TestServer;
1606
1607 use axum::Router;
1608 use axum::routing::get;
1609 use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1610 use cookie::Cookie;
1611 use cookie::CookieJar;
1612
1613 async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1614 let mut all_cookies = cookies
1615 .iter()
1616 .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1617 .collect::<Vec<String>>();
1618 all_cookies.sort();
1619
1620 all_cookies.join(&", ")
1621 }
1622
1623 #[tokio::test]
1624 async fn it_should_send_all_cookies_added_by_jar() {
1625 let app = Router::new().route("/cookies", get(route_get_cookies));
1626 let mut server = TestServer::new(app);
1627
1628 let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1630 let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1631 let mut cookie_jar = CookieJar::new();
1632 cookie_jar.add(cookie_1);
1633 cookie_jar.add(cookie_2);
1634
1635 server.add_cookies(cookie_jar);
1636
1637 server
1638 .get(&"/cookies")
1639 .await
1640 .assert_text("first-cookie=my-custom-cookie, second-cookie=other-cookie");
1641 }
1642}
1643
1644#[cfg(test)]
1645mod test_clear_cookies {
1646 use crate::TestServer;
1647
1648 use axum::Router;
1649 use axum::routing::get;
1650 use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1651 use cookie::Cookie;
1652 use cookie::CookieJar;
1653
1654 async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1655 let mut all_cookies = cookies
1656 .iter()
1657 .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1658 .collect::<Vec<String>>();
1659 all_cookies.sort();
1660
1661 all_cookies.join(&", ")
1662 }
1663
1664 #[tokio::test]
1665 async fn it_should_not_send_cookies_cleared() {
1666 let app = Router::new().route("/cookies", get(route_get_cookies));
1667 let mut server = TestServer::new(app);
1668
1669 let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1670 let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1671 let mut cookie_jar = CookieJar::new();
1672 cookie_jar.add(cookie_1);
1673 cookie_jar.add(cookie_2);
1674
1675 server.add_cookies(cookie_jar);
1676
1677 server.clear_cookies();
1679
1680 server.get(&"/cookies").await.assert_text("");
1681 }
1682}
1683
1684#[cfg(test)]
1685mod test_add_header {
1686 use super::*;
1687 use crate::TestServer;
1688 use axum::Router;
1689 use axum::extract::FromRequestParts;
1690 use axum::routing::get;
1691 use http::HeaderName;
1692 use http::HeaderValue;
1693 use http::request::Parts;
1694 use hyper::StatusCode;
1695 use std::marker::Sync;
1696
1697 const TEST_HEADER_NAME: &'static str = &"test-header";
1698 const TEST_HEADER_CONTENT: &'static str = &"Test header content";
1699
1700 struct TestHeader(Vec<u8>);
1701
1702 impl<S: Sync> FromRequestParts<S> for TestHeader {
1703 type Rejection = (StatusCode, &'static str);
1704
1705 async fn from_request_parts(
1706 parts: &mut Parts,
1707 _state: &S,
1708 ) -> Result<TestHeader, Self::Rejection> {
1709 parts
1710 .headers
1711 .get(HeaderName::from_static(TEST_HEADER_NAME))
1712 .map(|v| TestHeader(v.as_bytes().to_vec()))
1713 .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
1714 }
1715 }
1716
1717 async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
1718 header
1719 }
1720
1721 #[tokio::test]
1722 async fn it_should_send_header_added_to_server() {
1723 let app = Router::new().route("/header", get(ping_header));
1725
1726 let mut server = TestServer::new(app);
1728 server.add_header(
1729 HeaderName::from_static(TEST_HEADER_NAME),
1730 HeaderValue::from_static(TEST_HEADER_CONTENT),
1731 );
1732
1733 let response = server.get(&"/header").await;
1735
1736 response.assert_text(TEST_HEADER_CONTENT);
1738 }
1739}
1740
1741#[cfg(test)]
1742mod test_clear_headers {
1743 use super::*;
1744 use crate::TestServer;
1745 use axum::Router;
1746 use axum::extract::FromRequestParts;
1747 use axum::routing::get;
1748 use http::HeaderName;
1749 use http::HeaderValue;
1750 use http::request::Parts;
1751 use hyper::StatusCode;
1752 use std::marker::Sync;
1753
1754 const TEST_HEADER_NAME: &'static str = &"test-header";
1755 const TEST_HEADER_CONTENT: &'static str = &"Test header content";
1756
1757 struct TestHeader(Vec<u8>);
1758
1759 impl<S: Sync> FromRequestParts<S> for TestHeader {
1760 type Rejection = (StatusCode, &'static str);
1761
1762 async fn from_request_parts(
1763 parts: &mut Parts,
1764 _state: &S,
1765 ) -> Result<Self, Self::Rejection> {
1766 parts
1767 .headers
1768 .get(HeaderName::from_static(TEST_HEADER_NAME))
1769 .map(|v| TestHeader(v.as_bytes().to_vec()))
1770 .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
1771 }
1772 }
1773
1774 async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
1775 header
1776 }
1777
1778 #[tokio::test]
1779 async fn it_should_not_send_headers_cleared_by_server() {
1780 let app = Router::new().route("/header", get(ping_header));
1782
1783 let mut server = TestServer::new(app);
1785 server.add_header(
1786 HeaderName::from_static(TEST_HEADER_NAME),
1787 HeaderValue::from_static(TEST_HEADER_CONTENT),
1788 );
1789 server.clear_headers();
1790
1791 let response = server.get(&"/header").await;
1793
1794 response.assert_status_bad_request();
1796 response.assert_text("Missing test header");
1797 }
1798}
1799
1800#[cfg(test)]
1801mod test_add_query_params {
1802 use axum::Router;
1803 use axum::extract::Query;
1804 use axum::routing::get;
1805
1806 use serde::Deserialize;
1807 use serde::Serialize;
1808 use serde_json::json;
1809
1810 use crate::TestServer;
1811
1812 #[derive(Debug, Deserialize, Serialize)]
1813 struct QueryParam {
1814 message: String,
1815 }
1816
1817 async fn get_query_param(Query(params): Query<QueryParam>) -> String {
1818 params.message
1819 }
1820
1821 #[derive(Debug, Deserialize, Serialize)]
1822 struct QueryParam2 {
1823 message: String,
1824 other: String,
1825 }
1826
1827 async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
1828 format!("{}-{}", params.message, params.other)
1829 }
1830
1831 #[tokio::test]
1832 async fn it_should_pass_up_query_params_from_serialization() {
1833 let app = Router::new().route("/query", get(get_query_param));
1835
1836 let mut server = TestServer::new(app);
1838 server.add_query_params(QueryParam {
1839 message: "it works".to_string(),
1840 });
1841
1842 server.get(&"/query").await.assert_text(&"it works");
1844 }
1845
1846 #[tokio::test]
1847 async fn it_should_pass_up_query_params_from_pairs() {
1848 let app = Router::new().route("/query", get(get_query_param));
1850
1851 let mut server = TestServer::new(app);
1853 server.add_query_params(&[("message", "it works")]);
1854
1855 server.get(&"/query").await.assert_text(&"it works");
1857 }
1858
1859 #[tokio::test]
1860 async fn it_should_pass_up_multiple_query_params_from_multiple_params() {
1861 let app = Router::new().route("/query-2", get(get_query_param_2));
1863
1864 let mut server = TestServer::new(app);
1866 server.add_query_params(&[("message", "it works"), ("other", "yup")]);
1867
1868 server.get(&"/query-2").await.assert_text(&"it works-yup");
1870 }
1871
1872 #[tokio::test]
1873 async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
1874 let app = Router::new().route("/query-2", get(get_query_param_2));
1876
1877 let mut server = TestServer::new(app);
1879 server.add_query_params(&[("message", "it works")]);
1880 server.add_query_params(&[("other", "yup")]);
1881
1882 server.get(&"/query-2").await.assert_text(&"it works-yup");
1884 }
1885
1886 #[tokio::test]
1887 async fn it_should_pass_up_multiple_query_params_from_json() {
1888 let app = Router::new().route("/query-2", get(get_query_param_2));
1890
1891 let mut server = TestServer::new(app);
1893 server.add_query_params(json!({
1894 "message": "it works",
1895 "other": "yup"
1896 }));
1897
1898 server.get(&"/query-2").await.assert_text(&"it works-yup");
1900 }
1901}
1902
1903#[cfg(test)]
1904mod test_add_query_param {
1905 use axum::Router;
1906 use axum::extract::Query;
1907 use axum::routing::get;
1908
1909 use serde::Deserialize;
1910 use serde::Serialize;
1911
1912 use crate::TestServer;
1913
1914 #[derive(Debug, Deserialize, Serialize)]
1915 struct QueryParam {
1916 message: String,
1917 }
1918
1919 async fn get_query_param(Query(params): Query<QueryParam>) -> String {
1920 params.message
1921 }
1922
1923 #[derive(Debug, Deserialize, Serialize)]
1924 struct QueryParam2 {
1925 message: String,
1926 other: String,
1927 }
1928
1929 async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
1930 format!("{}-{}", params.message, params.other)
1931 }
1932
1933 #[tokio::test]
1934 async fn it_should_pass_up_query_params_from_pairs() {
1935 let app = Router::new().route("/query", get(get_query_param));
1937
1938 let mut server = TestServer::new(app);
1940 server.add_query_param("message", "it works");
1941
1942 server.get(&"/query").await.assert_text(&"it works");
1944 }
1945
1946 #[tokio::test]
1947 async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
1948 let app = Router::new().route("/query-2", get(get_query_param_2));
1950
1951 let mut server = TestServer::new(app);
1953 server.add_query_param("message", "it works");
1954 server.add_query_param("other", "yup");
1955
1956 server.get(&"/query-2").await.assert_text(&"it works-yup");
1958 }
1959
1960 #[tokio::test]
1961 async fn it_should_pass_up_multiple_query_params_from_calls_across_server_and_request() {
1962 let app = Router::new().route("/query-2", get(get_query_param_2));
1964
1965 let mut server = TestServer::new(app);
1967 server.add_query_param("message", "it works");
1968
1969 server
1971 .get(&"/query-2")
1972 .add_query_param("other", "yup")
1973 .await
1974 .assert_text(&"it works-yup");
1975 }
1976}
1977
1978#[cfg(test)]
1979mod test_add_raw_query_param {
1980 use axum::Router;
1981 use axum::extract::Query as AxumStdQuery;
1982 use axum::routing::get;
1983 use axum_extra::extract::Query as AxumExtraQuery;
1984 use serde::Deserialize;
1985 use serde::Serialize;
1986 use std::fmt::Write;
1987
1988 use crate::TestServer;
1989
1990 #[derive(Debug, Deserialize, Serialize)]
1991 struct QueryParam {
1992 message: String,
1993 }
1994
1995 async fn get_query_param(AxumStdQuery(params): AxumStdQuery<QueryParam>) -> String {
1996 params.message
1997 }
1998
1999 #[derive(Debug, Deserialize, Serialize)]
2000 struct QueryParamExtra {
2001 #[serde(default)]
2002 items: Vec<String>,
2003
2004 #[serde(default, rename = "arrs[]")]
2005 arrs: Vec<String>,
2006 }
2007
2008 async fn get_query_param_extra(
2009 AxumExtraQuery(params): AxumExtraQuery<QueryParamExtra>,
2010 ) -> String {
2011 let mut output = String::new();
2012
2013 if params.items.len() > 0 {
2014 write!(output, "{}", params.items.join(", ")).unwrap();
2015 }
2016
2017 if params.arrs.len() > 0 {
2018 write!(output, "{}", params.arrs.join(", ")).unwrap();
2019 }
2020
2021 output
2022 }
2023
2024 fn build_app() -> Router {
2025 Router::new()
2026 .route("/query", get(get_query_param))
2027 .route("/query-extra", get(get_query_param_extra))
2028 }
2029
2030 #[tokio::test]
2031 async fn it_should_pass_up_query_param_as_is() {
2032 let mut server = TestServer::new(build_app());
2034 server.add_raw_query_param(&"message=it-works");
2035
2036 server.get(&"/query").await.assert_text(&"it-works");
2038 }
2039
2040 #[tokio::test]
2041 async fn it_should_pass_up_array_query_params_as_one_string() {
2042 let mut server = TestServer::new(build_app());
2044 server.add_raw_query_param(&"items=one&items=two&items=three");
2045
2046 server
2048 .get(&"/query-extra")
2049 .await
2050 .assert_text(&"one, two, three");
2051 }
2052
2053 #[tokio::test]
2054 async fn it_should_pass_up_array_query_params_as_multiple_params() {
2055 let mut server = TestServer::new(build_app());
2057 server.add_raw_query_param(&"arrs[]=one");
2058 server.add_raw_query_param(&"arrs[]=two");
2059 server.add_raw_query_param(&"arrs[]=three");
2060
2061 server
2063 .get(&"/query-extra")
2064 .await
2065 .assert_text(&"one, two, three");
2066 }
2067}
2068
2069#[cfg(test)]
2070mod test_clear_query_params {
2071 use axum::Router;
2072 use axum::extract::Query;
2073 use axum::routing::get;
2074
2075 use serde::Deserialize;
2076 use serde::Serialize;
2077
2078 use crate::TestServer;
2079
2080 #[derive(Debug, Deserialize, Serialize)]
2081 struct QueryParams {
2082 first: Option<String>,
2083 second: Option<String>,
2084 }
2085
2086 async fn get_query_params(Query(params): Query<QueryParams>) -> String {
2087 format!(
2088 "has first? {}, has second? {}",
2089 params.first.is_some(),
2090 params.second.is_some()
2091 )
2092 }
2093
2094 #[tokio::test]
2095 async fn it_should_clear_all_params_set() {
2096 let app = Router::new().route("/query", get(get_query_params));
2098
2099 let mut server = TestServer::new(app);
2101 server.add_query_params(QueryParams {
2102 first: Some("first".to_string()),
2103 second: Some("second".to_string()),
2104 });
2105 server.clear_query_params();
2106
2107 server
2109 .get(&"/query")
2110 .await
2111 .assert_text(&"has first? false, has second? false");
2112 }
2113
2114 #[tokio::test]
2115 async fn it_should_clear_all_params_set_and_allow_replacement() {
2116 let app = Router::new().route("/query", get(get_query_params));
2118
2119 let mut server = TestServer::new(app);
2121 server.add_query_params(QueryParams {
2122 first: Some("first".to_string()),
2123 second: Some("second".to_string()),
2124 });
2125 server.clear_query_params();
2126 server.add_query_params(QueryParams {
2127 first: Some("first".to_string()),
2128 second: Some("second".to_string()),
2129 });
2130
2131 server
2133 .get(&"/query")
2134 .await
2135 .assert_text(&"has first? true, has second? true");
2136 }
2137}
2138
2139#[cfg(test)]
2140mod test_expect_success_by_default {
2141 use super::*;
2142 use crate::testing::catch_panic_error_message_async;
2143 use axum::Router;
2144 use axum::routing::get;
2145 use pretty_assertions::assert_str_eq;
2146
2147 #[tokio::test]
2148 async fn it_should_not_panic_by_default_if_accessing_404_route() {
2149 let app = Router::new();
2150 let server = TestServer::new(app);
2151
2152 server.get(&"/some_unknown_route").await;
2153 }
2154
2155 #[tokio::test]
2156 async fn it_should_not_panic_by_default_if_accessing_200_route() {
2157 let app = Router::new().route("/known_route", get(|| async { "🦊🦊🦊" }));
2158 let server = TestServer::new(app);
2159
2160 server.get(&"/known_route").await;
2161 }
2162
2163 #[tokio::test]
2164 async fn it_should_panic_by_default_if_accessing_404_route_and_expect_success_on() {
2165 let app = Router::new();
2166 let server = TestServer::builder().expect_success_by_default().build(app);
2167
2168 let message = catch_panic_error_message_async(server.get(&"/some_unknown_route")).await;
2169 assert_str_eq!(
2170 "Expect status code within 2xx range, received 404 (Not Found), for request GET http://localhost/some_unknown_route, with body ''",
2171 message
2172 );
2173 }
2174
2175 #[tokio::test]
2176 async fn it_should_not_panic_by_default_if_accessing_200_route_and_expect_success_on() {
2177 let app = Router::new().route("/known_route", get(|| async { "🦊🦊🦊" }));
2178 let server = TestServer::builder().expect_success_by_default().build(app);
2179
2180 server.get(&"/known_route").await;
2181 }
2182}
2183
2184#[cfg(test)]
2185mod test_content_type {
2186 use super::*;
2187 use axum::Router;
2188 use axum::routing::get;
2189 use http::HeaderMap;
2190 use http::header::CONTENT_TYPE;
2191
2192 async fn get_content_type(headers: HeaderMap) -> String {
2193 headers
2194 .get(CONTENT_TYPE)
2195 .map(|h| h.to_str().unwrap().to_string())
2196 .unwrap_or_else(|| "".to_string())
2197 }
2198
2199 #[tokio::test]
2200 async fn it_should_default_to_server_content_type_when_present() {
2201 let app = Router::new().route("/content_type", get(get_content_type));
2203
2204 let server = TestServer::builder()
2206 .default_content_type("text/plain")
2207 .build(app);
2208
2209 let text = server.get(&"/content_type").await.text();
2211
2212 assert_eq!(text, "text/plain");
2213 }
2214}
2215
2216#[cfg(test)]
2217mod test_expect_success {
2218 use crate::TestServer;
2219 use crate::testing::catch_panic_error_message_async;
2220 use axum::Router;
2221 use axum::routing::get;
2222 use http::StatusCode;
2223 use pretty_assertions::assert_str_eq;
2224
2225 #[tokio::test]
2226 async fn it_should_not_panic_if_success_is_returned() {
2227 async fn get_ping() -> &'static str {
2228 "pong!"
2229 }
2230
2231 let app = Router::new().route("/ping", get(get_ping));
2233
2234 let mut server = TestServer::new(app);
2236 server.expect_success();
2237
2238 server.get(&"/ping").await;
2240 }
2241
2242 #[tokio::test]
2243 async fn it_should_not_panic_on_other_2xx_status_code() {
2244 async fn get_accepted() -> StatusCode {
2245 StatusCode::ACCEPTED
2246 }
2247
2248 let app = Router::new().route("/accepted", get(get_accepted));
2250
2251 let mut server = TestServer::new(app);
2253 server.expect_success();
2254
2255 server.get(&"/accepted").await;
2257 }
2258
2259 #[tokio::test]
2260 async fn it_should_panic_on_404() {
2261 let app = Router::new();
2263
2264 let mut server = TestServer::new(app);
2266 server.expect_success();
2267
2268 let message = catch_panic_error_message_async(server.get(&"/some_unknown_route")).await;
2270 assert_str_eq!(
2271 "Expect status code within 2xx range, received 404 (Not Found), for request GET http://localhost/some_unknown_route, with body ''",
2272 message
2273 );
2274 }
2275}
2276
2277#[cfg(test)]
2278mod test_expect_failure {
2279 use crate::TestServer;
2280 use crate::testing::catch_panic_error_message_async;
2281 use axum::Router;
2282 use axum::routing::get;
2283 use http::StatusCode;
2284 use pretty_assertions::assert_str_eq;
2285
2286 #[tokio::test]
2287 async fn it_should_not_panic_if_expect_failure_on_404() {
2288 let app = Router::new();
2290
2291 let mut server = TestServer::new(app);
2293 server.expect_failure();
2294
2295 server.get(&"/some_unknown_route").await;
2297 }
2298
2299 #[tokio::test]
2300 async fn it_should_panic_if_success_is_returned() {
2301 async fn get_ping() -> &'static str {
2302 "pong!"
2303 }
2304
2305 let app = Router::new().route("/ping", get(get_ping));
2307
2308 let mut server = TestServer::new(app);
2310 server.expect_failure();
2311
2312 let message = catch_panic_error_message_async(server.get(&"/ping")).await;
2314 assert_str_eq!(
2315 "Expect status code outside 2xx range, received 200 (OK), for request GET http://localhost/ping, with body 'pong!'",
2316 message
2317 );
2318 }
2319
2320 #[tokio::test]
2321 async fn it_should_panic_on_other_2xx_status_code() {
2322 async fn get_accepted() -> StatusCode {
2323 StatusCode::ACCEPTED
2324 }
2325
2326 let app = Router::new().route("/accepted", get(get_accepted));
2328
2329 let mut server = TestServer::new(app);
2331 server.expect_failure();
2332
2333 let message = catch_panic_error_message_async(server.get(&"/accepted")).await;
2335 assert_str_eq!(
2336 "Expect status code outside 2xx range, received 202 (Accepted), for request GET http://localhost/accepted, with body ''",
2337 message
2338 );
2339 }
2340}
2341
2342#[cfg(feature = "typed-routing")]
2343#[cfg(test)]
2344mod test_typed_get {
2345 use super::*;
2346 use axum::Router;
2347 use axum_extra::routing::RouterExt;
2348 use serde::Deserialize;
2349
2350 #[derive(TypedPath, Deserialize)]
2351 #[typed_path("/path/{id}")]
2352 struct TestingPath {
2353 id: u32,
2354 }
2355
2356 async fn route_get(TestingPath { id }: TestingPath) -> String {
2357 format!("get {id}")
2358 }
2359
2360 fn new_app() -> Router {
2361 Router::new().typed_get(route_get)
2362 }
2363
2364 #[tokio::test]
2365 async fn it_should_send_get() {
2366 let server = TestServer::new(new_app());
2367
2368 server
2369 .typed_get(&TestingPath { id: 123 })
2370 .await
2371 .assert_text("get 123");
2372 }
2373}
2374
2375#[cfg(feature = "typed-routing")]
2376#[cfg(test)]
2377mod test_typed_post {
2378 use super::*;
2379 use axum::Router;
2380 use axum_extra::routing::RouterExt;
2381 use serde::Deserialize;
2382
2383 #[derive(TypedPath, Deserialize)]
2384 #[typed_path("/path/{id}")]
2385 struct TestingPath {
2386 id: u32,
2387 }
2388
2389 async fn route_post(TestingPath { id }: TestingPath) -> String {
2390 format!("post {id}")
2391 }
2392
2393 fn new_app() -> Router {
2394 Router::new().typed_post(route_post)
2395 }
2396
2397 #[tokio::test]
2398 async fn it_should_send_post() {
2399 let server = TestServer::new(new_app());
2400
2401 server
2402 .typed_post(&TestingPath { id: 123 })
2403 .await
2404 .assert_text("post 123");
2405 }
2406}
2407
2408#[cfg(feature = "typed-routing")]
2409#[cfg(test)]
2410mod test_typed_patch {
2411 use super::*;
2412 use axum::Router;
2413 use axum_extra::routing::RouterExt;
2414 use serde::Deserialize;
2415
2416 #[derive(TypedPath, Deserialize)]
2417 #[typed_path("/path/{id}")]
2418 struct TestingPath {
2419 id: u32,
2420 }
2421
2422 async fn route_patch(TestingPath { id }: TestingPath) -> String {
2423 format!("patch {id}")
2424 }
2425
2426 fn new_app() -> Router {
2427 Router::new().typed_patch(route_patch)
2428 }
2429
2430 #[tokio::test]
2431 async fn it_should_send_patch() {
2432 let server = TestServer::new(new_app());
2433
2434 server
2435 .typed_patch(&TestingPath { id: 123 })
2436 .await
2437 .assert_text("patch 123");
2438 }
2439}
2440
2441#[cfg(feature = "typed-routing")]
2442#[cfg(test)]
2443mod test_typed_put {
2444 use super::*;
2445 use axum::Router;
2446 use axum_extra::routing::RouterExt;
2447 use serde::Deserialize;
2448
2449 #[derive(TypedPath, Deserialize)]
2450 #[typed_path("/path/{id}")]
2451 struct TestingPath {
2452 id: u32,
2453 }
2454
2455 async fn route_put(TestingPath { id }: TestingPath) -> String {
2456 format!("put {id}")
2457 }
2458
2459 fn new_app() -> Router {
2460 Router::new().typed_put(route_put)
2461 }
2462
2463 #[tokio::test]
2464 async fn it_should_send_put() {
2465 let server = TestServer::new(new_app());
2466
2467 server
2468 .typed_put(&TestingPath { id: 123 })
2469 .await
2470 .assert_text("put 123");
2471 }
2472}
2473
2474#[cfg(feature = "typed-routing")]
2475#[cfg(test)]
2476mod test_typed_delete {
2477 use super::*;
2478 use axum::Router;
2479 use axum_extra::routing::RouterExt;
2480 use serde::Deserialize;
2481
2482 #[derive(TypedPath, Deserialize)]
2483 #[typed_path("/path/{id}")]
2484 struct TestingPath {
2485 id: u32,
2486 }
2487
2488 async fn route_delete(TestingPath { id }: TestingPath) -> String {
2489 format!("delete {id}")
2490 }
2491
2492 fn new_app() -> Router {
2493 Router::new().typed_delete(route_delete)
2494 }
2495
2496 #[tokio::test]
2497 async fn it_should_send_delete() {
2498 let server = TestServer::new(new_app());
2499
2500 server
2501 .typed_delete(&TestingPath { id: 123 })
2502 .await
2503 .assert_text("delete 123");
2504 }
2505}
2506
2507#[cfg(feature = "typed-routing")]
2508#[cfg(test)]
2509mod test_typed_method {
2510 use super::*;
2511 use axum::Router;
2512 use axum_extra::routing::RouterExt;
2513 use serde::Deserialize;
2514
2515 #[derive(TypedPath, Deserialize)]
2516 #[typed_path("/path/{id}")]
2517 struct TestingPath {
2518 id: u32,
2519 }
2520
2521 async fn route_get(TestingPath { id }: TestingPath) -> String {
2522 format!("get {id}")
2523 }
2524
2525 async fn route_post(TestingPath { id }: TestingPath) -> String {
2526 format!("post {id}")
2527 }
2528
2529 async fn route_patch(TestingPath { id }: TestingPath) -> String {
2530 format!("patch {id}")
2531 }
2532
2533 async fn route_put(TestingPath { id }: TestingPath) -> String {
2534 format!("put {id}")
2535 }
2536
2537 async fn route_delete(TestingPath { id }: TestingPath) -> String {
2538 format!("delete {id}")
2539 }
2540
2541 fn new_app() -> Router {
2542 Router::new()
2543 .typed_get(route_get)
2544 .typed_post(route_post)
2545 .typed_patch(route_patch)
2546 .typed_put(route_put)
2547 .typed_delete(route_delete)
2548 }
2549
2550 #[tokio::test]
2551 async fn it_should_send_get() {
2552 let server = TestServer::new(new_app());
2553
2554 server
2555 .typed_method(Method::GET, &TestingPath { id: 123 })
2556 .await
2557 .assert_text("get 123");
2558 }
2559
2560 #[tokio::test]
2561 async fn it_should_send_post() {
2562 let server = TestServer::new(new_app());
2563
2564 server
2565 .typed_method(Method::POST, &TestingPath { id: 123 })
2566 .await
2567 .assert_text("post 123");
2568 }
2569
2570 #[tokio::test]
2571 async fn it_should_send_patch() {
2572 let server = TestServer::new(new_app());
2573
2574 server
2575 .typed_method(Method::PATCH, &TestingPath { id: 123 })
2576 .await
2577 .assert_text("patch 123");
2578 }
2579
2580 #[tokio::test]
2581 async fn it_should_send_put() {
2582 let server = TestServer::new(new_app());
2583
2584 server
2585 .typed_method(Method::PUT, &TestingPath { id: 123 })
2586 .await
2587 .assert_text("put 123");
2588 }
2589
2590 #[tokio::test]
2591 async fn it_should_send_delete() {
2592 let server = TestServer::new(new_app());
2593
2594 server
2595 .typed_method(Method::DELETE, &TestingPath { id: 123 })
2596 .await
2597 .assert_text("delete 123");
2598 }
2599}
2600
2601#[cfg(test)]
2602mod test_sync {
2603 use super::*;
2604 use axum::Router;
2605 use axum::routing::get;
2606 use std::cell::OnceCell;
2607
2608 #[tokio::test]
2609 async fn it_should_be_able_to_be_in_one_cell() {
2610 let cell: OnceCell<TestServer> = OnceCell::new();
2611 let server = cell.get_or_init(|| {
2612 async fn route_get() -> &'static str {
2613 "it works"
2614 }
2615
2616 let router = Router::new().route("/test", get(route_get));
2617
2618 TestServer::new(router)
2619 });
2620
2621 server.get("/test").await.assert_text("it works");
2622 }
2623}
2624
2625#[cfg(test)]
2626mod test_is_running {
2627 use super::*;
2628 use crate::testing::catch_panic_error_message_async;
2629 use crate::util::new_random_tokio_tcp_listener_with_socket_addr;
2630 use axum::Router;
2631 use axum::routing::IntoMakeService;
2632 use axum::routing::get;
2633 use axum::serve;
2634 use pretty_assertions::assert_str_eq;
2635 use std::time::Duration;
2636 use tokio::sync::Notify;
2637 use tokio::time::sleep;
2638
2639 async fn get_ping() -> &'static str {
2640 "pong!"
2641 }
2642
2643 #[tokio::test]
2644 async fn it_should_panic_when_run_with_mock_http() {
2645 let shutdown_notification = Arc::new(Notify::new());
2646 let waiting_notification = shutdown_notification.clone();
2647
2648 let app: IntoMakeService<Router> = Router::new()
2650 .route("/ping", get(get_ping))
2651 .into_make_service();
2652 let (listener, ip_port) = new_random_tokio_tcp_listener_with_socket_addr().unwrap();
2653 let application = serve(listener, app)
2654 .with_graceful_shutdown(async move { waiting_notification.notified().await });
2655
2656 let server = TestServer::builder().build(application);
2658
2659 server.get("/ping").await.assert_status_ok();
2660 assert!(server.is_running());
2661
2662 shutdown_notification.notify_one();
2663 sleep(Duration::from_millis(10)).await;
2664
2665 assert!(!server.is_running());
2666
2667 let ip = ip_port.ip();
2668 let port = ip_port.port();
2669 let expected = format!(
2670 "Sending request failed, for request GET http://{ip}:{port}/ping,
2671 client error (Connect)
2672 tcp connect error
2673 Connection refused (os error 61)
2674"
2675 );
2676 let message = catch_panic_error_message_async(server.get("/ping")).await;
2677 assert_str_eq!(expected, message);
2678 }
2679}
2680
2681#[cfg(test)]
2682mod test_save_cookies {
2683 use crate::TestServer;
2684 use axum::Router;
2685 use axum::extract::Request;
2686 use axum::http::header::HeaderMap;
2687 use axum::routing::get;
2688 use axum::routing::put;
2689 use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
2690 use cookie::Cookie;
2691 use cookie::SameSite;
2692 use http_body_util::BodyExt;
2693
2694 const TEST_COOKIE_NAME: &'static str = &"test-cookie";
2695
2696 #[tokio::test]
2697 async fn it_should_save_cookies_across_requests_when_enabled() {
2698 let mut server = TestServer::new(app());
2699
2700 server.save_cookies();
2701
2702 save_cookie_using_axum_test(&server).await;
2703 assert_cookie_using_axum_test(&server).await;
2704 }
2705
2706 #[cfg(feature = "reqwest")]
2707 #[tokio::test]
2708 async fn it_should_save_cookies_across_reqwest_requests_when_enabled() {
2709 let mut server = TestServer::builder().http_transport().build(app());
2710
2711 server.save_cookies();
2712
2713 save_cookie_using_reqwest(&server).await;
2714 save_cookie_using_reqwest(&server).await;
2715 }
2716
2717 #[tokio::test]
2718 async fn it_should_save_cookies_across_axum_test_requests_when_enabled_for_second_request() {
2719 let mut server = TestServer::builder().http_transport().build(app());
2720
2721 save_cookie_using_axum_test(&server).await;
2722 assert_no_cookie_using_axum_test(&server).await;
2723
2724 server.save_cookies();
2725
2726 save_cookie_using_axum_test(&server).await;
2727 assert_cookie_using_axum_test(&server).await;
2728 }
2729
2730 #[cfg(feature = "reqwest")]
2731 #[tokio::test]
2732 async fn it_should_save_cookies_across_reqwest_requests_when_enabled_for_second_request() {
2733 let mut server = TestServer::builder().http_transport().build(app());
2734
2735 save_cookie_using_reqwest(&server).await;
2736 assert_no_cookie_using_reqwest(&server).await;
2737
2738 server.save_cookies();
2739
2740 save_cookie_using_reqwest(&server).await;
2741 assert_cookie_using_reqwest(&server).await;
2742 }
2743
2744 #[cfg(feature = "reqwest")]
2745 #[tokio::test]
2746 async fn it_should_save_cookies_when_set_by_reqwest_and_read_by_axum_test() {
2747 let mut server = TestServer::builder().http_transport().build(app());
2748
2749 server.save_cookies();
2750
2751 save_cookie_using_reqwest(&server).await;
2752 assert_cookie_using_axum_test(&server).await;
2753 }
2754
2755 #[cfg(feature = "reqwest")]
2756 #[tokio::test]
2757 async fn it_should_save_cookies_when_set_by_axum_test_and_read_by_reqwest() {
2758 let mut server = TestServer::builder().http_transport().build(app());
2759
2760 server.save_cookies();
2761
2762 save_cookie_using_axum_test(&server).await;
2763 assert_cookie_using_reqwest(&server).await;
2764 }
2765
2766 fn app() -> Router {
2767 async fn put_cookie_with_attributes(
2768 mut cookies: AxumCookieJar,
2769 request: Request,
2770 ) -> (AxumCookieJar, &'static str) {
2771 let body_bytes = request
2772 .into_body()
2773 .collect()
2774 .await
2775 .expect("Should turn the body into bytes")
2776 .to_bytes();
2777
2778 let body_text: String = String::from_utf8_lossy(&body_bytes).to_string();
2779 let cookie = Cookie::build((TEST_COOKIE_NAME, body_text))
2780 .http_only(true)
2781 .secure(true)
2782 .same_site(SameSite::Strict)
2783 .path("/cookie")
2784 .build();
2785 cookies = cookies.add(cookie);
2786
2787 (cookies, &"done")
2788 }
2789
2790 async fn get_cookie_headers_joined(headers: HeaderMap) -> String {
2791 let cookies: String = headers
2792 .get_all("cookie")
2793 .into_iter()
2794 .map(|c| c.to_str().unwrap_or("").to_string())
2795 .reduce(|a, b| a + "; " + &b)
2796 .unwrap_or_else(|| String::new());
2797
2798 cookies
2799 }
2800
2801 Router::new()
2802 .route("/cookie", put(put_cookie_with_attributes))
2803 .route("/cookie", get(get_cookie_headers_joined))
2804 }
2805
2806 async fn save_cookie_using_axum_test(server: &TestServer) {
2807 server.put(&"/cookie").text(&"cookie-found!").await;
2808 }
2809
2810 #[cfg(feature = "reqwest")]
2811 async fn save_cookie_using_reqwest(server: &TestServer) {
2812 server
2813 .reqwest_put(&"/cookie")
2814 .body("cookie-found!".to_string())
2815 .send()
2816 .await
2817 .unwrap();
2818 }
2819
2820 async fn assert_cookie_using_axum_test(server: &TestServer) {
2821 server
2822 .get(&"/cookie")
2823 .await
2824 .assert_text("test-cookie=cookie-found!");
2825 }
2826
2827 #[cfg(feature = "reqwest")]
2828 async fn assert_cookie_using_reqwest(server: &TestServer) {
2829 let response_text = server
2830 .reqwest_get(&"/cookie")
2831 .send()
2832 .await
2833 .unwrap()
2834 .text()
2835 .await
2836 .unwrap();
2837
2838 assert_eq!("test-cookie=cookie-found!", response_text);
2839 }
2840
2841 async fn assert_no_cookie_using_axum_test(server: &TestServer) {
2842 server.get(&"/cookie").await.assert_text("");
2843 }
2844
2845 #[cfg(feature = "reqwest")]
2846 async fn assert_no_cookie_using_reqwest(server: &TestServer) {
2847 let response_text = server
2848 .reqwest_get(&"/cookie")
2849 .send()
2850 .await
2851 .unwrap()
2852 .text()
2853 .await
2854 .unwrap();
2855
2856 assert_eq!("", response_text);
2857 }
2858}