1use anyhow::anyhow;
2use anyhow::Context;
3use anyhow::Result;
4use cookie::Cookie;
5use cookie::CookieJar;
6use http::HeaderName;
7use http::HeaderValue;
8use http::Method;
9use http::Uri;
10use serde::Serialize;
11use std::fmt::Debug;
12use std::sync::Arc;
13use std::sync::Mutex;
14use url::Url;
15
16#[cfg(feature = "typed-routing")]
17use axum_extra::routing::TypedPath;
18
19#[cfg(feature = "reqwest")]
20use crate::transport_layer::TransportLayerType;
21#[cfg(feature = "reqwest")]
22use reqwest::Client;
23#[cfg(feature = "reqwest")]
24use reqwest::RequestBuilder;
25
26use crate::internals::ExpectedState;
27use crate::internals::QueryParamsStore;
28use crate::internals::RequestPathFormatter;
29use crate::transport_layer::IntoTransportLayer;
30use crate::transport_layer::TransportLayer;
31use crate::transport_layer::TransportLayerBuilder;
32use crate::TestRequest;
33use crate::TestRequestConfig;
34use crate::TestServerBuilder;
35use crate::TestServerConfig;
36use crate::Transport;
37
38mod server_shared_state;
39pub(crate) use self::server_shared_state::*;
40
41const DEFAULT_URL_ADDRESS: &str = "http://localhost";
42
43#[derive(Debug)]
142pub struct TestServer {
143 state: Arc<Mutex<ServerSharedState>>,
144 transport: Arc<Box<dyn TransportLayer>>,
145 save_cookies: bool,
146 expected_state: ExpectedState,
147 default_content_type: Option<String>,
148 is_http_path_restricted: bool,
149
150 #[cfg(feature = "reqwest")]
151 maybe_reqwest_client: Option<Client>,
152}
153
154impl TestServer {
155 pub fn builder() -> TestServerBuilder {
157 TestServerBuilder::default()
158 }
159
160 pub fn new<A>(app: A) -> Result<Self>
192 where
193 A: IntoTransportLayer,
194 {
195 Self::new_with_config(app, TestServerConfig::default())
196 }
197
198 pub fn new_with_config<A, C>(app: A, config: C) -> Result<Self>
205 where
206 A: IntoTransportLayer,
207 C: Into<TestServerConfig>,
208 {
209 let config = config.into();
210 let mut shared_state = ServerSharedState::new();
211 if let Some(scheme) = config.default_scheme {
212 shared_state.set_scheme_unlocked(scheme);
213 }
214
215 let shared_state_mutex = Mutex::new(shared_state);
216 let state = Arc::new(shared_state_mutex);
217
218 let transport = match config.transport {
219 None => {
220 let builder = TransportLayerBuilder::new(None, None);
221 let transport = app.into_default_transport(builder)?;
222 Arc::new(transport)
223 }
224 Some(Transport::HttpRandomPort) => {
225 let builder = TransportLayerBuilder::new(None, None);
226 let transport = app.into_http_transport_layer(builder)?;
227 Arc::new(transport)
228 }
229 Some(Transport::HttpIpPort { ip, port }) => {
230 let builder = TransportLayerBuilder::new(ip, port);
231 let transport = app.into_http_transport_layer(builder)?;
232 Arc::new(transport)
233 }
234 Some(Transport::MockHttp) => {
235 let transport = app.into_mock_transport_layer()?;
236 Arc::new(transport)
237 }
238 };
239
240 let expected_state = match config.expect_success_by_default {
241 true => ExpectedState::Success,
242 false => ExpectedState::None,
243 };
244
245 #[cfg(feature = "reqwest")]
246 let maybe_reqwest_client = match transport.transport_layer_type() {
247 TransportLayerType::Http => {
248 let reqwest_client = reqwest::Client::builder()
249 .redirect(reqwest::redirect::Policy::none())
250 .cookie_store(config.save_cookies)
251 .build()
252 .expect("Failed to build Reqwest Client");
253
254 Some(reqwest_client)
255 }
256 TransportLayerType::Mock => None,
257 };
258
259 Ok(Self {
260 state,
261 transport,
262 save_cookies: config.save_cookies,
263 expected_state,
264 default_content_type: config.default_content_type,
265 is_http_path_restricted: config.restrict_requests_with_http_schema,
266
267 #[cfg(feature = "reqwest")]
268 maybe_reqwest_client,
269 })
270 }
271
272 pub fn get(&self, path: &str) -> TestRequest {
274 self.method(Method::GET, path)
275 }
276
277 pub fn post(&self, path: &str) -> TestRequest {
279 self.method(Method::POST, path)
280 }
281
282 pub fn patch(&self, path: &str) -> TestRequest {
284 self.method(Method::PATCH, path)
285 }
286
287 pub fn put(&self, path: &str) -> TestRequest {
289 self.method(Method::PUT, path)
290 }
291
292 pub fn delete(&self, path: &str) -> TestRequest {
294 self.method(Method::DELETE, path)
295 }
296
297 pub fn method(&self, method: Method, path: &str) -> TestRequest {
299 let maybe_config = self.build_test_request_config(method.clone(), path);
300 let config = maybe_config
301 .with_context(|| format!("Failed to build, for request {method} {path}"))
302 .unwrap();
303
304 TestRequest::new(self.state.clone(), self.transport.clone(), config)
305 }
306
307 #[cfg(feature = "reqwest")]
308 fn reqwest_client(&self) -> &Client {
309 self.maybe_reqwest_client
310 .as_ref()
311 .expect("Reqwest client is not available, TestServer must be build with HTTP transport for Reqwest to be available")
312 }
313
314 #[cfg(feature = "reqwest")]
315 pub fn reqwest_get(&self, path: &str) -> RequestBuilder {
316 self.reqwest_method(Method::GET, path)
317 }
318
319 #[cfg(feature = "reqwest")]
320 pub fn reqwest_post(&self, path: &str) -> RequestBuilder {
321 self.reqwest_method(Method::POST, path)
322 }
323
324 #[cfg(feature = "reqwest")]
325 pub fn reqwest_put(&self, path: &str) -> RequestBuilder {
326 self.reqwest_method(Method::PUT, path)
327 }
328
329 #[cfg(feature = "reqwest")]
330 pub fn reqwest_patch(&self, path: &str) -> RequestBuilder {
331 self.reqwest_method(Method::PATCH, path)
332 }
333
334 #[cfg(feature = "reqwest")]
335 pub fn reqwest_delete(&self, path: &str) -> RequestBuilder {
336 self.reqwest_method(Method::DELETE, path)
337 }
338
339 #[cfg(feature = "reqwest")]
340 pub fn reqwest_head(&self, path: &str) -> RequestBuilder {
341 self.reqwest_method(Method::HEAD, path)
342 }
343
344 #[cfg(feature = "reqwest")]
369 pub fn reqwest_method(&self, method: Method, path: &str) -> RequestBuilder {
370 let request_url = self
371 .server_url(path)
372 .expect("Failed to generate server url for request {method} {path}");
373
374 self.reqwest_client().request(method, request_url)
375 }
376
377 #[cfg(feature = "ws")]
412 pub fn get_websocket(&self, path: &str) -> TestRequest {
413 use http::header;
414
415 self.get(path)
416 .add_header(header::CONNECTION, "upgrade")
417 .add_header(header::UPGRADE, "websocket")
418 .add_header(header::SEC_WEBSOCKET_VERSION, "13")
419 .add_header(
420 header::SEC_WEBSOCKET_KEY,
421 crate::internals::generate_ws_key(),
422 )
423 }
424
425 #[cfg(feature = "typed-routing")]
472 pub fn typed_get<P>(&self, path: &P) -> TestRequest
473 where
474 P: TypedPath,
475 {
476 self.typed_method(Method::GET, path)
477 }
478
479 #[cfg(feature = "typed-routing")]
483 pub fn typed_post<P>(&self, path: &P) -> TestRequest
484 where
485 P: TypedPath,
486 {
487 self.typed_method(Method::POST, path)
488 }
489
490 #[cfg(feature = "typed-routing")]
494 pub fn typed_patch<P>(&self, path: &P) -> TestRequest
495 where
496 P: TypedPath,
497 {
498 self.typed_method(Method::PATCH, path)
499 }
500
501 #[cfg(feature = "typed-routing")]
505 pub fn typed_put<P>(&self, path: &P) -> TestRequest
506 where
507 P: TypedPath,
508 {
509 self.typed_method(Method::PUT, path)
510 }
511
512 #[cfg(feature = "typed-routing")]
516 pub fn typed_delete<P>(&self, path: &P) -> TestRequest
517 where
518 P: TypedPath,
519 {
520 self.typed_method(Method::DELETE, path)
521 }
522
523 #[cfg(feature = "typed-routing")]
527 pub fn typed_method<P>(&self, method: Method, path: &P) -> TestRequest
528 where
529 P: TypedPath,
530 {
531 self.method(method, &path.to_string())
532 }
533
534 pub fn server_address(&self) -> Option<Url> {
542 self.url()
543 }
544
545 pub fn server_url(&self, path: &str) -> Result<Url> {
578 let path_uri = path.parse::<Uri>()?;
579 if is_absolute_uri(&path_uri) {
580 return Err(anyhow!(
581 "Absolute path provided for building server url, need to provide a relative uri"
582 ));
583 }
584
585 let server_url = self.url()
586 .ok_or_else(||
587 anyhow!(
588 "No local address for server, need to run with HTTP transport to have a server address",
589 )
590 )?;
591
592 let server_locked = self.state.as_ref().lock().map_err(|err| {
593 anyhow!("Failed to lock InternalTestServer, for building server_url, received {err:?}",)
594 })?;
595 let mut query_params = server_locked.query_params().clone();
596 let mut full_server_url = build_url(
597 server_url,
598 path,
599 &mut query_params,
600 self.is_http_path_restricted,
601 )?;
602
603 if query_params.has_content() {
605 full_server_url.set_query(Some(&query_params.to_string()));
606 }
607
608 Ok(full_server_url)
609 }
610
611 pub fn add_cookie(&mut self, cookie: Cookie) {
616 ServerSharedState::add_cookie(&self.state, cookie)
617 .context("Trying to call add_cookie")
618 .unwrap()
619 }
620
621 pub fn add_cookies(&mut self, cookies: CookieJar) {
626 ServerSharedState::add_cookies(&self.state, cookies)
627 .context("Trying to call add_cookies")
628 .unwrap()
629 }
630
631 pub fn clear_cookies(&mut self) {
633 ServerSharedState::clear_cookies(&self.state)
634 .context("Trying to call clear_cookies")
635 .unwrap()
636 }
637
638 pub fn save_cookies(&mut self) {
642 self.save_cookies = true;
643 }
644
645 pub fn do_not_save_cookies(&mut self) {
649 self.save_cookies = false;
650 }
651
652 pub fn expect_success(&mut self) {
656 self.expected_state = ExpectedState::Success;
657 }
658
659 pub fn expect_failure(&mut self) {
663 self.expected_state = ExpectedState::Failure;
664 }
665
666 pub fn add_query_param<V>(&mut self, key: &str, value: V)
668 where
669 V: Serialize,
670 {
671 ServerSharedState::add_query_param(&self.state, key, value)
672 .context("Trying to call add_query_param")
673 .unwrap()
674 }
675
676 pub fn add_query_params<V>(&mut self, query_params: V)
678 where
679 V: Serialize,
680 {
681 ServerSharedState::add_query_params(&self.state, query_params)
682 .context("Trying to call add_query_params")
683 .unwrap()
684 }
685
686 pub fn add_raw_query_param(&mut self, raw_query_param: &str) {
689 ServerSharedState::add_raw_query_param(&self.state, raw_query_param)
690 .context("Trying to call add_raw_query_param")
691 .unwrap()
692 }
693
694 pub fn clear_query_params(&mut self) {
696 ServerSharedState::clear_query_params(&self.state)
697 .context("Trying to call clear_query_params")
698 .unwrap()
699 }
700
701 pub fn add_header<N, V>(&mut self, name: N, value: V)
722 where
723 N: TryInto<HeaderName>,
724 N::Error: Debug,
725 V: TryInto<HeaderValue>,
726 V::Error: Debug,
727 {
728 let header_name: HeaderName = name
729 .try_into()
730 .expect("Failed to convert header name to HeaderName");
731 let header_value: HeaderValue = value
732 .try_into()
733 .expect("Failed to convert header vlue to HeaderValue");
734
735 ServerSharedState::add_header(&self.state, header_name, header_value)
736 .context("Trying to call add_header")
737 .unwrap()
738 }
739
740 pub fn clear_headers(&mut self) {
742 ServerSharedState::clear_headers(&self.state)
743 .context("Trying to call clear_headers")
744 .unwrap()
745 }
746
747 pub fn scheme(&mut self, scheme: &str) {
771 ServerSharedState::set_scheme(&self.state, scheme.to_string())
772 .context("Trying to call set_scheme")
773 .unwrap()
774 }
775
776 pub(crate) fn url(&self) -> Option<Url> {
777 self.transport.url().cloned()
778 }
779
780 pub(crate) fn build_test_request_config(
781 &self,
782 method: Method,
783 path: &str,
784 ) -> Result<TestRequestConfig> {
785 let url = self
786 .url()
787 .unwrap_or_else(|| DEFAULT_URL_ADDRESS.parse().unwrap());
788
789 let server_locked = self.state.as_ref().lock().map_err(|err| {
790 anyhow!(
791 "Failed to lock InternalTestServer, for request {method} {path}, received {err:?}",
792 )
793 })?;
794
795 let cookies = server_locked.cookies().clone();
796 let mut query_params = server_locked.query_params().clone();
797 let headers = server_locked.headers().clone();
798 let mut full_request_url =
799 build_url(url, path, &mut query_params, self.is_http_path_restricted)?;
800
801 if let Some(scheme) = server_locked.scheme() {
802 full_request_url.set_scheme(scheme).map_err(|_| {
803 let debug_request_format = RequestPathFormatter::new(&method, full_request_url.as_str(), Some(&query_params));
804 anyhow!("Scheme '{scheme}' from TestServer cannot be set to request {debug_request_format}")
805 })?;
806 }
807
808 ::std::mem::drop(server_locked);
809
810 Ok(TestRequestConfig {
811 is_saving_cookies: self.save_cookies,
812 expected_state: self.expected_state,
813 content_type: self.default_content_type.clone(),
814 method,
815
816 full_request_url,
817 cookies,
818 query_params,
819 headers,
820 })
821 }
822
823 pub fn is_running(&self) -> bool {
829 self.transport.is_running()
830 }
831}
832
833fn build_url(
834 mut url: Url,
835 path: &str,
836 query_params: &mut QueryParamsStore,
837 is_http_restricted: bool,
838) -> Result<Url> {
839 let path_uri = path.parse::<Uri>()?;
840
841 if let Some(scheme) = path_uri.scheme_str() {
843 if is_http_restricted {
844 if has_different_schema(&url, &path_uri) || has_different_authority(&url, &path_uri) {
845 return Err(anyhow!("Request disallowed for path '{path}', requests are only allowed to local server. Turn off 'restrict_requests_with_http_schema' to change this."));
846 }
847 } else {
848 url.set_scheme(scheme)
849 .map_err(|_| anyhow!("Failed to set scheme for request, with path '{path}'"))?;
850
851 if let Some(authority) = path_uri.authority() {
853 url.set_host(Some(authority.host()))
854 .map_err(|_| anyhow!("Failed to set host for request, with path '{path}'"))?;
855 url.set_port(authority.port().map(|p| p.as_u16()))
856 .map_err(|_| anyhow!("Failed to set port for request, with path '{path}'"))?;
857
858 }
860 }
861 }
862
863 if is_absolute_uri(&path_uri) {
873 url.set_path(path_uri.path());
874
875 if url.query().is_some() {
877 url.set_query(None);
878 }
879 } else {
880 let calculated_path = path.split('?').next().unwrap_or(path);
882 url.set_path(calculated_path);
883
884 if let Some(url_query) = url.query() {
886 query_params.add_raw(url_query.to_string());
887 url.set_query(None);
888 }
889 }
890
891 if let Some(path_query) = path_uri.query() {
892 query_params.add_raw(path_query.to_string());
893 }
894
895 Ok(url)
896}
897
898fn is_absolute_uri(path_uri: &Uri) -> bool {
899 path_uri.scheme_str().is_some()
900}
901
902fn has_different_schema(base_url: &Url, path_uri: &Uri) -> bool {
903 if let Some(scheme) = path_uri.scheme_str() {
904 return scheme != base_url.scheme();
905 }
906
907 false
908}
909
910fn has_different_authority(base_url: &Url, path_uri: &Uri) -> bool {
911 if let Some(authority) = path_uri.authority() {
912 return authority.as_str() != base_url.authority();
913 }
914
915 false
916}
917
918#[cfg(test)]
919mod test_build_url {
920 use super::*;
921
922 #[test]
923 fn it_should_copy_path_to_url_returned_when_restricted() {
924 let base_url = "http://example.com".parse::<Url>().unwrap();
925 let path = "/users";
926 let mut query_params = QueryParamsStore::new();
927 let result = build_url(base_url, &path, &mut query_params, true).unwrap();
928
929 assert_eq!("http://example.com/users", result.as_str());
930 assert!(query_params.is_empty());
931 }
932
933 #[test]
934 fn it_should_copy_all_query_params_to_store_when_restricted() {
935 let base_url = "http://example.com?base=aaa".parse::<Url>().unwrap();
936 let path = "/users?path=bbb&path-flag";
937 let mut query_params = QueryParamsStore::new();
938 let result = build_url(base_url, &path, &mut query_params, true).unwrap();
939
940 assert_eq!("http://example.com/users", result.as_str());
941 assert_eq!("base=aaa&path=bbb&path-flag", query_params.to_string());
942 }
943
944 #[test]
945 fn it_should_not_replace_url_when_restricted_with_different_scheme() {
946 let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
947 let path = "ftp://google.com:123/users.csv?limit=456";
948 let mut query_params = QueryParamsStore::new();
949 let result = build_url(base_url, &path, &mut query_params, true);
950
951 assert!(result.is_err());
952 }
953
954 #[test]
955 fn it_should_not_replace_url_when_restricted_with_same_scheme() {
956 let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
957 let path = "http://google.com:123/users.csv?limit=456";
958 let mut query_params = QueryParamsStore::new();
959 let result = build_url(base_url, &path, &mut query_params, true);
960
961 assert!(result.is_err());
962 }
963
964 #[test]
965 fn it_should_block_url_when_restricted_with_same_scheme() {
966 let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
967 let path = "http://google.com";
968 let mut query_params = QueryParamsStore::new();
969 let result = build_url(base_url, &path, &mut query_params, true);
970
971 assert!(result.is_err());
972 }
973
974 #[test]
975 fn it_should_block_url_when_restricted_and_same_domain_with_different_scheme() {
976 let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
977 let path = "ftp://example.com/users";
978 let mut query_params = QueryParamsStore::new();
979 let result = build_url(base_url, &path, &mut query_params, true);
980
981 assert!(result.is_err());
982 }
983
984 #[test]
985 fn it_should_copy_path_to_url_returned_when_unrestricted() {
986 let base_url = "http://example.com".parse::<Url>().unwrap();
987 let path = "/users";
988 let mut query_params = QueryParamsStore::new();
989 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
990
991 assert_eq!("http://example.com/users", result.as_str());
992 assert!(query_params.is_empty());
993 }
994
995 #[test]
996 fn it_should_copy_all_query_params_to_store_when_unrestricted() {
997 let base_url = "http://example.com?base=aaa".parse::<Url>().unwrap();
998 let path = "/users?path=bbb&path-flag";
999 let mut query_params = QueryParamsStore::new();
1000 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1001
1002 assert_eq!("http://example.com/users", result.as_str());
1003 assert_eq!("base=aaa&path=bbb&path-flag", query_params.to_string());
1004 }
1005
1006 #[test]
1007 fn it_should_copy_host_like_a_path_when_unrestricted() {
1008 let base_url = "http://example.com".parse::<Url>().unwrap();
1009 let path = "google.com";
1010 let mut query_params = QueryParamsStore::new();
1011 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1012
1013 assert_eq!("http://example.com/google.com", result.as_str());
1014 assert!(query_params.is_empty());
1015 }
1016
1017 #[test]
1018 fn it_should_copy_host_like_a_path_when_restricted() {
1019 let base_url = "http://example.com".parse::<Url>().unwrap();
1020 let path = "google.com";
1021 let mut query_params = QueryParamsStore::new();
1022 let result = build_url(base_url, &path, &mut query_params, true).unwrap();
1023
1024 assert_eq!("http://example.com/google.com", result.as_str());
1025 assert!(query_params.is_empty());
1026 }
1027
1028 #[test]
1029 fn it_should_replace_url_when_unrestricted() {
1030 let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
1031 let path = "ftp://google.com:123/users.csv?limit=456";
1032 let mut query_params = QueryParamsStore::new();
1033 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1034
1035 assert_eq!("ftp://google.com:123/users.csv", result.as_str());
1036 assert_eq!("limit=456", query_params.to_string());
1037 }
1038
1039 #[test]
1040 fn it_should_allow_different_scheme_when_unrestricted() {
1041 let base_url = "http://example.com".parse::<Url>().unwrap();
1042 let path = "ftp://example.com";
1043 let mut query_params = QueryParamsStore::new();
1044 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1045
1046 assert_eq!("ftp://example.com/", result.as_str());
1047 }
1048
1049 #[test]
1050 fn it_should_allow_different_host_when_unrestricted() {
1051 let base_url = "http://example.com".parse::<Url>().unwrap();
1052 let path = "http://google.com";
1053 let mut query_params = QueryParamsStore::new();
1054 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1055
1056 assert_eq!("http://google.com/", result.as_str());
1057 }
1058
1059 #[test]
1060 fn it_should_allow_different_port_when_unrestricted() {
1061 let base_url = "http://example.com:123".parse::<Url>().unwrap();
1062 let path = "http://example.com:456";
1063 let mut query_params = QueryParamsStore::new();
1064 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1065
1066 assert_eq!("http://example.com:456/", result.as_str());
1067 }
1068
1069 #[test]
1070 fn it_should_allow_same_host_port_when_unrestricted() {
1071 let base_url = "http://example.com:123".parse::<Url>().unwrap();
1072 let path = "http://example.com:123";
1073 let mut query_params = QueryParamsStore::new();
1074 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1075
1076 assert_eq!("http://example.com:123/", result.as_str());
1077 }
1078
1079 #[test]
1080 fn it_should_not_allow_different_scheme_when_restricted() {
1081 let base_url = "http://example.com".parse::<Url>().unwrap();
1082 let path = "ftp://example.com";
1083 let mut query_params = QueryParamsStore::new();
1084 let result = build_url(base_url, &path, &mut query_params, true);
1085
1086 assert!(result.is_err());
1087 }
1088
1089 #[test]
1090 fn it_should_not_allow_different_host_when_restricted() {
1091 let base_url = "http://example.com".parse::<Url>().unwrap();
1092 let path = "http://google.com";
1093 let mut query_params = QueryParamsStore::new();
1094 let result = build_url(base_url, &path, &mut query_params, true);
1095
1096 assert!(result.is_err());
1097 }
1098
1099 #[test]
1100 fn it_should_not_allow_different_port_when_restricted() {
1101 let base_url = "http://example.com:123".parse::<Url>().unwrap();
1102 let path = "http://example.com:456";
1103 let mut query_params = QueryParamsStore::new();
1104 let result = build_url(base_url, &path, &mut query_params, true);
1105
1106 assert!(result.is_err());
1107 }
1108
1109 #[test]
1110 fn it_should_allow_same_host_port_when_restricted() {
1111 let base_url = "http://example.com:123".parse::<Url>().unwrap();
1112 let path = "http://example.com:123";
1113 let mut query_params = QueryParamsStore::new();
1114 let result = build_url(base_url, &path, &mut query_params, true).unwrap();
1115
1116 assert_eq!("http://example.com:123/", result.as_str());
1117 }
1118}
1119
1120#[cfg(test)]
1121mod test_new {
1122 use axum::routing::get;
1123 use axum::Router;
1124 use std::net::SocketAddr;
1125
1126 use crate::TestServer;
1127
1128 async fn get_ping() -> &'static str {
1129 "pong!"
1130 }
1131
1132 #[tokio::test]
1133 async fn it_should_run_into_make_into_service_with_connect_info_by_default() {
1134 let app = Router::new()
1136 .route("/ping", get(get_ping))
1137 .into_make_service_with_connect_info::<SocketAddr>();
1138
1139 let server = TestServer::new(app).expect("Should create test server");
1141
1142 server.get(&"/ping").await.assert_text(&"pong!");
1144 }
1145}
1146
1147#[cfg(test)]
1148mod test_get {
1149 use super::*;
1150
1151 use axum::routing::get;
1152 use axum::Router;
1153 use reserve_port::ReservedSocketAddr;
1154
1155 async fn get_ping() -> &'static str {
1156 "pong!"
1157 }
1158
1159 #[tokio::test]
1160 async fn it_should_get_using_relative_path_with_slash() {
1161 let app = Router::new().route("/ping", get(get_ping));
1162 let server = TestServer::new(app).expect("Should create test server");
1163
1164 server.get(&"/ping").await.assert_text(&"pong!");
1166 }
1167
1168 #[tokio::test]
1169 async fn it_should_get_using_relative_path_without_slash() {
1170 let app = Router::new().route("/ping", get(get_ping));
1171 let server = TestServer::new(app).expect("Should create test server");
1172
1173 server.get(&"ping").await.assert_text(&"pong!");
1175 }
1176
1177 #[tokio::test]
1178 async fn it_should_get_using_absolute_path() {
1179 let app = Router::new().route("/ping", get(get_ping));
1181
1182 let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1184 let ip = reserved_address.ip();
1185 let port = reserved_address.port();
1186
1187 let server = TestServer::builder()
1189 .http_transport_with_ip_port(Some(ip), Some(port))
1190 .build(app)
1191 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1192 .unwrap();
1193
1194 let absolute_url = format!("http://{ip}:{port}/ping");
1196 let response = server.get(&absolute_url).await;
1197
1198 response.assert_text(&"pong!");
1199 let request_path = response.request_url();
1200 assert_eq!(request_path.to_string(), format!("http://{ip}:{port}/ping"));
1201 }
1202
1203 #[tokio::test]
1204 async fn it_should_get_using_absolute_path_and_restricted_if_path_is_for_server() {
1205 let app = Router::new().route("/ping", get(get_ping));
1207
1208 let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1210 let ip = reserved_address.ip();
1211 let port = reserved_address.port();
1212
1213 let server = TestServer::builder()
1215 .http_transport_with_ip_port(Some(ip), Some(port))
1216 .restrict_requests_with_http_schema() .build(app)
1218 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1219 .unwrap();
1220
1221 let absolute_url = format!("http://{ip}:{port}/ping");
1223 let response = server.get(&absolute_url).await;
1224
1225 response.assert_text(&"pong!");
1226 let request_path = response.request_url();
1227 assert_eq!(request_path.to_string(), format!("http://{ip}:{port}/ping"));
1228 }
1229
1230 #[tokio::test]
1231 #[should_panic]
1232 async fn it_should_not_get_using_absolute_path_if_restricted_and_different_port() {
1233 let app = Router::new().route("/ping", get(get_ping));
1235
1236 let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1238 let ip = reserved_address.ip();
1239 let mut port = reserved_address.port();
1240
1241 let server = TestServer::builder()
1243 .http_transport_with_ip_port(Some(ip), Some(port))
1244 .restrict_requests_with_http_schema() .build(app)
1246 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1247 .unwrap();
1248
1249 port += 1; let absolute_url = format!("http://{ip}:{port}/ping");
1252 server.get(&absolute_url).await;
1253 }
1254
1255 #[tokio::test]
1256 async fn it_should_work_in_parallel() {
1257 let app = Router::new().route("/ping", get(get_ping));
1258 let server = TestServer::new(app).expect("Should create test server");
1259
1260 let future1 = async { server.get("/ping").await };
1261 let future2 = async { server.get("/ping").await };
1262 let (r1, r2) = tokio::join!(future1, future2);
1263
1264 assert_eq!(r1.text(), r2.text());
1265 }
1266
1267 #[tokio::test]
1268 async fn it_should_work_in_parallel_with_sleeping_requests() {
1269 let app = axum::Router::new().route(
1270 &"/slow",
1271 axum::routing::get(|| async {
1272 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1273 "hello!"
1274 }),
1275 );
1276
1277 let server = TestServer::new(app).expect("Should create test server");
1278
1279 let future1 = async { server.get("/slow").await };
1280 let future2 = async { server.get("/slow").await };
1281 let (r1, r2) = tokio::join!(future1, future2);
1282
1283 assert_eq!(r1.text(), r2.text());
1284 }
1285}
1286
1287#[cfg(feature = "reqwest")]
1288#[cfg(test)]
1289mod test_reqwest_get {
1290 use super::*;
1291
1292 use axum::routing::get;
1293 use axum::Router;
1294
1295 async fn get_ping() -> &'static str {
1296 "pong!"
1297 }
1298
1299 #[tokio::test]
1300 async fn it_should_get_using_relative_path_with_slash() {
1301 let app = Router::new().route("/ping", get(get_ping));
1302 let server = TestServer::builder()
1303 .http_transport()
1304 .build(app)
1305 .expect("Should create test server");
1306
1307 let response = server
1308 .reqwest_get(&"/ping")
1309 .send()
1310 .await
1311 .unwrap()
1312 .text()
1313 .await
1314 .unwrap();
1315
1316 assert_eq!(response, "pong!");
1317 }
1318}
1319
1320#[cfg(feature = "reqwest")]
1321#[cfg(test)]
1322mod test_reqwest_post {
1323 use super::*;
1324
1325 use axum::routing::post;
1326 use axum::Json;
1327 use axum::Router;
1328 use serde::Deserialize;
1329
1330 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1331 struct TestBody {
1332 number: u32,
1333 text: String,
1334 }
1335
1336 async fn post_json(Json(body): Json<TestBody>) -> Json<TestBody> {
1337 let response = TestBody {
1338 number: body.number * 2,
1339 text: format!("{}_plus_response", body.text),
1340 };
1341
1342 Json(response)
1343 }
1344
1345 #[tokio::test]
1346 async fn it_should_post_and_receive_json() {
1347 let app = Router::new().route("/json", post(post_json));
1348 let server = TestServer::builder()
1349 .http_transport()
1350 .build(app)
1351 .expect("Should create test server");
1352
1353 let response = server
1354 .reqwest_post(&"/json")
1355 .json(&TestBody {
1356 number: 111,
1357 text: format!("request"),
1358 })
1359 .send()
1360 .await
1361 .unwrap()
1362 .json::<TestBody>()
1363 .await
1364 .unwrap();
1365
1366 assert_eq!(
1367 response,
1368 TestBody {
1369 number: 222,
1370 text: format!("request_plus_response"),
1371 }
1372 );
1373 }
1374}
1375
1376#[cfg(test)]
1377mod test_server_address {
1378 use super::*;
1379
1380 use axum::Router;
1381 use local_ip_address::local_ip;
1382 use regex::Regex;
1383 use reserve_port::ReservedPort;
1384
1385 #[tokio::test]
1386 async fn it_should_return_address_used_from_config() {
1387 let reserved_port = ReservedPort::random().unwrap();
1388 let ip = local_ip().unwrap();
1389 let port = reserved_port.port();
1390
1391 let app = Router::new();
1393 let server = TestServer::builder()
1394 .http_transport_with_ip_port(Some(ip), Some(port))
1395 .build(app)
1396 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1397 .unwrap();
1398
1399 let expected_ip_port = format!("http://{}:{}/", ip, reserved_port.port());
1400 assert_eq!(
1401 server.server_address().unwrap().to_string(),
1402 expected_ip_port
1403 );
1404 }
1405
1406 #[tokio::test]
1407 async fn it_should_return_default_address_without_ending_slash() {
1408 let app = Router::new();
1409 let server = TestServer::builder()
1410 .http_transport()
1411 .build(app)
1412 .expect("Should create test server");
1413
1414 let address_regex = Regex::new("^http://127\\.0\\.0\\.1:[0-9]+/$").unwrap();
1415 let is_match = address_regex.is_match(&server.server_address().unwrap().to_string());
1416 assert!(is_match);
1417 }
1418
1419 #[tokio::test]
1420 async fn it_should_return_none_on_mock_transport() {
1421 let app = Router::new();
1422 let server = TestServer::builder()
1423 .mock_transport()
1424 .build(app)
1425 .expect("Should create test server");
1426
1427 assert!(server.server_address().is_none());
1428 }
1429}
1430
1431#[cfg(test)]
1432mod test_server_url {
1433 use super::*;
1434
1435 use axum::Router;
1436 use local_ip_address::local_ip;
1437 use regex::Regex;
1438 use reserve_port::ReservedPort;
1439
1440 #[tokio::test]
1441 async fn it_should_return_address_with_url_on_http_ip_port() {
1442 let reserved_port = ReservedPort::random().unwrap();
1443 let ip = local_ip().unwrap();
1444 let port = reserved_port.port();
1445
1446 let app = Router::new();
1448 let server = TestServer::builder()
1449 .http_transport_with_ip_port(Some(ip), Some(port))
1450 .build(app)
1451 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1452 .unwrap();
1453
1454 let expected_ip_port_url = format!("http://{}:{}/users", ip, reserved_port.port());
1455 let absolute_url = server.server_url("/users").unwrap().to_string();
1456 assert_eq!(absolute_url, expected_ip_port_url);
1457 }
1458
1459 #[tokio::test]
1460 async fn it_should_return_address_with_url_on_random_http() {
1461 let app = Router::new();
1462 let server = TestServer::builder()
1463 .http_transport()
1464 .build(app)
1465 .expect("Should create test server");
1466
1467 let address_regex =
1468 Regex::new("^http://127\\.0\\.0\\.1:[0-9]+/users/123\\?filter=enabled$").unwrap();
1469 let absolute_url = &server
1470 .server_url(&"/users/123?filter=enabled")
1471 .unwrap()
1472 .to_string();
1473
1474 let is_match = address_regex.is_match(absolute_url);
1475 assert!(is_match);
1476 }
1477
1478 #[tokio::test]
1479 async fn it_should_error_on_mock_transport() {
1480 let app = Router::new();
1482 let server = TestServer::builder()
1483 .mock_transport()
1484 .build(app)
1485 .expect("Should create test server");
1486
1487 let result = server.server_url("/users");
1488 assert!(result.is_err());
1489 }
1490
1491 #[tokio::test]
1492 async fn it_should_include_path_query_params() {
1493 let reserved_port = ReservedPort::random().unwrap();
1494 let ip = local_ip().unwrap();
1495 let port = reserved_port.port();
1496
1497 let app = Router::new();
1499 let server = TestServer::builder()
1500 .http_transport_with_ip_port(Some(ip), Some(port))
1501 .build(app)
1502 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1503 .unwrap();
1504
1505 let expected_url = format!(
1506 "http://{}:{}/users?filter=enabled",
1507 ip,
1508 reserved_port.port()
1509 );
1510 let received_url = server
1511 .server_url("/users?filter=enabled")
1512 .unwrap()
1513 .to_string();
1514
1515 assert_eq!(received_url, expected_url);
1516 }
1517
1518 #[tokio::test]
1519 async fn it_should_include_server_query_params() {
1520 let reserved_port = ReservedPort::random().unwrap();
1521 let ip = local_ip().unwrap();
1522 let port = reserved_port.port();
1523
1524 let app = Router::new();
1526 let mut server = TestServer::builder()
1527 .http_transport_with_ip_port(Some(ip), Some(port))
1528 .build(app)
1529 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1530 .unwrap();
1531
1532 server.add_query_param("filter", "enabled");
1533
1534 let expected_url = format!(
1535 "http://{}:{}/users?filter=enabled",
1536 ip,
1537 reserved_port.port()
1538 );
1539 let received_url = server.server_url("/users").unwrap().to_string();
1540
1541 assert_eq!(received_url, expected_url);
1542 }
1543
1544 #[tokio::test]
1545 async fn it_should_include_server_and_path_query_params() {
1546 let reserved_port = ReservedPort::random().unwrap();
1547 let ip = local_ip().unwrap();
1548 let port = reserved_port.port();
1549
1550 let app = Router::new();
1552 let mut server = TestServer::builder()
1553 .http_transport_with_ip_port(Some(ip), Some(port))
1554 .build(app)
1555 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1556 .unwrap();
1557
1558 server.add_query_param("filter", "enabled");
1559
1560 let expected_url = format!(
1561 "http://{}:{}/users?filter=enabled&animal=donkeys",
1562 ip,
1563 reserved_port.port()
1564 );
1565 let received_url = server
1566 .server_url("/users?animal=donkeys")
1567 .unwrap()
1568 .to_string();
1569
1570 assert_eq!(received_url, expected_url);
1571 }
1572}
1573
1574#[cfg(test)]
1575mod test_add_cookie {
1576 use crate::TestServer;
1577
1578 use axum::routing::get;
1579 use axum::Router;
1580 use axum_extra::extract::cookie::CookieJar;
1581 use cookie::Cookie;
1582
1583 const TEST_COOKIE_NAME: &'static str = &"test-cookie";
1584
1585 async fn get_cookie(cookies: CookieJar) -> (CookieJar, String) {
1586 let cookie = cookies.get(&TEST_COOKIE_NAME);
1587 let cookie_value = cookie
1588 .map(|c| c.value().to_string())
1589 .unwrap_or_else(|| "cookie-not-found".to_string());
1590
1591 (cookies, cookie_value)
1592 }
1593
1594 #[tokio::test]
1595 async fn it_should_send_cookies_added_to_request() {
1596 let app = Router::new().route("/cookie", get(get_cookie));
1597 let mut server = TestServer::new(app).expect("Should create test server");
1598
1599 let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1600 server.add_cookie(cookie);
1601
1602 let response_text = server.get(&"/cookie").await.text();
1603 assert_eq!(response_text, "my-custom-cookie");
1604 }
1605}
1606
1607#[cfg(test)]
1608mod test_add_cookies {
1609 use crate::TestServer;
1610
1611 use axum::routing::get;
1612 use axum::Router;
1613 use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1614 use cookie::Cookie;
1615 use cookie::CookieJar;
1616
1617 async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1618 let mut all_cookies = cookies
1619 .iter()
1620 .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1621 .collect::<Vec<String>>();
1622 all_cookies.sort();
1623
1624 all_cookies.join(&", ")
1625 }
1626
1627 #[tokio::test]
1628 async fn it_should_send_all_cookies_added_by_jar() {
1629 let app = Router::new().route("/cookies", get(route_get_cookies));
1630 let mut server = TestServer::new(app).expect("Should create test server");
1631
1632 let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1634 let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1635 let mut cookie_jar = CookieJar::new();
1636 cookie_jar.add(cookie_1);
1637 cookie_jar.add(cookie_2);
1638
1639 server.add_cookies(cookie_jar);
1640
1641 server
1642 .get(&"/cookies")
1643 .await
1644 .assert_text("first-cookie=my-custom-cookie, second-cookie=other-cookie");
1645 }
1646}
1647
1648#[cfg(test)]
1649mod test_clear_cookies {
1650 use crate::TestServer;
1651
1652 use axum::routing::get;
1653 use axum::Router;
1654 use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1655 use cookie::Cookie;
1656 use cookie::CookieJar;
1657
1658 async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1659 let mut all_cookies = cookies
1660 .iter()
1661 .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1662 .collect::<Vec<String>>();
1663 all_cookies.sort();
1664
1665 all_cookies.join(&", ")
1666 }
1667
1668 #[tokio::test]
1669 async fn it_should_not_send_cookies_cleared() {
1670 let app = Router::new().route("/cookies", get(route_get_cookies));
1671 let mut server = TestServer::new(app).expect("Should create test server");
1672
1673 let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1674 let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1675 let mut cookie_jar = CookieJar::new();
1676 cookie_jar.add(cookie_1);
1677 cookie_jar.add(cookie_2);
1678
1679 server.add_cookies(cookie_jar);
1680
1681 server.clear_cookies();
1683
1684 server.get(&"/cookies").await.assert_text("");
1685 }
1686}
1687
1688#[cfg(test)]
1689mod test_add_header {
1690 use super::*;
1691 use crate::TestServer;
1692 use axum::extract::FromRequestParts;
1693 use axum::routing::get;
1694 use axum::Router;
1695 use http::request::Parts;
1696 use http::HeaderName;
1697 use http::HeaderValue;
1698 use hyper::StatusCode;
1699 use std::marker::Sync;
1700
1701 const TEST_HEADER_NAME: &'static str = &"test-header";
1702 const TEST_HEADER_CONTENT: &'static str = &"Test header content";
1703
1704 struct TestHeader(Vec<u8>);
1705
1706 impl<S: Sync> FromRequestParts<S> for TestHeader {
1707 type Rejection = (StatusCode, &'static str);
1708
1709 async fn from_request_parts(
1710 parts: &mut Parts,
1711 _state: &S,
1712 ) -> Result<TestHeader, Self::Rejection> {
1713 parts
1714 .headers
1715 .get(HeaderName::from_static(TEST_HEADER_NAME))
1716 .map(|v| TestHeader(v.as_bytes().to_vec()))
1717 .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
1718 }
1719 }
1720
1721 async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
1722 header
1723 }
1724
1725 #[tokio::test]
1726 async fn it_should_send_header_added_to_server() {
1727 let app = Router::new().route("/header", get(ping_header));
1729
1730 let mut server = TestServer::new(app).expect("Should create test server");
1732 server.add_header(
1733 HeaderName::from_static(TEST_HEADER_NAME),
1734 HeaderValue::from_static(TEST_HEADER_CONTENT),
1735 );
1736
1737 let response = server.get(&"/header").await;
1739
1740 response.assert_text(TEST_HEADER_CONTENT)
1742 }
1743}
1744
1745#[cfg(test)]
1746mod test_clear_headers {
1747 use super::*;
1748 use crate::TestServer;
1749 use axum::extract::FromRequestParts;
1750 use axum::routing::get;
1751 use axum::Router;
1752 use http::request::Parts;
1753 use http::HeaderName;
1754 use http::HeaderValue;
1755 use hyper::StatusCode;
1756 use std::marker::Sync;
1757
1758 const TEST_HEADER_NAME: &'static str = &"test-header";
1759 const TEST_HEADER_CONTENT: &'static str = &"Test header content";
1760
1761 struct TestHeader(Vec<u8>);
1762
1763 impl<S: Sync> FromRequestParts<S> for TestHeader {
1764 type Rejection = (StatusCode, &'static str);
1765
1766 async fn from_request_parts(
1767 parts: &mut Parts,
1768 _state: &S,
1769 ) -> Result<Self, Self::Rejection> {
1770 parts
1771 .headers
1772 .get(HeaderName::from_static(TEST_HEADER_NAME))
1773 .map(|v| TestHeader(v.as_bytes().to_vec()))
1774 .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
1775 }
1776 }
1777
1778 async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
1779 header
1780 }
1781
1782 #[tokio::test]
1783 async fn it_should_not_send_headers_cleared_by_server() {
1784 let app = Router::new().route("/header", get(ping_header));
1786
1787 let mut server = TestServer::new(app).expect("Should create test server");
1789 server.add_header(
1790 HeaderName::from_static(TEST_HEADER_NAME),
1791 HeaderValue::from_static(TEST_HEADER_CONTENT),
1792 );
1793 server.clear_headers();
1794
1795 let response = server.get(&"/header").await;
1797
1798 response.assert_status_bad_request();
1800 response.assert_text("Missing test header");
1801 }
1802}
1803
1804#[cfg(test)]
1805mod test_add_query_params {
1806 use axum::extract::Query;
1807 use axum::routing::get;
1808 use axum::Router;
1809
1810 use serde::Deserialize;
1811 use serde::Serialize;
1812 use serde_json::json;
1813
1814 use crate::TestServer;
1815
1816 #[derive(Debug, Deserialize, Serialize)]
1817 struct QueryParam {
1818 message: String,
1819 }
1820
1821 async fn get_query_param(Query(params): Query<QueryParam>) -> String {
1822 params.message
1823 }
1824
1825 #[derive(Debug, Deserialize, Serialize)]
1826 struct QueryParam2 {
1827 message: String,
1828 other: String,
1829 }
1830
1831 async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
1832 format!("{}-{}", params.message, params.other)
1833 }
1834
1835 #[tokio::test]
1836 async fn it_should_pass_up_query_params_from_serialization() {
1837 let app = Router::new().route("/query", get(get_query_param));
1839
1840 let mut server = TestServer::new(app).expect("Should create test server");
1842 server.add_query_params(QueryParam {
1843 message: "it works".to_string(),
1844 });
1845
1846 server.get(&"/query").await.assert_text(&"it works");
1848 }
1849
1850 #[tokio::test]
1851 async fn it_should_pass_up_query_params_from_pairs() {
1852 let app = Router::new().route("/query", get(get_query_param));
1854
1855 let mut server = TestServer::new(app).expect("Should create test server");
1857 server.add_query_params(&[("message", "it works")]);
1858
1859 server.get(&"/query").await.assert_text(&"it works");
1861 }
1862
1863 #[tokio::test]
1864 async fn it_should_pass_up_multiple_query_params_from_multiple_params() {
1865 let app = Router::new().route("/query-2", get(get_query_param_2));
1867
1868 let mut server = TestServer::new(app).expect("Should create test server");
1870 server.add_query_params(&[("message", "it works"), ("other", "yup")]);
1871
1872 server.get(&"/query-2").await.assert_text(&"it works-yup");
1874 }
1875
1876 #[tokio::test]
1877 async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
1878 let app = Router::new().route("/query-2", get(get_query_param_2));
1880
1881 let mut server = TestServer::new(app).expect("Should create test server");
1883 server.add_query_params(&[("message", "it works")]);
1884 server.add_query_params(&[("other", "yup")]);
1885
1886 server.get(&"/query-2").await.assert_text(&"it works-yup");
1888 }
1889
1890 #[tokio::test]
1891 async fn it_should_pass_up_multiple_query_params_from_json() {
1892 let app = Router::new().route("/query-2", get(get_query_param_2));
1894
1895 let mut server = TestServer::new(app).expect("Should create test server");
1897 server.add_query_params(json!({
1898 "message": "it works",
1899 "other": "yup"
1900 }));
1901
1902 server.get(&"/query-2").await.assert_text(&"it works-yup");
1904 }
1905}
1906
1907#[cfg(test)]
1908mod test_add_query_param {
1909 use axum::extract::Query;
1910 use axum::routing::get;
1911 use axum::Router;
1912
1913 use serde::Deserialize;
1914 use serde::Serialize;
1915
1916 use crate::TestServer;
1917
1918 #[derive(Debug, Deserialize, Serialize)]
1919 struct QueryParam {
1920 message: String,
1921 }
1922
1923 async fn get_query_param(Query(params): Query<QueryParam>) -> String {
1924 params.message
1925 }
1926
1927 #[derive(Debug, Deserialize, Serialize)]
1928 struct QueryParam2 {
1929 message: String,
1930 other: String,
1931 }
1932
1933 async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
1934 format!("{}-{}", params.message, params.other)
1935 }
1936
1937 #[tokio::test]
1938 async fn it_should_pass_up_query_params_from_pairs() {
1939 let app = Router::new().route("/query", get(get_query_param));
1941
1942 let mut server = TestServer::new(app).expect("Should create test server");
1944 server.add_query_param("message", "it works");
1945
1946 server.get(&"/query").await.assert_text(&"it works");
1948 }
1949
1950 #[tokio::test]
1951 async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
1952 let app = Router::new().route("/query-2", get(get_query_param_2));
1954
1955 let mut server = TestServer::new(app).expect("Should create test server");
1957 server.add_query_param("message", "it works");
1958 server.add_query_param("other", "yup");
1959
1960 server.get(&"/query-2").await.assert_text(&"it works-yup");
1962 }
1963
1964 #[tokio::test]
1965 async fn it_should_pass_up_multiple_query_params_from_calls_across_server_and_request() {
1966 let app = Router::new().route("/query-2", get(get_query_param_2));
1968
1969 let mut server = TestServer::new(app).expect("Should create test server");
1971 server.add_query_param("message", "it works");
1972
1973 server
1975 .get(&"/query-2")
1976 .add_query_param("other", "yup")
1977 .await
1978 .assert_text(&"it works-yup");
1979 }
1980}
1981
1982#[cfg(test)]
1983mod test_add_raw_query_param {
1984 use axum::extract::Query as AxumStdQuery;
1985 use axum::routing::get;
1986 use axum::Router;
1987 use axum_extra::extract::Query as AxumExtraQuery;
1988 use serde::Deserialize;
1989 use serde::Serialize;
1990 use std::fmt::Write;
1991
1992 use crate::TestServer;
1993
1994 #[derive(Debug, Deserialize, Serialize)]
1995 struct QueryParam {
1996 message: String,
1997 }
1998
1999 async fn get_query_param(AxumStdQuery(params): AxumStdQuery<QueryParam>) -> String {
2000 params.message
2001 }
2002
2003 #[derive(Debug, Deserialize, Serialize)]
2004 struct QueryParamExtra {
2005 #[serde(default)]
2006 items: Vec<String>,
2007
2008 #[serde(default, rename = "arrs[]")]
2009 arrs: Vec<String>,
2010 }
2011
2012 async fn get_query_param_extra(
2013 AxumExtraQuery(params): AxumExtraQuery<QueryParamExtra>,
2014 ) -> String {
2015 let mut output = String::new();
2016
2017 if params.items.len() > 0 {
2018 write!(output, "{}", params.items.join(", ")).unwrap();
2019 }
2020
2021 if params.arrs.len() > 0 {
2022 write!(output, "{}", params.arrs.join(", ")).unwrap();
2023 }
2024
2025 output
2026 }
2027
2028 fn build_app() -> Router {
2029 Router::new()
2030 .route("/query", get(get_query_param))
2031 .route("/query-extra", get(get_query_param_extra))
2032 }
2033
2034 #[tokio::test]
2035 async fn it_should_pass_up_query_param_as_is() {
2036 let mut server = TestServer::new(build_app()).expect("Should create test server");
2038 server.add_raw_query_param(&"message=it-works");
2039
2040 server.get(&"/query").await.assert_text(&"it-works");
2042 }
2043
2044 #[tokio::test]
2045 async fn it_should_pass_up_array_query_params_as_one_string() {
2046 let mut server = TestServer::new(build_app()).expect("Should create test server");
2048 server.add_raw_query_param(&"items=one&items=two&items=three");
2049
2050 server
2052 .get(&"/query-extra")
2053 .await
2054 .assert_text(&"one, two, three");
2055 }
2056
2057 #[tokio::test]
2058 async fn it_should_pass_up_array_query_params_as_multiple_params() {
2059 let mut server = TestServer::new(build_app()).expect("Should create test server");
2061 server.add_raw_query_param(&"arrs[]=one");
2062 server.add_raw_query_param(&"arrs[]=two");
2063 server.add_raw_query_param(&"arrs[]=three");
2064
2065 server
2067 .get(&"/query-extra")
2068 .await
2069 .assert_text(&"one, two, three");
2070 }
2071}
2072
2073#[cfg(test)]
2074mod test_clear_query_params {
2075 use axum::extract::Query;
2076 use axum::routing::get;
2077 use axum::Router;
2078
2079 use serde::Deserialize;
2080 use serde::Serialize;
2081
2082 use crate::TestServer;
2083
2084 #[derive(Debug, Deserialize, Serialize)]
2085 struct QueryParams {
2086 first: Option<String>,
2087 second: Option<String>,
2088 }
2089
2090 async fn get_query_params(Query(params): Query<QueryParams>) -> String {
2091 format!(
2092 "has first? {}, has second? {}",
2093 params.first.is_some(),
2094 params.second.is_some()
2095 )
2096 }
2097
2098 #[tokio::test]
2099 async fn it_should_clear_all_params_set() {
2100 let app = Router::new().route("/query", get(get_query_params));
2102
2103 let mut server = TestServer::new(app).expect("Should create test server");
2105 server.add_query_params(QueryParams {
2106 first: Some("first".to_string()),
2107 second: Some("second".to_string()),
2108 });
2109 server.clear_query_params();
2110
2111 server
2113 .get(&"/query")
2114 .await
2115 .assert_text(&"has first? false, has second? false");
2116 }
2117
2118 #[tokio::test]
2119 async fn it_should_clear_all_params_set_and_allow_replacement() {
2120 let app = Router::new().route("/query", get(get_query_params));
2122
2123 let mut server = TestServer::new(app).expect("Should create test server");
2125 server.add_query_params(QueryParams {
2126 first: Some("first".to_string()),
2127 second: Some("second".to_string()),
2128 });
2129 server.clear_query_params();
2130 server.add_query_params(QueryParams {
2131 first: Some("first".to_string()),
2132 second: Some("second".to_string()),
2133 });
2134
2135 server
2137 .get(&"/query")
2138 .await
2139 .assert_text(&"has first? true, has second? true");
2140 }
2141}
2142
2143#[cfg(test)]
2144mod test_expect_success_by_default {
2145 use super::*;
2146
2147 use axum::routing::get;
2148 use axum::Router;
2149
2150 #[tokio::test]
2151 async fn it_should_not_panic_by_default_if_accessing_404_route() {
2152 let app = Router::new();
2153 let server = TestServer::new(app).expect("Should create test server");
2154
2155 server.get(&"/some_unknown_route").await;
2156 }
2157
2158 #[tokio::test]
2159 async fn it_should_not_panic_by_default_if_accessing_200_route() {
2160 let app = Router::new().route("/known_route", get(|| async { "🦊🦊🦊" }));
2161 let server = TestServer::new(app).expect("Should create test server");
2162
2163 server.get(&"/known_route").await;
2164 }
2165
2166 #[tokio::test]
2167 #[should_panic]
2168 async fn it_should_panic_by_default_if_accessing_404_route_and_expect_success_on() {
2169 let app = Router::new();
2170 let server = TestServer::builder()
2171 .expect_success_by_default()
2172 .build(app)
2173 .expect("Should create test server");
2174
2175 server.get(&"/some_unknown_route").await;
2176 }
2177
2178 #[tokio::test]
2179 async fn it_should_not_panic_by_default_if_accessing_200_route_and_expect_success_on() {
2180 let app = Router::new().route("/known_route", get(|| async { "🦊🦊🦊" }));
2181 let server = TestServer::builder()
2182 .expect_success_by_default()
2183 .build(app)
2184 .expect("Should create test server");
2185
2186 server.get(&"/known_route").await;
2187 }
2188}
2189
2190#[cfg(test)]
2191mod test_content_type {
2192 use super::*;
2193
2194 use axum::routing::get;
2195 use axum::Router;
2196 use http::header::CONTENT_TYPE;
2197 use http::HeaderMap;
2198
2199 async fn get_content_type(headers: HeaderMap) -> String {
2200 headers
2201 .get(CONTENT_TYPE)
2202 .map(|h| h.to_str().unwrap().to_string())
2203 .unwrap_or_else(|| "".to_string())
2204 }
2205
2206 #[tokio::test]
2207 async fn it_should_default_to_server_content_type_when_present() {
2208 let app = Router::new().route("/content_type", get(get_content_type));
2210
2211 let server = TestServer::builder()
2213 .default_content_type("text/plain")
2214 .build(app)
2215 .expect("Should create test server");
2216
2217 let text = server.get(&"/content_type").await.text();
2219
2220 assert_eq!(text, "text/plain");
2221 }
2222}
2223
2224#[cfg(test)]
2225mod test_expect_success {
2226 use crate::TestServer;
2227 use axum::routing::get;
2228 use axum::Router;
2229 use http::StatusCode;
2230
2231 #[tokio::test]
2232 async fn it_should_not_panic_if_success_is_returned() {
2233 async fn get_ping() -> &'static str {
2234 "pong!"
2235 }
2236
2237 let app = Router::new().route("/ping", get(get_ping));
2239
2240 let mut server = TestServer::new(app).expect("Should create test server");
2242 server.expect_success();
2243
2244 server.get(&"/ping").await;
2246 }
2247
2248 #[tokio::test]
2249 async fn it_should_not_panic_on_other_2xx_status_code() {
2250 async fn get_accepted() -> StatusCode {
2251 StatusCode::ACCEPTED
2252 }
2253
2254 let app = Router::new().route("/accepted", get(get_accepted));
2256
2257 let mut server = TestServer::new(app).expect("Should create test server");
2259 server.expect_success();
2260
2261 server.get(&"/accepted").await;
2263 }
2264
2265 #[tokio::test]
2266 #[should_panic]
2267 async fn it_should_panic_on_404() {
2268 let app = Router::new();
2270
2271 let mut server = TestServer::new(app).expect("Should create test server");
2273 server.expect_success();
2274
2275 server.get(&"/some_unknown_route").await;
2277 }
2278}
2279
2280#[cfg(test)]
2281mod test_expect_failure {
2282 use crate::TestServer;
2283 use axum::routing::get;
2284 use axum::Router;
2285 use http::StatusCode;
2286
2287 #[tokio::test]
2288 async fn it_should_not_panic_if_expect_failure_on_404() {
2289 let app = Router::new();
2291
2292 let mut server = TestServer::new(app).expect("Should create test server");
2294 server.expect_failure();
2295
2296 server.get(&"/some_unknown_route").await;
2298 }
2299
2300 #[tokio::test]
2301 #[should_panic]
2302 async fn it_should_panic_if_success_is_returned() {
2303 async fn get_ping() -> &'static str {
2304 "pong!"
2305 }
2306
2307 let app = Router::new().route("/ping", get(get_ping));
2309
2310 let mut server = TestServer::new(app).expect("Should create test server");
2312 server.expect_failure();
2313
2314 server.get(&"/ping").await;
2316 }
2317
2318 #[tokio::test]
2319 #[should_panic]
2320 async fn it_should_panic_on_other_2xx_status_code() {
2321 async fn get_accepted() -> StatusCode {
2322 StatusCode::ACCEPTED
2323 }
2324
2325 let app = Router::new().route("/accepted", get(get_accepted));
2327
2328 let mut server = TestServer::new(app).expect("Should create test server");
2330 server.expect_failure();
2331
2332 server.get(&"/accepted").await;
2334 }
2335}
2336
2337#[cfg(test)]
2338mod test_scheme {
2339 use axum::extract::Request;
2340 use axum::routing::get;
2341 use axum::Router;
2342
2343 use crate::TestServer;
2344
2345 async fn route_get_scheme(request: Request) -> String {
2346 request.uri().scheme_str().unwrap().to_string()
2347 }
2348
2349 #[tokio::test]
2350 async fn it_should_return_http_by_default() {
2351 let router = Router::new().route("/scheme", get(route_get_scheme));
2352 let server = TestServer::builder().build(router).unwrap();
2353
2354 server.get("/scheme").await.assert_text("http");
2355 }
2356
2357 #[tokio::test]
2358 async fn it_should_return_https_across_multiple_requests_when_set() {
2359 let router = Router::new().route("/scheme", get(route_get_scheme));
2360 let mut server = TestServer::builder().build(router).unwrap();
2361 server.scheme(&"https");
2362
2363 server.get("/scheme").await.assert_text("https");
2364 }
2365}
2366
2367#[cfg(feature = "typed-routing")]
2368#[cfg(test)]
2369mod test_typed_get {
2370 use super::*;
2371
2372 use axum::Router;
2373 use axum_extra::routing::RouterExt;
2374 use serde::Deserialize;
2375
2376 #[derive(TypedPath, Deserialize)]
2377 #[typed_path("/path/{id}")]
2378 struct TestingPath {
2379 id: u32,
2380 }
2381
2382 async fn route_get(TestingPath { id }: TestingPath) -> String {
2383 format!("get {id}")
2384 }
2385
2386 fn new_app() -> Router {
2387 Router::new().typed_get(route_get)
2388 }
2389
2390 #[tokio::test]
2391 async fn it_should_send_get() {
2392 let server = TestServer::new(new_app()).unwrap();
2393
2394 server
2395 .typed_get(&TestingPath { id: 123 })
2396 .await
2397 .assert_text("get 123");
2398 }
2399}
2400
2401#[cfg(feature = "typed-routing")]
2402#[cfg(test)]
2403mod test_typed_post {
2404 use super::*;
2405
2406 use axum::Router;
2407 use axum_extra::routing::RouterExt;
2408 use serde::Deserialize;
2409
2410 #[derive(TypedPath, Deserialize)]
2411 #[typed_path("/path/{id}")]
2412 struct TestingPath {
2413 id: u32,
2414 }
2415
2416 async fn route_post(TestingPath { id }: TestingPath) -> String {
2417 format!("post {id}")
2418 }
2419
2420 fn new_app() -> Router {
2421 Router::new().typed_post(route_post)
2422 }
2423
2424 #[tokio::test]
2425 async fn it_should_send_post() {
2426 let server = TestServer::new(new_app()).unwrap();
2427
2428 server
2429 .typed_post(&TestingPath { id: 123 })
2430 .await
2431 .assert_text("post 123");
2432 }
2433}
2434
2435#[cfg(feature = "typed-routing")]
2436#[cfg(test)]
2437mod test_typed_patch {
2438 use super::*;
2439
2440 use axum::Router;
2441 use axum_extra::routing::RouterExt;
2442 use serde::Deserialize;
2443
2444 #[derive(TypedPath, Deserialize)]
2445 #[typed_path("/path/{id}")]
2446 struct TestingPath {
2447 id: u32,
2448 }
2449
2450 async fn route_patch(TestingPath { id }: TestingPath) -> String {
2451 format!("patch {id}")
2452 }
2453
2454 fn new_app() -> Router {
2455 Router::new().typed_patch(route_patch)
2456 }
2457
2458 #[tokio::test]
2459 async fn it_should_send_patch() {
2460 let server = TestServer::new(new_app()).unwrap();
2461
2462 server
2463 .typed_patch(&TestingPath { id: 123 })
2464 .await
2465 .assert_text("patch 123");
2466 }
2467}
2468
2469#[cfg(feature = "typed-routing")]
2470#[cfg(test)]
2471mod test_typed_put {
2472 use super::*;
2473
2474 use axum::Router;
2475 use axum_extra::routing::RouterExt;
2476 use serde::Deserialize;
2477
2478 #[derive(TypedPath, Deserialize)]
2479 #[typed_path("/path/{id}")]
2480 struct TestingPath {
2481 id: u32,
2482 }
2483
2484 async fn route_put(TestingPath { id }: TestingPath) -> String {
2485 format!("put {id}")
2486 }
2487
2488 fn new_app() -> Router {
2489 Router::new().typed_put(route_put)
2490 }
2491
2492 #[tokio::test]
2493 async fn it_should_send_put() {
2494 let server = TestServer::new(new_app()).unwrap();
2495
2496 server
2497 .typed_put(&TestingPath { id: 123 })
2498 .await
2499 .assert_text("put 123");
2500 }
2501}
2502
2503#[cfg(feature = "typed-routing")]
2504#[cfg(test)]
2505mod test_typed_delete {
2506 use super::*;
2507
2508 use axum::Router;
2509 use axum_extra::routing::RouterExt;
2510 use serde::Deserialize;
2511
2512 #[derive(TypedPath, Deserialize)]
2513 #[typed_path("/path/{id}")]
2514 struct TestingPath {
2515 id: u32,
2516 }
2517
2518 async fn route_delete(TestingPath { id }: TestingPath) -> String {
2519 format!("delete {id}")
2520 }
2521
2522 fn new_app() -> Router {
2523 Router::new().typed_delete(route_delete)
2524 }
2525
2526 #[tokio::test]
2527 async fn it_should_send_delete() {
2528 let server = TestServer::new(new_app()).unwrap();
2529
2530 server
2531 .typed_delete(&TestingPath { id: 123 })
2532 .await
2533 .assert_text("delete 123");
2534 }
2535}
2536
2537#[cfg(feature = "typed-routing")]
2538#[cfg(test)]
2539mod test_typed_method {
2540 use super::*;
2541
2542 use axum::Router;
2543 use axum_extra::routing::RouterExt;
2544 use serde::Deserialize;
2545
2546 #[derive(TypedPath, Deserialize)]
2547 #[typed_path("/path/{id}")]
2548 struct TestingPath {
2549 id: u32,
2550 }
2551
2552 async fn route_get(TestingPath { id }: TestingPath) -> String {
2553 format!("get {id}")
2554 }
2555
2556 async fn route_post(TestingPath { id }: TestingPath) -> String {
2557 format!("post {id}")
2558 }
2559
2560 async fn route_patch(TestingPath { id }: TestingPath) -> String {
2561 format!("patch {id}")
2562 }
2563
2564 async fn route_put(TestingPath { id }: TestingPath) -> String {
2565 format!("put {id}")
2566 }
2567
2568 async fn route_delete(TestingPath { id }: TestingPath) -> String {
2569 format!("delete {id}")
2570 }
2571
2572 fn new_app() -> Router {
2573 Router::new()
2574 .typed_get(route_get)
2575 .typed_post(route_post)
2576 .typed_patch(route_patch)
2577 .typed_put(route_put)
2578 .typed_delete(route_delete)
2579 }
2580
2581 #[tokio::test]
2582 async fn it_should_send_get() {
2583 let server = TestServer::new(new_app()).unwrap();
2584
2585 server
2586 .typed_method(Method::GET, &TestingPath { id: 123 })
2587 .await
2588 .assert_text("get 123");
2589 }
2590
2591 #[tokio::test]
2592 async fn it_should_send_post() {
2593 let server = TestServer::new(new_app()).unwrap();
2594
2595 server
2596 .typed_method(Method::POST, &TestingPath { id: 123 })
2597 .await
2598 .assert_text("post 123");
2599 }
2600
2601 #[tokio::test]
2602 async fn it_should_send_patch() {
2603 let server = TestServer::new(new_app()).unwrap();
2604
2605 server
2606 .typed_method(Method::PATCH, &TestingPath { id: 123 })
2607 .await
2608 .assert_text("patch 123");
2609 }
2610
2611 #[tokio::test]
2612 async fn it_should_send_put() {
2613 let server = TestServer::new(new_app()).unwrap();
2614
2615 server
2616 .typed_method(Method::PUT, &TestingPath { id: 123 })
2617 .await
2618 .assert_text("put 123");
2619 }
2620
2621 #[tokio::test]
2622 async fn it_should_send_delete() {
2623 let server = TestServer::new(new_app()).unwrap();
2624
2625 server
2626 .typed_method(Method::DELETE, &TestingPath { id: 123 })
2627 .await
2628 .assert_text("delete 123");
2629 }
2630}
2631
2632#[cfg(test)]
2633mod test_sync {
2634 use super::*;
2635 use axum::routing::get;
2636 use axum::Router;
2637 use std::cell::OnceCell;
2638
2639 #[tokio::test]
2640 async fn it_should_be_able_to_be_in_one_cell() {
2641 let cell: OnceCell<TestServer> = OnceCell::new();
2642 let server = cell.get_or_init(|| {
2643 async fn route_get() -> &'static str {
2644 "it works"
2645 }
2646
2647 let router = Router::new().route("/test", get(route_get));
2648
2649 TestServer::new(router).unwrap()
2650 });
2651
2652 server.get("/test").await.assert_text("it works");
2653 }
2654}
2655
2656#[cfg(test)]
2657mod test_is_running {
2658 use super::*;
2659 use crate::util::new_random_tokio_tcp_listener;
2660 use axum::routing::get;
2661 use axum::routing::IntoMakeService;
2662 use axum::serve;
2663 use axum::Router;
2664 use std::time::Duration;
2665 use tokio::sync::Notify;
2666 use tokio::time::sleep;
2667
2668 async fn get_ping() -> &'static str {
2669 "pong!"
2670 }
2671
2672 #[tokio::test]
2673 #[should_panic]
2674 async fn it_should_panic_when_run_with_mock_http() {
2675 let shutdown_notification = Arc::new(Notify::new());
2676 let waiting_notification = shutdown_notification.clone();
2677
2678 let app: IntoMakeService<Router> = Router::new()
2680 .route("/ping", get(get_ping))
2681 .into_make_service();
2682 let port = new_random_tokio_tcp_listener().unwrap();
2683 let application = serve(port, app)
2684 .with_graceful_shutdown(async move { waiting_notification.notified().await });
2685
2686 let server = TestServer::builder()
2688 .build(application)
2689 .expect("Should create test server");
2690
2691 server.get("/ping").await.assert_status_ok();
2692 assert!(server.is_running());
2693
2694 shutdown_notification.notify_one();
2695 sleep(Duration::from_millis(10)).await;
2696
2697 assert!(!server.is_running());
2698 server.get("/ping").await.assert_status_ok();
2699 }
2700}