1use anyhow::Context;
2use anyhow::Result;
3use anyhow::anyhow;
4use cookie::Cookie;
5use cookie::CookieJar;
6use http::HeaderName;
7use http::HeaderValue;
8use http::Method;
9use http::Uri;
10use serde::Serialize;
11use std::fmt::Debug;
12use std::sync::Arc;
13use std::sync::Mutex;
14use url::Url;
15
16#[cfg(feature = "typed-routing")]
17use axum_extra::routing::TypedPath;
18
19#[cfg(feature = "reqwest")]
20use crate::transport_layer::TransportLayerType;
21#[cfg(feature = "reqwest")]
22use reqwest::Client;
23#[cfg(feature = "reqwest")]
24use reqwest::RequestBuilder;
25
26use crate::TestRequest;
27use crate::TestRequestConfig;
28use crate::TestServerBuilder;
29use crate::TestServerConfig;
30use crate::Transport;
31use crate::internals::ExpectedState;
32use crate::internals::QueryParamsStore;
33use crate::internals::RequestPathFormatter;
34use crate::transport_layer::IntoTransportLayer;
35use crate::transport_layer::TransportLayer;
36use crate::transport_layer::TransportLayerBuilder;
37
38mod server_shared_state;
39pub(crate) use self::server_shared_state::*;
40
41const DEFAULT_URL_ADDRESS: &str = "http://localhost";
42
43#[derive(Debug)]
142pub struct TestServer {
143 state: Arc<Mutex<ServerSharedState>>,
144 transport: Arc<Box<dyn TransportLayer>>,
145 save_cookies: bool,
146 expected_state: ExpectedState,
147 default_content_type: Option<String>,
148 is_http_path_restricted: bool,
149
150 #[cfg(feature = "reqwest")]
151 maybe_reqwest_client: Option<Client>,
152}
153
154impl TestServer {
155 pub fn builder() -> TestServerBuilder {
157 TestServerBuilder::default()
158 }
159
160 pub fn new<A>(app: A) -> Result<Self>
191 where
192 A: IntoTransportLayer,
193 {
194 Self::new_with_config(app, TestServerConfig::default())
195 }
196
197 pub fn new_with_config<A, C>(app: A, config: C) -> Result<Self>
204 where
205 A: IntoTransportLayer,
206 C: Into<TestServerConfig>,
207 {
208 let config = config.into();
209 let mut shared_state = ServerSharedState::new();
210 if let Some(scheme) = config.default_scheme {
211 shared_state.set_scheme_unlocked(scheme);
212 }
213
214 let shared_state_mutex = Mutex::new(shared_state);
215 let state = Arc::new(shared_state_mutex);
216
217 let transport = match config.transport {
218 None => {
219 let builder = TransportLayerBuilder::new(None, None);
220 let transport = app.into_default_transport(builder)?;
221 Arc::new(transport)
222 }
223 Some(Transport::HttpRandomPort) => {
224 let builder = TransportLayerBuilder::new(None, None);
225 let transport = app.into_http_transport_layer(builder)?;
226 Arc::new(transport)
227 }
228 Some(Transport::HttpIpPort { ip, port }) => {
229 let builder = TransportLayerBuilder::new(ip, port);
230 let transport = app.into_http_transport_layer(builder)?;
231 Arc::new(transport)
232 }
233 Some(Transport::MockHttp) => {
234 let transport = app.into_mock_transport_layer()?;
235 Arc::new(transport)
236 }
237 };
238
239 let expected_state = match config.expect_success_by_default {
240 true => ExpectedState::Success,
241 false => ExpectedState::None,
242 };
243
244 #[cfg(feature = "reqwest")]
245 let maybe_reqwest_client = match transport.transport_layer_type() {
246 TransportLayerType::Http => {
247 let reqwest_client = reqwest::Client::builder()
248 .redirect(reqwest::redirect::Policy::none())
249 .cookie_store(config.save_cookies)
250 .build()
251 .expect("Failed to build Reqwest Client");
252
253 Some(reqwest_client)
254 }
255 TransportLayerType::Mock => None,
256 };
257
258 Ok(Self {
259 state,
260 transport,
261 save_cookies: config.save_cookies,
262 expected_state,
263 default_content_type: config.default_content_type,
264 is_http_path_restricted: config.restrict_requests_with_http_schema,
265
266 #[cfg(feature = "reqwest")]
267 maybe_reqwest_client,
268 })
269 }
270
271 pub fn get(&self, path: &str) -> TestRequest {
273 self.method(Method::GET, path)
274 }
275
276 pub fn post(&self, path: &str) -> TestRequest {
278 self.method(Method::POST, path)
279 }
280
281 pub fn patch(&self, path: &str) -> TestRequest {
283 self.method(Method::PATCH, path)
284 }
285
286 pub fn put(&self, path: &str) -> TestRequest {
288 self.method(Method::PUT, path)
289 }
290
291 pub fn delete(&self, path: &str) -> TestRequest {
293 self.method(Method::DELETE, path)
294 }
295
296 pub fn method(&self, method: Method, path: &str) -> TestRequest {
298 let maybe_config = self.build_test_request_config(method.clone(), path);
299 let config = maybe_config
300 .with_context(|| format!("Failed to build, for request {method} {path}"))
301 .unwrap();
302
303 TestRequest::new(self.state.clone(), self.transport.clone(), config)
304 }
305
306 #[cfg(feature = "reqwest")]
307 fn reqwest_client(&self) -> &Client {
308 self.maybe_reqwest_client
309 .as_ref()
310 .expect("Reqwest client is not available, TestServer must be build with HTTP transport for Reqwest to be available")
311 }
312
313 #[cfg(feature = "reqwest")]
314 pub fn reqwest_get(&self, path: &str) -> RequestBuilder {
315 self.reqwest_method(Method::GET, path)
316 }
317
318 #[cfg(feature = "reqwest")]
319 pub fn reqwest_post(&self, path: &str) -> RequestBuilder {
320 self.reqwest_method(Method::POST, path)
321 }
322
323 #[cfg(feature = "reqwest")]
324 pub fn reqwest_put(&self, path: &str) -> RequestBuilder {
325 self.reqwest_method(Method::PUT, path)
326 }
327
328 #[cfg(feature = "reqwest")]
329 pub fn reqwest_patch(&self, path: &str) -> RequestBuilder {
330 self.reqwest_method(Method::PATCH, path)
331 }
332
333 #[cfg(feature = "reqwest")]
334 pub fn reqwest_delete(&self, path: &str) -> RequestBuilder {
335 self.reqwest_method(Method::DELETE, path)
336 }
337
338 #[cfg(feature = "reqwest")]
339 pub fn reqwest_head(&self, path: &str) -> RequestBuilder {
340 self.reqwest_method(Method::HEAD, path)
341 }
342
343 #[cfg(feature = "reqwest")]
368 pub fn reqwest_method(&self, method: Method, path: &str) -> RequestBuilder {
369 let request_url = self
370 .server_url(path)
371 .expect("Failed to generate server url for request {method} {path}");
372
373 self.reqwest_client().request(method, request_url)
374 }
375
376 #[cfg(feature = "ws")]
411 pub fn get_websocket(&self, path: &str) -> TestRequest {
412 use http::header;
413
414 self.get(path)
415 .add_header(header::CONNECTION, "upgrade")
416 .add_header(header::UPGRADE, "websocket")
417 .add_header(header::SEC_WEBSOCKET_VERSION, "13")
418 .add_header(
419 header::SEC_WEBSOCKET_KEY,
420 crate::internals::generate_ws_key(),
421 )
422 }
423
424 #[cfg(feature = "typed-routing")]
471 pub fn typed_get<P>(&self, path: &P) -> TestRequest
472 where
473 P: TypedPath,
474 {
475 self.typed_method(Method::GET, path)
476 }
477
478 #[cfg(feature = "typed-routing")]
482 pub fn typed_post<P>(&self, path: &P) -> TestRequest
483 where
484 P: TypedPath,
485 {
486 self.typed_method(Method::POST, path)
487 }
488
489 #[cfg(feature = "typed-routing")]
493 pub fn typed_patch<P>(&self, path: &P) -> TestRequest
494 where
495 P: TypedPath,
496 {
497 self.typed_method(Method::PATCH, path)
498 }
499
500 #[cfg(feature = "typed-routing")]
504 pub fn typed_put<P>(&self, path: &P) -> TestRequest
505 where
506 P: TypedPath,
507 {
508 self.typed_method(Method::PUT, path)
509 }
510
511 #[cfg(feature = "typed-routing")]
515 pub fn typed_delete<P>(&self, path: &P) -> TestRequest
516 where
517 P: TypedPath,
518 {
519 self.typed_method(Method::DELETE, path)
520 }
521
522 #[cfg(feature = "typed-routing")]
526 pub fn typed_method<P>(&self, method: Method, path: &P) -> TestRequest
527 where
528 P: TypedPath,
529 {
530 self.method(method, &path.to_string())
531 }
532
533 pub fn server_address(&self) -> Option<Url> {
541 self.url()
542 }
543
544 pub fn server_url(&self, path: &str) -> Result<Url> {
577 let path_uri = path.parse::<Uri>()?;
578 if is_absolute_uri(&path_uri) {
579 return Err(anyhow!(
580 "Absolute path provided for building server url, need to provide a relative uri"
581 ));
582 }
583
584 let server_url = self.url()
585 .ok_or_else(||
586 anyhow!(
587 "No local address for server, need to run with HTTP transport to have a server address",
588 )
589 )?;
590
591 let server_locked = self.state.as_ref().lock().map_err(|err| {
592 anyhow!("Failed to lock InternalTestServer, for building server_url, received {err:?}",)
593 })?;
594 let mut query_params = server_locked.query_params().clone();
595 let mut full_server_url = build_url(
596 server_url,
597 path,
598 &mut query_params,
599 self.is_http_path_restricted,
600 )?;
601
602 if query_params.has_content() {
604 full_server_url.set_query(Some(&query_params.to_string()));
605 }
606
607 Ok(full_server_url)
608 }
609
610 pub fn add_cookie(&mut self, cookie: Cookie) {
615 ServerSharedState::add_cookie(&self.state, cookie)
616 .context("Trying to call add_cookie")
617 .unwrap()
618 }
619
620 pub fn add_cookies(&mut self, cookies: CookieJar) {
625 ServerSharedState::add_cookies(&self.state, cookies)
626 .context("Trying to call add_cookies")
627 .unwrap()
628 }
629
630 pub fn clear_cookies(&mut self) {
632 ServerSharedState::clear_cookies(&self.state)
633 .context("Trying to call clear_cookies")
634 .unwrap()
635 }
636
637 pub fn save_cookies(&mut self) {
641 self.save_cookies = true;
642 }
643
644 pub fn do_not_save_cookies(&mut self) {
648 self.save_cookies = false;
649 }
650
651 pub fn expect_success(&mut self) {
655 self.expected_state = ExpectedState::Success;
656 }
657
658 pub fn expect_failure(&mut self) {
662 self.expected_state = ExpectedState::Failure;
663 }
664
665 pub fn add_query_param<V>(&mut self, key: &str, value: V)
667 where
668 V: Serialize,
669 {
670 ServerSharedState::add_query_param(&self.state, key, value)
671 .context("Trying to call add_query_param")
672 .unwrap()
673 }
674
675 pub fn add_query_params<V>(&mut self, query_params: V)
677 where
678 V: Serialize,
679 {
680 ServerSharedState::add_query_params(&self.state, query_params)
681 .context("Trying to call add_query_params")
682 .unwrap()
683 }
684
685 pub fn add_raw_query_param(&mut self, raw_query_param: &str) {
688 ServerSharedState::add_raw_query_param(&self.state, raw_query_param)
689 .context("Trying to call add_raw_query_param")
690 .unwrap()
691 }
692
693 pub fn clear_query_params(&mut self) {
695 ServerSharedState::clear_query_params(&self.state)
696 .context("Trying to call clear_query_params")
697 .unwrap()
698 }
699
700 pub fn add_header<N, V>(&mut self, name: N, value: V)
721 where
722 N: TryInto<HeaderName>,
723 N::Error: Debug,
724 V: TryInto<HeaderValue>,
725 V::Error: Debug,
726 {
727 let header_name: HeaderName = name
728 .try_into()
729 .expect("Failed to convert header name to HeaderName");
730 let header_value: HeaderValue = value
731 .try_into()
732 .expect("Failed to convert header vlue to HeaderValue");
733
734 ServerSharedState::add_header(&self.state, header_name, header_value)
735 .context("Trying to call add_header")
736 .unwrap()
737 }
738
739 pub fn clear_headers(&mut self) {
741 ServerSharedState::clear_headers(&self.state)
742 .context("Trying to call clear_headers")
743 .unwrap()
744 }
745
746 pub fn scheme(&mut self, scheme: &str) {
770 ServerSharedState::set_scheme(&self.state, scheme.to_string())
771 .context("Trying to call set_scheme")
772 .unwrap()
773 }
774
775 pub(crate) fn url(&self) -> Option<Url> {
776 self.transport.url().cloned()
777 }
778
779 pub(crate) fn build_test_request_config(
780 &self,
781 method: Method,
782 path: &str,
783 ) -> Result<TestRequestConfig> {
784 let url = self
785 .url()
786 .unwrap_or_else(|| DEFAULT_URL_ADDRESS.parse().unwrap());
787
788 let server_locked = self.state.as_ref().lock().map_err(|err| {
789 anyhow!(
790 "Failed to lock InternalTestServer, for request {method} {path}, received {err:?}",
791 )
792 })?;
793
794 let cookies = server_locked.cookies().clone();
795 let mut query_params = server_locked.query_params().clone();
796 let headers = server_locked.headers().clone();
797 let mut full_request_url =
798 build_url(url, path, &mut query_params, self.is_http_path_restricted)?;
799
800 if let Some(scheme) = server_locked.scheme() {
801 full_request_url.set_scheme(scheme).map_err(|_| {
802 let debug_request_format = RequestPathFormatter::new(&method, full_request_url.as_str(), Some(&query_params));
803 anyhow!("Scheme '{scheme}' from TestServer cannot be set to request {debug_request_format}")
804 })?;
805 }
806
807 ::std::mem::drop(server_locked);
808
809 Ok(TestRequestConfig {
810 is_saving_cookies: self.save_cookies,
811 expected_state: self.expected_state,
812 content_type: self.default_content_type.clone(),
813 method,
814
815 full_request_url,
816 cookies,
817 query_params,
818 headers,
819 })
820 }
821
822 pub fn is_running(&self) -> bool {
828 self.transport.is_running()
829 }
830}
831
832fn build_url(
833 mut url: Url,
834 path: &str,
835 query_params: &mut QueryParamsStore,
836 is_http_restricted: bool,
837) -> Result<Url> {
838 let path_uri = path.parse::<Uri>()?;
839
840 if let Some(scheme) = path_uri.scheme_str() {
842 if is_http_restricted {
843 if has_different_schema(&url, &path_uri) || has_different_authority(&url, &path_uri) {
844 return Err(anyhow!(
845 "Request disallowed for path '{path}', requests are only allowed to local server. Turn off 'restrict_requests_with_http_schema' to change this."
846 ));
847 }
848 } else {
849 url.set_scheme(scheme)
850 .map_err(|_| anyhow!("Failed to set scheme for request, with path '{path}'"))?;
851
852 if let Some(authority) = path_uri.authority() {
854 url.set_host(Some(authority.host()))
855 .map_err(|_| anyhow!("Failed to set host for request, with path '{path}'"))?;
856 url.set_port(authority.port().map(|p| p.as_u16()))
857 .map_err(|_| anyhow!("Failed to set port for request, with path '{path}'"))?;
858
859 }
861 }
862 }
863
864 if is_absolute_uri(&path_uri) {
874 url.set_path(path_uri.path());
875
876 if url.query().is_some() {
878 url.set_query(None);
879 }
880 } else {
881 let calculated_path = path.split('?').next().unwrap_or(path);
883 url.set_path(calculated_path);
884
885 if let Some(url_query) = url.query() {
887 query_params.add_raw(url_query.to_string());
888 url.set_query(None);
889 }
890 }
891
892 if let Some(path_query) = path_uri.query() {
893 query_params.add_raw(path_query.to_string());
894 }
895
896 Ok(url)
897}
898
899fn is_absolute_uri(path_uri: &Uri) -> bool {
900 path_uri.scheme_str().is_some()
901}
902
903fn has_different_schema(base_url: &Url, path_uri: &Uri) -> bool {
904 if let Some(scheme) = path_uri.scheme_str() {
905 return scheme != base_url.scheme();
906 }
907
908 false
909}
910
911fn has_different_authority(base_url: &Url, path_uri: &Uri) -> bool {
912 if let Some(authority) = path_uri.authority() {
913 return authority.as_str() != base_url.authority();
914 }
915
916 false
917}
918
919#[cfg(test)]
920mod test_build_url {
921 use super::*;
922
923 #[test]
924 fn it_should_copy_path_to_url_returned_when_restricted() {
925 let base_url = "http://example.com".parse::<Url>().unwrap();
926 let path = "/users";
927 let mut query_params = QueryParamsStore::new();
928 let result = build_url(base_url, &path, &mut query_params, true).unwrap();
929
930 assert_eq!("http://example.com/users", result.as_str());
931 assert!(query_params.is_empty());
932 }
933
934 #[test]
935 fn it_should_copy_all_query_params_to_store_when_restricted() {
936 let base_url = "http://example.com?base=aaa".parse::<Url>().unwrap();
937 let path = "/users?path=bbb&path-flag";
938 let mut query_params = QueryParamsStore::new();
939 let result = build_url(base_url, &path, &mut query_params, true).unwrap();
940
941 assert_eq!("http://example.com/users", result.as_str());
942 assert_eq!("base=aaa&path=bbb&path-flag", query_params.to_string());
943 }
944
945 #[test]
946 fn it_should_not_replace_url_when_restricted_with_different_scheme() {
947 let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
948 let path = "ftp://google.com:123/users.csv?limit=456";
949 let mut query_params = QueryParamsStore::new();
950 let result = build_url(base_url, &path, &mut query_params, true);
951
952 assert!(result.is_err());
953 }
954
955 #[test]
956 fn it_should_not_replace_url_when_restricted_with_same_scheme() {
957 let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
958 let path = "http://google.com:123/users.csv?limit=456";
959 let mut query_params = QueryParamsStore::new();
960 let result = build_url(base_url, &path, &mut query_params, true);
961
962 assert!(result.is_err());
963 }
964
965 #[test]
966 fn it_should_block_url_when_restricted_with_same_scheme() {
967 let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
968 let path = "http://google.com";
969 let mut query_params = QueryParamsStore::new();
970 let result = build_url(base_url, &path, &mut query_params, true);
971
972 assert!(result.is_err());
973 }
974
975 #[test]
976 fn it_should_block_url_when_restricted_and_same_domain_with_different_scheme() {
977 let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
978 let path = "ftp://example.com/users";
979 let mut query_params = QueryParamsStore::new();
980 let result = build_url(base_url, &path, &mut query_params, true);
981
982 assert!(result.is_err());
983 }
984
985 #[test]
986 fn it_should_copy_path_to_url_returned_when_unrestricted() {
987 let base_url = "http://example.com".parse::<Url>().unwrap();
988 let path = "/users";
989 let mut query_params = QueryParamsStore::new();
990 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
991
992 assert_eq!("http://example.com/users", result.as_str());
993 assert!(query_params.is_empty());
994 }
995
996 #[test]
997 fn it_should_copy_all_query_params_to_store_when_unrestricted() {
998 let base_url = "http://example.com?base=aaa".parse::<Url>().unwrap();
999 let path = "/users?path=bbb&path-flag";
1000 let mut query_params = QueryParamsStore::new();
1001 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1002
1003 assert_eq!("http://example.com/users", result.as_str());
1004 assert_eq!("base=aaa&path=bbb&path-flag", query_params.to_string());
1005 }
1006
1007 #[test]
1008 fn it_should_copy_host_like_a_path_when_unrestricted() {
1009 let base_url = "http://example.com".parse::<Url>().unwrap();
1010 let path = "google.com";
1011 let mut query_params = QueryParamsStore::new();
1012 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1013
1014 assert_eq!("http://example.com/google.com", result.as_str());
1015 assert!(query_params.is_empty());
1016 }
1017
1018 #[test]
1019 fn it_should_copy_host_like_a_path_when_restricted() {
1020 let base_url = "http://example.com".parse::<Url>().unwrap();
1021 let path = "google.com";
1022 let mut query_params = QueryParamsStore::new();
1023 let result = build_url(base_url, &path, &mut query_params, true).unwrap();
1024
1025 assert_eq!("http://example.com/google.com", result.as_str());
1026 assert!(query_params.is_empty());
1027 }
1028
1029 #[test]
1030 fn it_should_replace_url_when_unrestricted() {
1031 let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
1032 let path = "ftp://google.com:123/users.csv?limit=456";
1033 let mut query_params = QueryParamsStore::new();
1034 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1035
1036 assert_eq!("ftp://google.com:123/users.csv", result.as_str());
1037 assert_eq!("limit=456", query_params.to_string());
1038 }
1039
1040 #[test]
1041 fn it_should_allow_different_scheme_when_unrestricted() {
1042 let base_url = "http://example.com".parse::<Url>().unwrap();
1043 let path = "ftp://example.com";
1044 let mut query_params = QueryParamsStore::new();
1045 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1046
1047 assert_eq!("ftp://example.com/", result.as_str());
1048 }
1049
1050 #[test]
1051 fn it_should_allow_different_host_when_unrestricted() {
1052 let base_url = "http://example.com".parse::<Url>().unwrap();
1053 let path = "http://google.com";
1054 let mut query_params = QueryParamsStore::new();
1055 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1056
1057 assert_eq!("http://google.com/", result.as_str());
1058 }
1059
1060 #[test]
1061 fn it_should_allow_different_port_when_unrestricted() {
1062 let base_url = "http://example.com:123".parse::<Url>().unwrap();
1063 let path = "http://example.com:456";
1064 let mut query_params = QueryParamsStore::new();
1065 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1066
1067 assert_eq!("http://example.com:456/", result.as_str());
1068 }
1069
1070 #[test]
1071 fn it_should_allow_same_host_port_when_unrestricted() {
1072 let base_url = "http://example.com:123".parse::<Url>().unwrap();
1073 let path = "http://example.com:123";
1074 let mut query_params = QueryParamsStore::new();
1075 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1076
1077 assert_eq!("http://example.com:123/", result.as_str());
1078 }
1079
1080 #[test]
1081 fn it_should_not_allow_different_scheme_when_restricted() {
1082 let base_url = "http://example.com".parse::<Url>().unwrap();
1083 let path = "ftp://example.com";
1084 let mut query_params = QueryParamsStore::new();
1085 let result = build_url(base_url, &path, &mut query_params, true);
1086
1087 assert!(result.is_err());
1088 }
1089
1090 #[test]
1091 fn it_should_not_allow_different_host_when_restricted() {
1092 let base_url = "http://example.com".parse::<Url>().unwrap();
1093 let path = "http://google.com";
1094 let mut query_params = QueryParamsStore::new();
1095 let result = build_url(base_url, &path, &mut query_params, true);
1096
1097 assert!(result.is_err());
1098 }
1099
1100 #[test]
1101 fn it_should_not_allow_different_port_when_restricted() {
1102 let base_url = "http://example.com:123".parse::<Url>().unwrap();
1103 let path = "http://example.com:456";
1104 let mut query_params = QueryParamsStore::new();
1105 let result = build_url(base_url, &path, &mut query_params, true);
1106
1107 assert!(result.is_err());
1108 }
1109
1110 #[test]
1111 fn it_should_allow_same_host_port_when_restricted() {
1112 let base_url = "http://example.com:123".parse::<Url>().unwrap();
1113 let path = "http://example.com:123";
1114 let mut query_params = QueryParamsStore::new();
1115 let result = build_url(base_url, &path, &mut query_params, true).unwrap();
1116
1117 assert_eq!("http://example.com:123/", result.as_str());
1118 }
1119}
1120
1121#[cfg(test)]
1122mod test_new {
1123 use axum::Router;
1124 use axum::routing::get;
1125 use std::net::SocketAddr;
1126
1127 use crate::TestServer;
1128
1129 async fn get_ping() -> &'static str {
1130 "pong!"
1131 }
1132
1133 #[tokio::test]
1134 async fn it_should_run_into_make_into_service_with_connect_info_by_default() {
1135 let app = Router::new()
1137 .route("/ping", get(get_ping))
1138 .into_make_service_with_connect_info::<SocketAddr>();
1139
1140 let server = TestServer::new(app).expect("Should create test server");
1142
1143 server.get(&"/ping").await.assert_text(&"pong!");
1145 }
1146}
1147
1148#[cfg(test)]
1149mod test_get {
1150 use super::*;
1151
1152 use axum::Router;
1153 use axum::routing::get;
1154 use reserve_port::ReservedSocketAddr;
1155
1156 async fn get_ping() -> &'static str {
1157 "pong!"
1158 }
1159
1160 #[tokio::test]
1161 async fn it_should_get_using_relative_path_with_slash() {
1162 let app = Router::new().route("/ping", get(get_ping));
1163 let server = TestServer::new(app).expect("Should create test server");
1164
1165 server.get(&"/ping").await.assert_text(&"pong!");
1167 }
1168
1169 #[tokio::test]
1170 async fn it_should_get_using_relative_path_without_slash() {
1171 let app = Router::new().route("/ping", get(get_ping));
1172 let server = TestServer::new(app).expect("Should create test server");
1173
1174 server.get(&"ping").await.assert_text(&"pong!");
1176 }
1177
1178 #[tokio::test]
1179 async fn it_should_get_using_absolute_path() {
1180 let app = Router::new().route("/ping", get(get_ping));
1182
1183 let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1185 let ip = reserved_address.ip();
1186 let port = reserved_address.port();
1187
1188 let server = TestServer::builder()
1190 .http_transport_with_ip_port(Some(ip), Some(port))
1191 .build(app)
1192 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1193 .unwrap();
1194
1195 let absolute_url = format!("http://{ip}:{port}/ping");
1197 let response = server.get(&absolute_url).await;
1198
1199 response.assert_text(&"pong!");
1200 let request_path = response.request_url();
1201 assert_eq!(request_path.to_string(), format!("http://{ip}:{port}/ping"));
1202 }
1203
1204 #[tokio::test]
1205 async fn it_should_get_using_absolute_path_and_restricted_if_path_is_for_server() {
1206 let app = Router::new().route("/ping", get(get_ping));
1208
1209 let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1211 let ip = reserved_address.ip();
1212 let port = reserved_address.port();
1213
1214 let server = TestServer::builder()
1216 .http_transport_with_ip_port(Some(ip), Some(port))
1217 .restrict_requests_with_http_schema() .build(app)
1219 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1220 .unwrap();
1221
1222 let absolute_url = format!("http://{ip}:{port}/ping");
1224 let response = server.get(&absolute_url).await;
1225
1226 response.assert_text(&"pong!");
1227 let request_path = response.request_url();
1228 assert_eq!(request_path.to_string(), format!("http://{ip}:{port}/ping"));
1229 }
1230
1231 #[tokio::test]
1232 #[should_panic]
1233 async fn it_should_not_get_using_absolute_path_if_restricted_and_different_port() {
1234 let app = Router::new().route("/ping", get(get_ping));
1236
1237 let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1239 let ip = reserved_address.ip();
1240 let mut port = reserved_address.port();
1241
1242 let server = TestServer::builder()
1244 .http_transport_with_ip_port(Some(ip), Some(port))
1245 .restrict_requests_with_http_schema() .build(app)
1247 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1248 .unwrap();
1249
1250 port += 1; let absolute_url = format!("http://{ip}:{port}/ping");
1253 server.get(&absolute_url).await;
1254 }
1255
1256 #[tokio::test]
1257 async fn it_should_work_in_parallel() {
1258 let app = Router::new().route("/ping", get(get_ping));
1259 let server = TestServer::new(app).expect("Should create test server");
1260
1261 let future1 = async { server.get("/ping").await };
1262 let future2 = async { server.get("/ping").await };
1263 let (r1, r2) = tokio::join!(future1, future2);
1264
1265 assert_eq!(r1.text(), r2.text());
1266 }
1267
1268 #[tokio::test]
1269 async fn it_should_work_in_parallel_with_sleeping_requests() {
1270 let app = axum::Router::new().route(
1271 &"/slow",
1272 axum::routing::get(|| async {
1273 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1274 "hello!"
1275 }),
1276 );
1277
1278 let server = TestServer::new(app).expect("Should create test server");
1279
1280 let future1 = async { server.get("/slow").await };
1281 let future2 = async { server.get("/slow").await };
1282 let (r1, r2) = tokio::join!(future1, future2);
1283
1284 assert_eq!(r1.text(), r2.text());
1285 }
1286}
1287
1288#[cfg(feature = "reqwest")]
1289#[cfg(test)]
1290mod test_reqwest_get {
1291 use super::*;
1292
1293 use axum::Router;
1294 use axum::routing::get;
1295
1296 async fn get_ping() -> &'static str {
1297 "pong!"
1298 }
1299
1300 #[tokio::test]
1301 async fn it_should_get_using_relative_path_with_slash() {
1302 let app = Router::new().route("/ping", get(get_ping));
1303 let server = TestServer::builder()
1304 .http_transport()
1305 .build(app)
1306 .expect("Should create test server");
1307
1308 let response = server
1309 .reqwest_get(&"/ping")
1310 .send()
1311 .await
1312 .unwrap()
1313 .text()
1314 .await
1315 .unwrap();
1316
1317 assert_eq!(response, "pong!");
1318 }
1319}
1320
1321#[cfg(feature = "reqwest")]
1322#[cfg(test)]
1323mod test_reqwest_post {
1324 use super::*;
1325
1326 use axum::Json;
1327 use axum::Router;
1328 use axum::routing::post;
1329 use serde::Deserialize;
1330
1331 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1332 struct TestBody {
1333 number: u32,
1334 text: String,
1335 }
1336
1337 async fn post_json(Json(body): Json<TestBody>) -> Json<TestBody> {
1338 let response = TestBody {
1339 number: body.number * 2,
1340 text: format!("{}_plus_response", body.text),
1341 };
1342
1343 Json(response)
1344 }
1345
1346 #[tokio::test]
1347 async fn it_should_post_and_receive_json() {
1348 let app = Router::new().route("/json", post(post_json));
1349 let server = TestServer::builder()
1350 .http_transport()
1351 .build(app)
1352 .expect("Should create test server");
1353
1354 let response = server
1355 .reqwest_post(&"/json")
1356 .json(&TestBody {
1357 number: 111,
1358 text: format!("request"),
1359 })
1360 .send()
1361 .await
1362 .unwrap()
1363 .json::<TestBody>()
1364 .await
1365 .unwrap();
1366
1367 assert_eq!(
1368 response,
1369 TestBody {
1370 number: 222,
1371 text: format!("request_plus_response"),
1372 }
1373 );
1374 }
1375}
1376
1377#[cfg(test)]
1378mod test_server_address {
1379 use super::*;
1380
1381 use axum::Router;
1382 use local_ip_address::local_ip;
1383 use regex::Regex;
1384 use reserve_port::ReservedPort;
1385
1386 #[tokio::test]
1387 async fn it_should_return_address_used_from_config() {
1388 let reserved_port = ReservedPort::random().unwrap();
1389 let ip = local_ip().unwrap();
1390 let port = reserved_port.port();
1391
1392 let app = Router::new();
1394 let server = TestServer::builder()
1395 .http_transport_with_ip_port(Some(ip), Some(port))
1396 .build(app)
1397 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1398 .unwrap();
1399
1400 let expected_ip_port = format!("http://{}:{}/", ip, reserved_port.port());
1401 assert_eq!(
1402 server.server_address().unwrap().to_string(),
1403 expected_ip_port
1404 );
1405 }
1406
1407 #[tokio::test]
1408 async fn it_should_return_default_address_without_ending_slash() {
1409 let app = Router::new();
1410 let server = TestServer::builder()
1411 .http_transport()
1412 .build(app)
1413 .expect("Should create test server");
1414
1415 let address_regex = Regex::new("^http://127\\.0\\.0\\.1:[0-9]+/$").unwrap();
1416 let is_match = address_regex.is_match(&server.server_address().unwrap().to_string());
1417 assert!(is_match);
1418 }
1419
1420 #[tokio::test]
1421 async fn it_should_return_none_on_mock_transport() {
1422 let app = Router::new();
1423 let server = TestServer::builder()
1424 .mock_transport()
1425 .build(app)
1426 .expect("Should create test server");
1427
1428 assert!(server.server_address().is_none());
1429 }
1430}
1431
1432#[cfg(test)]
1433mod test_server_url {
1434 use super::*;
1435
1436 use axum::Router;
1437 use local_ip_address::local_ip;
1438 use regex::Regex;
1439 use reserve_port::ReservedPort;
1440
1441 #[tokio::test]
1442 async fn it_should_return_address_with_url_on_http_ip_port() {
1443 let reserved_port = ReservedPort::random().unwrap();
1444 let ip = local_ip().unwrap();
1445 let port = reserved_port.port();
1446
1447 let app = Router::new();
1449 let server = TestServer::builder()
1450 .http_transport_with_ip_port(Some(ip), Some(port))
1451 .build(app)
1452 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1453 .unwrap();
1454
1455 let expected_ip_port_url = format!("http://{}:{}/users", ip, reserved_port.port());
1456 let absolute_url = server.server_url("/users").unwrap().to_string();
1457 assert_eq!(absolute_url, expected_ip_port_url);
1458 }
1459
1460 #[tokio::test]
1461 async fn it_should_return_address_with_url_on_random_http() {
1462 let app = Router::new();
1463 let server = TestServer::builder()
1464 .http_transport()
1465 .build(app)
1466 .expect("Should create test server");
1467
1468 let address_regex =
1469 Regex::new("^http://127\\.0\\.0\\.1:[0-9]+/users/123\\?filter=enabled$").unwrap();
1470 let absolute_url = &server
1471 .server_url(&"/users/123?filter=enabled")
1472 .unwrap()
1473 .to_string();
1474
1475 let is_match = address_regex.is_match(absolute_url);
1476 assert!(is_match);
1477 }
1478
1479 #[tokio::test]
1480 async fn it_should_error_on_mock_transport() {
1481 let app = Router::new();
1483 let server = TestServer::builder()
1484 .mock_transport()
1485 .build(app)
1486 .expect("Should create test server");
1487
1488 let result = server.server_url("/users");
1489 assert!(result.is_err());
1490 }
1491
1492 #[tokio::test]
1493 async fn it_should_include_path_query_params() {
1494 let reserved_port = ReservedPort::random().unwrap();
1495 let ip = local_ip().unwrap();
1496 let port = reserved_port.port();
1497
1498 let app = Router::new();
1500 let server = TestServer::builder()
1501 .http_transport_with_ip_port(Some(ip), Some(port))
1502 .build(app)
1503 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1504 .unwrap();
1505
1506 let expected_url = format!(
1507 "http://{}:{}/users?filter=enabled",
1508 ip,
1509 reserved_port.port()
1510 );
1511 let received_url = server
1512 .server_url("/users?filter=enabled")
1513 .unwrap()
1514 .to_string();
1515
1516 assert_eq!(received_url, expected_url);
1517 }
1518
1519 #[tokio::test]
1520 async fn it_should_include_server_query_params() {
1521 let reserved_port = ReservedPort::random().unwrap();
1522 let ip = local_ip().unwrap();
1523 let port = reserved_port.port();
1524
1525 let app = Router::new();
1527 let mut server = TestServer::builder()
1528 .http_transport_with_ip_port(Some(ip), Some(port))
1529 .build(app)
1530 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1531 .unwrap();
1532
1533 server.add_query_param("filter", "enabled");
1534
1535 let expected_url = format!(
1536 "http://{}:{}/users?filter=enabled",
1537 ip,
1538 reserved_port.port()
1539 );
1540 let received_url = server.server_url("/users").unwrap().to_string();
1541
1542 assert_eq!(received_url, expected_url);
1543 }
1544
1545 #[tokio::test]
1546 async fn it_should_include_server_and_path_query_params() {
1547 let reserved_port = ReservedPort::random().unwrap();
1548 let ip = local_ip().unwrap();
1549 let port = reserved_port.port();
1550
1551 let app = Router::new();
1553 let mut server = TestServer::builder()
1554 .http_transport_with_ip_port(Some(ip), Some(port))
1555 .build(app)
1556 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1557 .unwrap();
1558
1559 server.add_query_param("filter", "enabled");
1560
1561 let expected_url = format!(
1562 "http://{}:{}/users?filter=enabled&animal=donkeys",
1563 ip,
1564 reserved_port.port()
1565 );
1566 let received_url = server
1567 .server_url("/users?animal=donkeys")
1568 .unwrap()
1569 .to_string();
1570
1571 assert_eq!(received_url, expected_url);
1572 }
1573}
1574
1575#[cfg(test)]
1576mod test_add_cookie {
1577 use crate::TestServer;
1578
1579 use axum::Router;
1580 use axum::routing::get;
1581 use axum_extra::extract::cookie::CookieJar;
1582 use cookie::Cookie;
1583
1584 const TEST_COOKIE_NAME: &'static str = &"test-cookie";
1585
1586 async fn get_cookie(cookies: CookieJar) -> (CookieJar, String) {
1587 let cookie = cookies.get(&TEST_COOKIE_NAME);
1588 let cookie_value = cookie
1589 .map(|c| c.value().to_string())
1590 .unwrap_or_else(|| "cookie-not-found".to_string());
1591
1592 (cookies, cookie_value)
1593 }
1594
1595 #[tokio::test]
1596 async fn it_should_send_cookies_added_to_request() {
1597 let app = Router::new().route("/cookie", get(get_cookie));
1598 let mut server = TestServer::new(app).expect("Should create test server");
1599
1600 let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1601 server.add_cookie(cookie);
1602
1603 let response_text = server.get(&"/cookie").await.text();
1604 assert_eq!(response_text, "my-custom-cookie");
1605 }
1606}
1607
1608#[cfg(test)]
1609mod test_add_cookies {
1610 use crate::TestServer;
1611
1612 use axum::Router;
1613 use axum::routing::get;
1614 use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1615 use cookie::Cookie;
1616 use cookie::CookieJar;
1617
1618 async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1619 let mut all_cookies = cookies
1620 .iter()
1621 .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1622 .collect::<Vec<String>>();
1623 all_cookies.sort();
1624
1625 all_cookies.join(&", ")
1626 }
1627
1628 #[tokio::test]
1629 async fn it_should_send_all_cookies_added_by_jar() {
1630 let app = Router::new().route("/cookies", get(route_get_cookies));
1631 let mut server = TestServer::new(app).expect("Should create test server");
1632
1633 let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1635 let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1636 let mut cookie_jar = CookieJar::new();
1637 cookie_jar.add(cookie_1);
1638 cookie_jar.add(cookie_2);
1639
1640 server.add_cookies(cookie_jar);
1641
1642 server
1643 .get(&"/cookies")
1644 .await
1645 .assert_text("first-cookie=my-custom-cookie, second-cookie=other-cookie");
1646 }
1647}
1648
1649#[cfg(test)]
1650mod test_clear_cookies {
1651 use crate::TestServer;
1652
1653 use axum::Router;
1654 use axum::routing::get;
1655 use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1656 use cookie::Cookie;
1657 use cookie::CookieJar;
1658
1659 async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1660 let mut all_cookies = cookies
1661 .iter()
1662 .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1663 .collect::<Vec<String>>();
1664 all_cookies.sort();
1665
1666 all_cookies.join(&", ")
1667 }
1668
1669 #[tokio::test]
1670 async fn it_should_not_send_cookies_cleared() {
1671 let app = Router::new().route("/cookies", get(route_get_cookies));
1672 let mut server = TestServer::new(app).expect("Should create test server");
1673
1674 let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1675 let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1676 let mut cookie_jar = CookieJar::new();
1677 cookie_jar.add(cookie_1);
1678 cookie_jar.add(cookie_2);
1679
1680 server.add_cookies(cookie_jar);
1681
1682 server.clear_cookies();
1684
1685 server.get(&"/cookies").await.assert_text("");
1686 }
1687}
1688
1689#[cfg(test)]
1690mod test_add_header {
1691 use super::*;
1692 use crate::TestServer;
1693 use axum::Router;
1694 use axum::extract::FromRequestParts;
1695 use axum::routing::get;
1696 use http::HeaderName;
1697 use http::HeaderValue;
1698 use http::request::Parts;
1699 use hyper::StatusCode;
1700 use std::marker::Sync;
1701
1702 const TEST_HEADER_NAME: &'static str = &"test-header";
1703 const TEST_HEADER_CONTENT: &'static str = &"Test header content";
1704
1705 struct TestHeader(Vec<u8>);
1706
1707 impl<S: Sync> FromRequestParts<S> for TestHeader {
1708 type Rejection = (StatusCode, &'static str);
1709
1710 async fn from_request_parts(
1711 parts: &mut Parts,
1712 _state: &S,
1713 ) -> Result<TestHeader, Self::Rejection> {
1714 parts
1715 .headers
1716 .get(HeaderName::from_static(TEST_HEADER_NAME))
1717 .map(|v| TestHeader(v.as_bytes().to_vec()))
1718 .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
1719 }
1720 }
1721
1722 async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
1723 header
1724 }
1725
1726 #[tokio::test]
1727 async fn it_should_send_header_added_to_server() {
1728 let app = Router::new().route("/header", get(ping_header));
1730
1731 let mut server = TestServer::new(app).expect("Should create test server");
1733 server.add_header(
1734 HeaderName::from_static(TEST_HEADER_NAME),
1735 HeaderValue::from_static(TEST_HEADER_CONTENT),
1736 );
1737
1738 let response = server.get(&"/header").await;
1740
1741 response.assert_text(TEST_HEADER_CONTENT)
1743 }
1744}
1745
1746#[cfg(test)]
1747mod test_clear_headers {
1748 use super::*;
1749 use crate::TestServer;
1750 use axum::Router;
1751 use axum::extract::FromRequestParts;
1752 use axum::routing::get;
1753 use http::HeaderName;
1754 use http::HeaderValue;
1755 use http::request::Parts;
1756 use hyper::StatusCode;
1757 use std::marker::Sync;
1758
1759 const TEST_HEADER_NAME: &'static str = &"test-header";
1760 const TEST_HEADER_CONTENT: &'static str = &"Test header content";
1761
1762 struct TestHeader(Vec<u8>);
1763
1764 impl<S: Sync> FromRequestParts<S> for TestHeader {
1765 type Rejection = (StatusCode, &'static str);
1766
1767 async fn from_request_parts(
1768 parts: &mut Parts,
1769 _state: &S,
1770 ) -> Result<Self, Self::Rejection> {
1771 parts
1772 .headers
1773 .get(HeaderName::from_static(TEST_HEADER_NAME))
1774 .map(|v| TestHeader(v.as_bytes().to_vec()))
1775 .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
1776 }
1777 }
1778
1779 async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
1780 header
1781 }
1782
1783 #[tokio::test]
1784 async fn it_should_not_send_headers_cleared_by_server() {
1785 let app = Router::new().route("/header", get(ping_header));
1787
1788 let mut server = TestServer::new(app).expect("Should create test server");
1790 server.add_header(
1791 HeaderName::from_static(TEST_HEADER_NAME),
1792 HeaderValue::from_static(TEST_HEADER_CONTENT),
1793 );
1794 server.clear_headers();
1795
1796 let response = server.get(&"/header").await;
1798
1799 response.assert_status_bad_request();
1801 response.assert_text("Missing test header");
1802 }
1803}
1804
1805#[cfg(test)]
1806mod test_add_query_params {
1807 use axum::Router;
1808 use axum::extract::Query;
1809 use axum::routing::get;
1810
1811 use serde::Deserialize;
1812 use serde::Serialize;
1813 use serde_json::json;
1814
1815 use crate::TestServer;
1816
1817 #[derive(Debug, Deserialize, Serialize)]
1818 struct QueryParam {
1819 message: String,
1820 }
1821
1822 async fn get_query_param(Query(params): Query<QueryParam>) -> String {
1823 params.message
1824 }
1825
1826 #[derive(Debug, Deserialize, Serialize)]
1827 struct QueryParam2 {
1828 message: String,
1829 other: String,
1830 }
1831
1832 async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
1833 format!("{}-{}", params.message, params.other)
1834 }
1835
1836 #[tokio::test]
1837 async fn it_should_pass_up_query_params_from_serialization() {
1838 let app = Router::new().route("/query", get(get_query_param));
1840
1841 let mut server = TestServer::new(app).expect("Should create test server");
1843 server.add_query_params(QueryParam {
1844 message: "it works".to_string(),
1845 });
1846
1847 server.get(&"/query").await.assert_text(&"it works");
1849 }
1850
1851 #[tokio::test]
1852 async fn it_should_pass_up_query_params_from_pairs() {
1853 let app = Router::new().route("/query", get(get_query_param));
1855
1856 let mut server = TestServer::new(app).expect("Should create test server");
1858 server.add_query_params(&[("message", "it works")]);
1859
1860 server.get(&"/query").await.assert_text(&"it works");
1862 }
1863
1864 #[tokio::test]
1865 async fn it_should_pass_up_multiple_query_params_from_multiple_params() {
1866 let app = Router::new().route("/query-2", get(get_query_param_2));
1868
1869 let mut server = TestServer::new(app).expect("Should create test server");
1871 server.add_query_params(&[("message", "it works"), ("other", "yup")]);
1872
1873 server.get(&"/query-2").await.assert_text(&"it works-yup");
1875 }
1876
1877 #[tokio::test]
1878 async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
1879 let app = Router::new().route("/query-2", get(get_query_param_2));
1881
1882 let mut server = TestServer::new(app).expect("Should create test server");
1884 server.add_query_params(&[("message", "it works")]);
1885 server.add_query_params(&[("other", "yup")]);
1886
1887 server.get(&"/query-2").await.assert_text(&"it works-yup");
1889 }
1890
1891 #[tokio::test]
1892 async fn it_should_pass_up_multiple_query_params_from_json() {
1893 let app = Router::new().route("/query-2", get(get_query_param_2));
1895
1896 let mut server = TestServer::new(app).expect("Should create test server");
1898 server.add_query_params(json!({
1899 "message": "it works",
1900 "other": "yup"
1901 }));
1902
1903 server.get(&"/query-2").await.assert_text(&"it works-yup");
1905 }
1906}
1907
1908#[cfg(test)]
1909mod test_add_query_param {
1910 use axum::Router;
1911 use axum::extract::Query;
1912 use axum::routing::get;
1913
1914 use serde::Deserialize;
1915 use serde::Serialize;
1916
1917 use crate::TestServer;
1918
1919 #[derive(Debug, Deserialize, Serialize)]
1920 struct QueryParam {
1921 message: String,
1922 }
1923
1924 async fn get_query_param(Query(params): Query<QueryParam>) -> String {
1925 params.message
1926 }
1927
1928 #[derive(Debug, Deserialize, Serialize)]
1929 struct QueryParam2 {
1930 message: String,
1931 other: String,
1932 }
1933
1934 async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
1935 format!("{}-{}", params.message, params.other)
1936 }
1937
1938 #[tokio::test]
1939 async fn it_should_pass_up_query_params_from_pairs() {
1940 let app = Router::new().route("/query", get(get_query_param));
1942
1943 let mut server = TestServer::new(app).expect("Should create test server");
1945 server.add_query_param("message", "it works");
1946
1947 server.get(&"/query").await.assert_text(&"it works");
1949 }
1950
1951 #[tokio::test]
1952 async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
1953 let app = Router::new().route("/query-2", get(get_query_param_2));
1955
1956 let mut server = TestServer::new(app).expect("Should create test server");
1958 server.add_query_param("message", "it works");
1959 server.add_query_param("other", "yup");
1960
1961 server.get(&"/query-2").await.assert_text(&"it works-yup");
1963 }
1964
1965 #[tokio::test]
1966 async fn it_should_pass_up_multiple_query_params_from_calls_across_server_and_request() {
1967 let app = Router::new().route("/query-2", get(get_query_param_2));
1969
1970 let mut server = TestServer::new(app).expect("Should create test server");
1972 server.add_query_param("message", "it works");
1973
1974 server
1976 .get(&"/query-2")
1977 .add_query_param("other", "yup")
1978 .await
1979 .assert_text(&"it works-yup");
1980 }
1981}
1982
1983#[cfg(test)]
1984mod test_add_raw_query_param {
1985 use axum::Router;
1986 use axum::extract::Query as AxumStdQuery;
1987 use axum::routing::get;
1988 use axum_extra::extract::Query as AxumExtraQuery;
1989 use serde::Deserialize;
1990 use serde::Serialize;
1991 use std::fmt::Write;
1992
1993 use crate::TestServer;
1994
1995 #[derive(Debug, Deserialize, Serialize)]
1996 struct QueryParam {
1997 message: String,
1998 }
1999
2000 async fn get_query_param(AxumStdQuery(params): AxumStdQuery<QueryParam>) -> String {
2001 params.message
2002 }
2003
2004 #[derive(Debug, Deserialize, Serialize)]
2005 struct QueryParamExtra {
2006 #[serde(default)]
2007 items: Vec<String>,
2008
2009 #[serde(default, rename = "arrs[]")]
2010 arrs: Vec<String>,
2011 }
2012
2013 async fn get_query_param_extra(
2014 AxumExtraQuery(params): AxumExtraQuery<QueryParamExtra>,
2015 ) -> String {
2016 let mut output = String::new();
2017
2018 if params.items.len() > 0 {
2019 write!(output, "{}", params.items.join(", ")).unwrap();
2020 }
2021
2022 if params.arrs.len() > 0 {
2023 write!(output, "{}", params.arrs.join(", ")).unwrap();
2024 }
2025
2026 output
2027 }
2028
2029 fn build_app() -> Router {
2030 Router::new()
2031 .route("/query", get(get_query_param))
2032 .route("/query-extra", get(get_query_param_extra))
2033 }
2034
2035 #[tokio::test]
2036 async fn it_should_pass_up_query_param_as_is() {
2037 let mut server = TestServer::new(build_app()).expect("Should create test server");
2039 server.add_raw_query_param(&"message=it-works");
2040
2041 server.get(&"/query").await.assert_text(&"it-works");
2043 }
2044
2045 #[tokio::test]
2046 async fn it_should_pass_up_array_query_params_as_one_string() {
2047 let mut server = TestServer::new(build_app()).expect("Should create test server");
2049 server.add_raw_query_param(&"items=one&items=two&items=three");
2050
2051 server
2053 .get(&"/query-extra")
2054 .await
2055 .assert_text(&"one, two, three");
2056 }
2057
2058 #[tokio::test]
2059 async fn it_should_pass_up_array_query_params_as_multiple_params() {
2060 let mut server = TestServer::new(build_app()).expect("Should create test server");
2062 server.add_raw_query_param(&"arrs[]=one");
2063 server.add_raw_query_param(&"arrs[]=two");
2064 server.add_raw_query_param(&"arrs[]=three");
2065
2066 server
2068 .get(&"/query-extra")
2069 .await
2070 .assert_text(&"one, two, three");
2071 }
2072}
2073
2074#[cfg(test)]
2075mod test_clear_query_params {
2076 use axum::Router;
2077 use axum::extract::Query;
2078 use axum::routing::get;
2079
2080 use serde::Deserialize;
2081 use serde::Serialize;
2082
2083 use crate::TestServer;
2084
2085 #[derive(Debug, Deserialize, Serialize)]
2086 struct QueryParams {
2087 first: Option<String>,
2088 second: Option<String>,
2089 }
2090
2091 async fn get_query_params(Query(params): Query<QueryParams>) -> String {
2092 format!(
2093 "has first? {}, has second? {}",
2094 params.first.is_some(),
2095 params.second.is_some()
2096 )
2097 }
2098
2099 #[tokio::test]
2100 async fn it_should_clear_all_params_set() {
2101 let app = Router::new().route("/query", get(get_query_params));
2103
2104 let mut server = TestServer::new(app).expect("Should create test server");
2106 server.add_query_params(QueryParams {
2107 first: Some("first".to_string()),
2108 second: Some("second".to_string()),
2109 });
2110 server.clear_query_params();
2111
2112 server
2114 .get(&"/query")
2115 .await
2116 .assert_text(&"has first? false, has second? false");
2117 }
2118
2119 #[tokio::test]
2120 async fn it_should_clear_all_params_set_and_allow_replacement() {
2121 let app = Router::new().route("/query", get(get_query_params));
2123
2124 let mut server = TestServer::new(app).expect("Should create test server");
2126 server.add_query_params(QueryParams {
2127 first: Some("first".to_string()),
2128 second: Some("second".to_string()),
2129 });
2130 server.clear_query_params();
2131 server.add_query_params(QueryParams {
2132 first: Some("first".to_string()),
2133 second: Some("second".to_string()),
2134 });
2135
2136 server
2138 .get(&"/query")
2139 .await
2140 .assert_text(&"has first? true, has second? true");
2141 }
2142}
2143
2144#[cfg(test)]
2145mod test_expect_success_by_default {
2146 use super::*;
2147
2148 use axum::Router;
2149 use axum::routing::get;
2150
2151 #[tokio::test]
2152 async fn it_should_not_panic_by_default_if_accessing_404_route() {
2153 let app = Router::new();
2154 let server = TestServer::new(app).expect("Should create test server");
2155
2156 server.get(&"/some_unknown_route").await;
2157 }
2158
2159 #[tokio::test]
2160 async fn it_should_not_panic_by_default_if_accessing_200_route() {
2161 let app = Router::new().route("/known_route", get(|| async { "🦊🦊🦊" }));
2162 let server = TestServer::new(app).expect("Should create test server");
2163
2164 server.get(&"/known_route").await;
2165 }
2166
2167 #[tokio::test]
2168 #[should_panic]
2169 async fn it_should_panic_by_default_if_accessing_404_route_and_expect_success_on() {
2170 let app = Router::new();
2171 let server = TestServer::builder()
2172 .expect_success_by_default()
2173 .build(app)
2174 .expect("Should create test server");
2175
2176 server.get(&"/some_unknown_route").await;
2177 }
2178
2179 #[tokio::test]
2180 async fn it_should_not_panic_by_default_if_accessing_200_route_and_expect_success_on() {
2181 let app = Router::new().route("/known_route", get(|| async { "🦊🦊🦊" }));
2182 let server = TestServer::builder()
2183 .expect_success_by_default()
2184 .build(app)
2185 .expect("Should create test server");
2186
2187 server.get(&"/known_route").await;
2188 }
2189}
2190
2191#[cfg(test)]
2192mod test_content_type {
2193 use super::*;
2194
2195 use axum::Router;
2196 use axum::routing::get;
2197 use http::HeaderMap;
2198 use http::header::CONTENT_TYPE;
2199
2200 async fn get_content_type(headers: HeaderMap) -> String {
2201 headers
2202 .get(CONTENT_TYPE)
2203 .map(|h| h.to_str().unwrap().to_string())
2204 .unwrap_or_else(|| "".to_string())
2205 }
2206
2207 #[tokio::test]
2208 async fn it_should_default_to_server_content_type_when_present() {
2209 let app = Router::new().route("/content_type", get(get_content_type));
2211
2212 let server = TestServer::builder()
2214 .default_content_type("text/plain")
2215 .build(app)
2216 .expect("Should create test server");
2217
2218 let text = server.get(&"/content_type").await.text();
2220
2221 assert_eq!(text, "text/plain");
2222 }
2223}
2224
2225#[cfg(test)]
2226mod test_expect_success {
2227 use crate::TestServer;
2228 use axum::Router;
2229 use axum::routing::get;
2230 use http::StatusCode;
2231
2232 #[tokio::test]
2233 async fn it_should_not_panic_if_success_is_returned() {
2234 async fn get_ping() -> &'static str {
2235 "pong!"
2236 }
2237
2238 let app = Router::new().route("/ping", get(get_ping));
2240
2241 let mut server = TestServer::new(app).expect("Should create test server");
2243 server.expect_success();
2244
2245 server.get(&"/ping").await;
2247 }
2248
2249 #[tokio::test]
2250 async fn it_should_not_panic_on_other_2xx_status_code() {
2251 async fn get_accepted() -> StatusCode {
2252 StatusCode::ACCEPTED
2253 }
2254
2255 let app = Router::new().route("/accepted", get(get_accepted));
2257
2258 let mut server = TestServer::new(app).expect("Should create test server");
2260 server.expect_success();
2261
2262 server.get(&"/accepted").await;
2264 }
2265
2266 #[tokio::test]
2267 #[should_panic]
2268 async fn it_should_panic_on_404() {
2269 let app = Router::new();
2271
2272 let mut server = TestServer::new(app).expect("Should create test server");
2274 server.expect_success();
2275
2276 server.get(&"/some_unknown_route").await;
2278 }
2279}
2280
2281#[cfg(test)]
2282mod test_expect_failure {
2283 use crate::TestServer;
2284 use axum::Router;
2285 use axum::routing::get;
2286 use http::StatusCode;
2287
2288 #[tokio::test]
2289 async fn it_should_not_panic_if_expect_failure_on_404() {
2290 let app = Router::new();
2292
2293 let mut server = TestServer::new(app).expect("Should create test server");
2295 server.expect_failure();
2296
2297 server.get(&"/some_unknown_route").await;
2299 }
2300
2301 #[tokio::test]
2302 #[should_panic]
2303 async fn it_should_panic_if_success_is_returned() {
2304 async fn get_ping() -> &'static str {
2305 "pong!"
2306 }
2307
2308 let app = Router::new().route("/ping", get(get_ping));
2310
2311 let mut server = TestServer::new(app).expect("Should create test server");
2313 server.expect_failure();
2314
2315 server.get(&"/ping").await;
2317 }
2318
2319 #[tokio::test]
2320 #[should_panic]
2321 async fn it_should_panic_on_other_2xx_status_code() {
2322 async fn get_accepted() -> StatusCode {
2323 StatusCode::ACCEPTED
2324 }
2325
2326 let app = Router::new().route("/accepted", get(get_accepted));
2328
2329 let mut server = TestServer::new(app).expect("Should create test server");
2331 server.expect_failure();
2332
2333 server.get(&"/accepted").await;
2335 }
2336}
2337
2338#[cfg(test)]
2339mod test_scheme {
2340 use axum::Router;
2341 use axum::extract::Request;
2342 use axum::routing::get;
2343
2344 use crate::TestServer;
2345
2346 async fn route_get_scheme(request: Request) -> String {
2347 request.uri().scheme_str().unwrap().to_string()
2348 }
2349
2350 #[tokio::test]
2351 async fn it_should_return_http_by_default() {
2352 let router = Router::new().route("/scheme", get(route_get_scheme));
2353 let server = TestServer::builder().build(router).unwrap();
2354
2355 server.get("/scheme").await.assert_text("http");
2356 }
2357
2358 #[tokio::test]
2359 async fn it_should_return_https_across_multiple_requests_when_set() {
2360 let router = Router::new().route("/scheme", get(route_get_scheme));
2361 let mut server = TestServer::builder().build(router).unwrap();
2362 server.scheme(&"https");
2363
2364 server.get("/scheme").await.assert_text("https");
2365 }
2366}
2367
2368#[cfg(feature = "typed-routing")]
2369#[cfg(test)]
2370mod test_typed_get {
2371 use super::*;
2372
2373 use axum::Router;
2374 use axum_extra::routing::RouterExt;
2375 use serde::Deserialize;
2376
2377 #[derive(TypedPath, Deserialize)]
2378 #[typed_path("/path/{id}")]
2379 struct TestingPath {
2380 id: u32,
2381 }
2382
2383 async fn route_get(TestingPath { id }: TestingPath) -> String {
2384 format!("get {id}")
2385 }
2386
2387 fn new_app() -> Router {
2388 Router::new().typed_get(route_get)
2389 }
2390
2391 #[tokio::test]
2392 async fn it_should_send_get() {
2393 let server = TestServer::new(new_app()).unwrap();
2394
2395 server
2396 .typed_get(&TestingPath { id: 123 })
2397 .await
2398 .assert_text("get 123");
2399 }
2400}
2401
2402#[cfg(feature = "typed-routing")]
2403#[cfg(test)]
2404mod test_typed_post {
2405 use super::*;
2406
2407 use axum::Router;
2408 use axum_extra::routing::RouterExt;
2409 use serde::Deserialize;
2410
2411 #[derive(TypedPath, Deserialize)]
2412 #[typed_path("/path/{id}")]
2413 struct TestingPath {
2414 id: u32,
2415 }
2416
2417 async fn route_post(TestingPath { id }: TestingPath) -> String {
2418 format!("post {id}")
2419 }
2420
2421 fn new_app() -> Router {
2422 Router::new().typed_post(route_post)
2423 }
2424
2425 #[tokio::test]
2426 async fn it_should_send_post() {
2427 let server = TestServer::new(new_app()).unwrap();
2428
2429 server
2430 .typed_post(&TestingPath { id: 123 })
2431 .await
2432 .assert_text("post 123");
2433 }
2434}
2435
2436#[cfg(feature = "typed-routing")]
2437#[cfg(test)]
2438mod test_typed_patch {
2439 use super::*;
2440
2441 use axum::Router;
2442 use axum_extra::routing::RouterExt;
2443 use serde::Deserialize;
2444
2445 #[derive(TypedPath, Deserialize)]
2446 #[typed_path("/path/{id}")]
2447 struct TestingPath {
2448 id: u32,
2449 }
2450
2451 async fn route_patch(TestingPath { id }: TestingPath) -> String {
2452 format!("patch {id}")
2453 }
2454
2455 fn new_app() -> Router {
2456 Router::new().typed_patch(route_patch)
2457 }
2458
2459 #[tokio::test]
2460 async fn it_should_send_patch() {
2461 let server = TestServer::new(new_app()).unwrap();
2462
2463 server
2464 .typed_patch(&TestingPath { id: 123 })
2465 .await
2466 .assert_text("patch 123");
2467 }
2468}
2469
2470#[cfg(feature = "typed-routing")]
2471#[cfg(test)]
2472mod test_typed_put {
2473 use super::*;
2474
2475 use axum::Router;
2476 use axum_extra::routing::RouterExt;
2477 use serde::Deserialize;
2478
2479 #[derive(TypedPath, Deserialize)]
2480 #[typed_path("/path/{id}")]
2481 struct TestingPath {
2482 id: u32,
2483 }
2484
2485 async fn route_put(TestingPath { id }: TestingPath) -> String {
2486 format!("put {id}")
2487 }
2488
2489 fn new_app() -> Router {
2490 Router::new().typed_put(route_put)
2491 }
2492
2493 #[tokio::test]
2494 async fn it_should_send_put() {
2495 let server = TestServer::new(new_app()).unwrap();
2496
2497 server
2498 .typed_put(&TestingPath { id: 123 })
2499 .await
2500 .assert_text("put 123");
2501 }
2502}
2503
2504#[cfg(feature = "typed-routing")]
2505#[cfg(test)]
2506mod test_typed_delete {
2507 use super::*;
2508
2509 use axum::Router;
2510 use axum_extra::routing::RouterExt;
2511 use serde::Deserialize;
2512
2513 #[derive(TypedPath, Deserialize)]
2514 #[typed_path("/path/{id}")]
2515 struct TestingPath {
2516 id: u32,
2517 }
2518
2519 async fn route_delete(TestingPath { id }: TestingPath) -> String {
2520 format!("delete {id}")
2521 }
2522
2523 fn new_app() -> Router {
2524 Router::new().typed_delete(route_delete)
2525 }
2526
2527 #[tokio::test]
2528 async fn it_should_send_delete() {
2529 let server = TestServer::new(new_app()).unwrap();
2530
2531 server
2532 .typed_delete(&TestingPath { id: 123 })
2533 .await
2534 .assert_text("delete 123");
2535 }
2536}
2537
2538#[cfg(feature = "typed-routing")]
2539#[cfg(test)]
2540mod test_typed_method {
2541 use super::*;
2542
2543 use axum::Router;
2544 use axum_extra::routing::RouterExt;
2545 use serde::Deserialize;
2546
2547 #[derive(TypedPath, Deserialize)]
2548 #[typed_path("/path/{id}")]
2549 struct TestingPath {
2550 id: u32,
2551 }
2552
2553 async fn route_get(TestingPath { id }: TestingPath) -> String {
2554 format!("get {id}")
2555 }
2556
2557 async fn route_post(TestingPath { id }: TestingPath) -> String {
2558 format!("post {id}")
2559 }
2560
2561 async fn route_patch(TestingPath { id }: TestingPath) -> String {
2562 format!("patch {id}")
2563 }
2564
2565 async fn route_put(TestingPath { id }: TestingPath) -> String {
2566 format!("put {id}")
2567 }
2568
2569 async fn route_delete(TestingPath { id }: TestingPath) -> String {
2570 format!("delete {id}")
2571 }
2572
2573 fn new_app() -> Router {
2574 Router::new()
2575 .typed_get(route_get)
2576 .typed_post(route_post)
2577 .typed_patch(route_patch)
2578 .typed_put(route_put)
2579 .typed_delete(route_delete)
2580 }
2581
2582 #[tokio::test]
2583 async fn it_should_send_get() {
2584 let server = TestServer::new(new_app()).unwrap();
2585
2586 server
2587 .typed_method(Method::GET, &TestingPath { id: 123 })
2588 .await
2589 .assert_text("get 123");
2590 }
2591
2592 #[tokio::test]
2593 async fn it_should_send_post() {
2594 let server = TestServer::new(new_app()).unwrap();
2595
2596 server
2597 .typed_method(Method::POST, &TestingPath { id: 123 })
2598 .await
2599 .assert_text("post 123");
2600 }
2601
2602 #[tokio::test]
2603 async fn it_should_send_patch() {
2604 let server = TestServer::new(new_app()).unwrap();
2605
2606 server
2607 .typed_method(Method::PATCH, &TestingPath { id: 123 })
2608 .await
2609 .assert_text("patch 123");
2610 }
2611
2612 #[tokio::test]
2613 async fn it_should_send_put() {
2614 let server = TestServer::new(new_app()).unwrap();
2615
2616 server
2617 .typed_method(Method::PUT, &TestingPath { id: 123 })
2618 .await
2619 .assert_text("put 123");
2620 }
2621
2622 #[tokio::test]
2623 async fn it_should_send_delete() {
2624 let server = TestServer::new(new_app()).unwrap();
2625
2626 server
2627 .typed_method(Method::DELETE, &TestingPath { id: 123 })
2628 .await
2629 .assert_text("delete 123");
2630 }
2631}
2632
2633#[cfg(test)]
2634mod test_sync {
2635 use super::*;
2636 use axum::Router;
2637 use axum::routing::get;
2638 use std::cell::OnceCell;
2639
2640 #[tokio::test]
2641 async fn it_should_be_able_to_be_in_one_cell() {
2642 let cell: OnceCell<TestServer> = OnceCell::new();
2643 let server = cell.get_or_init(|| {
2644 async fn route_get() -> &'static str {
2645 "it works"
2646 }
2647
2648 let router = Router::new().route("/test", get(route_get));
2649
2650 TestServer::new(router).unwrap()
2651 });
2652
2653 server.get("/test").await.assert_text("it works");
2654 }
2655}
2656
2657#[cfg(test)]
2658mod test_is_running {
2659 use super::*;
2660 use crate::util::new_random_tokio_tcp_listener;
2661 use axum::Router;
2662 use axum::routing::IntoMakeService;
2663 use axum::routing::get;
2664 use axum::serve;
2665 use std::time::Duration;
2666 use tokio::sync::Notify;
2667 use tokio::time::sleep;
2668
2669 async fn get_ping() -> &'static str {
2670 "pong!"
2671 }
2672
2673 #[tokio::test]
2674 #[should_panic]
2675 async fn it_should_panic_when_run_with_mock_http() {
2676 let shutdown_notification = Arc::new(Notify::new());
2677 let waiting_notification = shutdown_notification.clone();
2678
2679 let app: IntoMakeService<Router> = Router::new()
2681 .route("/ping", get(get_ping))
2682 .into_make_service();
2683 let port = new_random_tokio_tcp_listener().unwrap();
2684 let application = serve(port, app)
2685 .with_graceful_shutdown(async move { waiting_notification.notified().await });
2686
2687 let server = TestServer::builder()
2689 .build(application)
2690 .expect("Should create test server");
2691
2692 server.get("/ping").await.assert_status_ok();
2693 assert!(server.is_running());
2694
2695 shutdown_notification.notify_one();
2696 sleep(Duration::from_millis(10)).await;
2697
2698 assert!(!server.is_running());
2699 server.get("/ping").await.assert_status_ok();
2700 }
2701}