1use anyhow::Context;
2use anyhow::Result;
3use anyhow::anyhow;
4use cookie::Cookie;
5use cookie::CookieJar;
6use http::HeaderName;
7use http::HeaderValue;
8use http::Method;
9use http::Uri;
10use serde::Serialize;
11use std::fmt::Debug;
12use std::sync::Arc;
13use std::sync::Mutex;
14use url::Url;
15
16#[cfg(feature = "typed-routing")]
17use axum_extra::routing::TypedPath;
18
19#[cfg(feature = "reqwest")]
20use crate::transport_layer::TransportLayerType;
21#[cfg(feature = "reqwest")]
22use reqwest::Client;
23#[cfg(feature = "reqwest")]
24use reqwest::RequestBuilder;
25
26use crate::TestRequest;
27use crate::TestRequestConfig;
28use crate::TestServerBuilder;
29use crate::TestServerConfig;
30use crate::Transport;
31use crate::internals::ExpectedState;
32use crate::internals::QueryParamsStore;
33use crate::internals::RequestPathFormatter;
34use crate::transport_layer::IntoTransportLayer;
35use crate::transport_layer::TransportLayer;
36use crate::transport_layer::TransportLayerBuilder;
37
38mod server_shared_state;
39pub(crate) use self::server_shared_state::*;
40
41const DEFAULT_URL_ADDRESS: &str = "http://localhost";
42
43#[derive(Debug)]
142pub struct TestServer {
143 state: Arc<Mutex<ServerSharedState>>,
144 transport: Arc<Box<dyn TransportLayer>>,
145 save_cookies: bool,
146 expected_state: ExpectedState,
147 default_content_type: Option<String>,
148 is_http_path_restricted: bool,
149
150 #[cfg(feature = "reqwest")]
151 maybe_reqwest_client: Option<Client>,
152}
153
154impl TestServer {
155 pub fn builder() -> TestServerBuilder {
157 TestServerBuilder::default()
158 }
159
160 pub fn new<A>(app: A) -> Result<Self>
192 where
193 A: IntoTransportLayer,
194 {
195 Self::new_with_config(app, TestServerConfig::default())
196 }
197
198 pub fn new_with_config<A, C>(app: A, config: C) -> Result<Self>
205 where
206 A: IntoTransportLayer,
207 C: Into<TestServerConfig>,
208 {
209 let config = config.into();
210 let mut shared_state = ServerSharedState::new();
211 if let Some(scheme) = config.default_scheme {
212 shared_state.set_scheme_unlocked(scheme);
213 }
214
215 let shared_state_mutex = Mutex::new(shared_state);
216 let state = Arc::new(shared_state_mutex);
217
218 let transport = match config.transport {
219 None => {
220 let builder = TransportLayerBuilder::new(None, None);
221 let transport = app.into_default_transport(builder)?;
222 Arc::new(transport)
223 }
224 Some(Transport::HttpRandomPort) => {
225 let builder = TransportLayerBuilder::new(None, None);
226 let transport = app.into_http_transport_layer(builder)?;
227 Arc::new(transport)
228 }
229 Some(Transport::HttpIpPort { ip, port }) => {
230 let builder = TransportLayerBuilder::new(ip, port);
231 let transport = app.into_http_transport_layer(builder)?;
232 Arc::new(transport)
233 }
234 Some(Transport::MockHttp) => {
235 let transport = app.into_mock_transport_layer()?;
236 Arc::new(transport)
237 }
238 };
239
240 let expected_state = match config.expect_success_by_default {
241 true => ExpectedState::Success,
242 false => ExpectedState::None,
243 };
244
245 #[cfg(feature = "reqwest")]
246 let maybe_reqwest_client = match transport.transport_layer_type() {
247 TransportLayerType::Http => {
248 let reqwest_client = reqwest::Client::builder()
249 .redirect(reqwest::redirect::Policy::none())
250 .cookie_store(config.save_cookies)
251 .build()
252 .expect("Failed to build Reqwest Client");
253
254 Some(reqwest_client)
255 }
256 TransportLayerType::Mock => None,
257 };
258
259 Ok(Self {
260 state,
261 transport,
262 save_cookies: config.save_cookies,
263 expected_state,
264 default_content_type: config.default_content_type,
265 is_http_path_restricted: config.restrict_requests_with_http_schema,
266
267 #[cfg(feature = "reqwest")]
268 maybe_reqwest_client,
269 })
270 }
271
272 pub fn get(&self, path: &str) -> TestRequest {
274 self.method(Method::GET, path)
275 }
276
277 pub fn post(&self, path: &str) -> TestRequest {
279 self.method(Method::POST, path)
280 }
281
282 pub fn patch(&self, path: &str) -> TestRequest {
284 self.method(Method::PATCH, path)
285 }
286
287 pub fn put(&self, path: &str) -> TestRequest {
289 self.method(Method::PUT, path)
290 }
291
292 pub fn delete(&self, path: &str) -> TestRequest {
294 self.method(Method::DELETE, path)
295 }
296
297 pub fn method(&self, method: Method, path: &str) -> TestRequest {
299 let maybe_config = self.build_test_request_config(method.clone(), path);
300 let config = maybe_config
301 .with_context(|| format!("Failed to build, for request {method} {path}"))
302 .unwrap();
303
304 TestRequest::new(self.state.clone(), self.transport.clone(), config)
305 }
306
307 #[cfg(feature = "reqwest")]
308 fn reqwest_client(&self) -> &Client {
309 self.maybe_reqwest_client
310 .as_ref()
311 .expect("Reqwest client is not available, TestServer must be build with HTTP transport for Reqwest to be available")
312 }
313
314 #[cfg(feature = "reqwest")]
315 pub fn reqwest_get(&self, path: &str) -> RequestBuilder {
316 self.reqwest_method(Method::GET, path)
317 }
318
319 #[cfg(feature = "reqwest")]
320 pub fn reqwest_post(&self, path: &str) -> RequestBuilder {
321 self.reqwest_method(Method::POST, path)
322 }
323
324 #[cfg(feature = "reqwest")]
325 pub fn reqwest_put(&self, path: &str) -> RequestBuilder {
326 self.reqwest_method(Method::PUT, path)
327 }
328
329 #[cfg(feature = "reqwest")]
330 pub fn reqwest_patch(&self, path: &str) -> RequestBuilder {
331 self.reqwest_method(Method::PATCH, path)
332 }
333
334 #[cfg(feature = "reqwest")]
335 pub fn reqwest_delete(&self, path: &str) -> RequestBuilder {
336 self.reqwest_method(Method::DELETE, path)
337 }
338
339 #[cfg(feature = "reqwest")]
340 pub fn reqwest_head(&self, path: &str) -> RequestBuilder {
341 self.reqwest_method(Method::HEAD, path)
342 }
343
344 #[cfg(feature = "reqwest")]
369 pub fn reqwest_method(&self, method: Method, path: &str) -> RequestBuilder {
370 let request_url = self
371 .server_url(path)
372 .expect("Failed to generate server url for request {method} {path}");
373
374 self.reqwest_client().request(method, request_url)
375 }
376
377 #[cfg(feature = "ws")]
412 pub fn get_websocket(&self, path: &str) -> TestRequest {
413 use http::header;
414
415 self.get(path)
416 .add_header(header::CONNECTION, "upgrade")
417 .add_header(header::UPGRADE, "websocket")
418 .add_header(header::SEC_WEBSOCKET_VERSION, "13")
419 .add_header(
420 header::SEC_WEBSOCKET_KEY,
421 crate::internals::generate_ws_key(),
422 )
423 }
424
425 #[cfg(feature = "typed-routing")]
472 pub fn typed_get<P>(&self, path: &P) -> TestRequest
473 where
474 P: TypedPath,
475 {
476 self.typed_method(Method::GET, path)
477 }
478
479 #[cfg(feature = "typed-routing")]
483 pub fn typed_post<P>(&self, path: &P) -> TestRequest
484 where
485 P: TypedPath,
486 {
487 self.typed_method(Method::POST, path)
488 }
489
490 #[cfg(feature = "typed-routing")]
494 pub fn typed_patch<P>(&self, path: &P) -> TestRequest
495 where
496 P: TypedPath,
497 {
498 self.typed_method(Method::PATCH, path)
499 }
500
501 #[cfg(feature = "typed-routing")]
505 pub fn typed_put<P>(&self, path: &P) -> TestRequest
506 where
507 P: TypedPath,
508 {
509 self.typed_method(Method::PUT, path)
510 }
511
512 #[cfg(feature = "typed-routing")]
516 pub fn typed_delete<P>(&self, path: &P) -> TestRequest
517 where
518 P: TypedPath,
519 {
520 self.typed_method(Method::DELETE, path)
521 }
522
523 #[cfg(feature = "typed-routing")]
527 pub fn typed_method<P>(&self, method: Method, path: &P) -> TestRequest
528 where
529 P: TypedPath,
530 {
531 self.method(method, &path.to_string())
532 }
533
534 pub fn server_address(&self) -> Option<Url> {
542 self.url()
543 }
544
545 pub fn server_url(&self, path: &str) -> Result<Url> {
578 let path_uri = path.parse::<Uri>()?;
579 if is_absolute_uri(&path_uri) {
580 return Err(anyhow!(
581 "Absolute path provided for building server url, need to provide a relative uri"
582 ));
583 }
584
585 let server_url = self.url()
586 .ok_or_else(||
587 anyhow!(
588 "No local address for server, need to run with HTTP transport to have a server address",
589 )
590 )?;
591
592 let server_locked = self.state.as_ref().lock().map_err(|err| {
593 anyhow!("Failed to lock InternalTestServer, for building server_url, received {err:?}",)
594 })?;
595 let mut query_params = server_locked.query_params().clone();
596 let mut full_server_url = build_url(
597 server_url,
598 path,
599 &mut query_params,
600 self.is_http_path_restricted,
601 )?;
602
603 if query_params.has_content() {
605 full_server_url.set_query(Some(&query_params.to_string()));
606 }
607
608 Ok(full_server_url)
609 }
610
611 pub fn add_cookie(&mut self, cookie: Cookie) {
616 ServerSharedState::add_cookie(&self.state, cookie)
617 .context("Trying to call add_cookie")
618 .unwrap()
619 }
620
621 pub fn add_cookies(&mut self, cookies: CookieJar) {
626 ServerSharedState::add_cookies(&self.state, cookies)
627 .context("Trying to call add_cookies")
628 .unwrap()
629 }
630
631 pub fn clear_cookies(&mut self) {
633 ServerSharedState::clear_cookies(&self.state)
634 .context("Trying to call clear_cookies")
635 .unwrap()
636 }
637
638 pub fn save_cookies(&mut self) {
642 self.save_cookies = true;
643 }
644
645 pub fn do_not_save_cookies(&mut self) {
649 self.save_cookies = false;
650 }
651
652 pub fn expect_success(&mut self) {
656 self.expected_state = ExpectedState::Success;
657 }
658
659 pub fn expect_failure(&mut self) {
663 self.expected_state = ExpectedState::Failure;
664 }
665
666 pub fn add_query_param<V>(&mut self, key: &str, value: V)
668 where
669 V: Serialize,
670 {
671 ServerSharedState::add_query_param(&self.state, key, value)
672 .context("Trying to call add_query_param")
673 .unwrap()
674 }
675
676 pub fn add_query_params<V>(&mut self, query_params: V)
678 where
679 V: Serialize,
680 {
681 ServerSharedState::add_query_params(&self.state, query_params)
682 .context("Trying to call add_query_params")
683 .unwrap()
684 }
685
686 pub fn add_raw_query_param(&mut self, raw_query_param: &str) {
689 ServerSharedState::add_raw_query_param(&self.state, raw_query_param)
690 .context("Trying to call add_raw_query_param")
691 .unwrap()
692 }
693
694 pub fn clear_query_params(&mut self) {
696 ServerSharedState::clear_query_params(&self.state)
697 .context("Trying to call clear_query_params")
698 .unwrap()
699 }
700
701 pub fn add_header<N, V>(&mut self, name: N, value: V)
722 where
723 N: TryInto<HeaderName>,
724 N::Error: Debug,
725 V: TryInto<HeaderValue>,
726 V::Error: Debug,
727 {
728 let header_name: HeaderName = name
729 .try_into()
730 .expect("Failed to convert header name to HeaderName");
731 let header_value: HeaderValue = value
732 .try_into()
733 .expect("Failed to convert header vlue to HeaderValue");
734
735 ServerSharedState::add_header(&self.state, header_name, header_value)
736 .context("Trying to call add_header")
737 .unwrap()
738 }
739
740 pub fn clear_headers(&mut self) {
742 ServerSharedState::clear_headers(&self.state)
743 .context("Trying to call clear_headers")
744 .unwrap()
745 }
746
747 pub fn scheme(&mut self, scheme: &str) {
771 ServerSharedState::set_scheme(&self.state, scheme.to_string())
772 .context("Trying to call set_scheme")
773 .unwrap()
774 }
775
776 pub(crate) fn url(&self) -> Option<Url> {
777 self.transport.url().cloned()
778 }
779
780 pub(crate) fn build_test_request_config(
781 &self,
782 method: Method,
783 path: &str,
784 ) -> Result<TestRequestConfig> {
785 let url = self
786 .url()
787 .unwrap_or_else(|| DEFAULT_URL_ADDRESS.parse().unwrap());
788
789 let server_locked = self.state.as_ref().lock().map_err(|err| {
790 anyhow!(
791 "Failed to lock InternalTestServer, for request {method} {path}, received {err:?}",
792 )
793 })?;
794
795 let cookies = server_locked.cookies().clone();
796 let mut query_params = server_locked.query_params().clone();
797 let headers = server_locked.headers().clone();
798 let mut full_request_url =
799 build_url(url, path, &mut query_params, self.is_http_path_restricted)?;
800
801 if let Some(scheme) = server_locked.scheme() {
802 full_request_url.set_scheme(scheme).map_err(|_| {
803 let debug_request_format = RequestPathFormatter::new(&method, full_request_url.as_str(), Some(&query_params));
804 anyhow!("Scheme '{scheme}' from TestServer cannot be set to request {debug_request_format}")
805 })?;
806 }
807
808 ::std::mem::drop(server_locked);
809
810 Ok(TestRequestConfig {
811 is_saving_cookies: self.save_cookies,
812 expected_state: self.expected_state,
813 content_type: self.default_content_type.clone(),
814 method,
815
816 full_request_url,
817 cookies,
818 query_params,
819 headers,
820 })
821 }
822
823 pub fn is_running(&self) -> bool {
829 self.transport.is_running()
830 }
831}
832
833fn build_url(
834 mut url: Url,
835 path: &str,
836 query_params: &mut QueryParamsStore,
837 is_http_restricted: bool,
838) -> Result<Url> {
839 let path_uri = path.parse::<Uri>()?;
840
841 if let Some(scheme) = path_uri.scheme_str() {
843 if is_http_restricted {
844 if has_different_schema(&url, &path_uri) || has_different_authority(&url, &path_uri) {
845 return Err(anyhow!(
846 "Request disallowed for path '{path}', requests are only allowed to local server. Turn off 'restrict_requests_with_http_schema' to change this."
847 ));
848 }
849 } else {
850 url.set_scheme(scheme)
851 .map_err(|_| anyhow!("Failed to set scheme for request, with path '{path}'"))?;
852
853 if let Some(authority) = path_uri.authority() {
855 url.set_host(Some(authority.host()))
856 .map_err(|_| anyhow!("Failed to set host for request, with path '{path}'"))?;
857 url.set_port(authority.port().map(|p| p.as_u16()))
858 .map_err(|_| anyhow!("Failed to set port for request, with path '{path}'"))?;
859
860 }
862 }
863 }
864
865 if is_absolute_uri(&path_uri) {
875 url.set_path(path_uri.path());
876
877 if url.query().is_some() {
879 url.set_query(None);
880 }
881 } else {
882 let calculated_path = path.split('?').next().unwrap_or(path);
884 url.set_path(calculated_path);
885
886 if let Some(url_query) = url.query() {
888 query_params.add_raw(url_query.to_string());
889 url.set_query(None);
890 }
891 }
892
893 if let Some(path_query) = path_uri.query() {
894 query_params.add_raw(path_query.to_string());
895 }
896
897 Ok(url)
898}
899
900fn is_absolute_uri(path_uri: &Uri) -> bool {
901 path_uri.scheme_str().is_some()
902}
903
904fn has_different_schema(base_url: &Url, path_uri: &Uri) -> bool {
905 if let Some(scheme) = path_uri.scheme_str() {
906 return scheme != base_url.scheme();
907 }
908
909 false
910}
911
912fn has_different_authority(base_url: &Url, path_uri: &Uri) -> bool {
913 if let Some(authority) = path_uri.authority() {
914 return authority.as_str() != base_url.authority();
915 }
916
917 false
918}
919
920#[cfg(test)]
921mod test_build_url {
922 use super::*;
923
924 #[test]
925 fn it_should_copy_path_to_url_returned_when_restricted() {
926 let base_url = "http://example.com".parse::<Url>().unwrap();
927 let path = "/users";
928 let mut query_params = QueryParamsStore::new();
929 let result = build_url(base_url, &path, &mut query_params, true).unwrap();
930
931 assert_eq!("http://example.com/users", result.as_str());
932 assert!(query_params.is_empty());
933 }
934
935 #[test]
936 fn it_should_copy_all_query_params_to_store_when_restricted() {
937 let base_url = "http://example.com?base=aaa".parse::<Url>().unwrap();
938 let path = "/users?path=bbb&path-flag";
939 let mut query_params = QueryParamsStore::new();
940 let result = build_url(base_url, &path, &mut query_params, true).unwrap();
941
942 assert_eq!("http://example.com/users", result.as_str());
943 assert_eq!("base=aaa&path=bbb&path-flag", query_params.to_string());
944 }
945
946 #[test]
947 fn it_should_not_replace_url_when_restricted_with_different_scheme() {
948 let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
949 let path = "ftp://google.com:123/users.csv?limit=456";
950 let mut query_params = QueryParamsStore::new();
951 let result = build_url(base_url, &path, &mut query_params, true);
952
953 assert!(result.is_err());
954 }
955
956 #[test]
957 fn it_should_not_replace_url_when_restricted_with_same_scheme() {
958 let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
959 let path = "http://google.com:123/users.csv?limit=456";
960 let mut query_params = QueryParamsStore::new();
961 let result = build_url(base_url, &path, &mut query_params, true);
962
963 assert!(result.is_err());
964 }
965
966 #[test]
967 fn it_should_block_url_when_restricted_with_same_scheme() {
968 let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
969 let path = "http://google.com";
970 let mut query_params = QueryParamsStore::new();
971 let result = build_url(base_url, &path, &mut query_params, true);
972
973 assert!(result.is_err());
974 }
975
976 #[test]
977 fn it_should_block_url_when_restricted_and_same_domain_with_different_scheme() {
978 let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
979 let path = "ftp://example.com/users";
980 let mut query_params = QueryParamsStore::new();
981 let result = build_url(base_url, &path, &mut query_params, true);
982
983 assert!(result.is_err());
984 }
985
986 #[test]
987 fn it_should_copy_path_to_url_returned_when_unrestricted() {
988 let base_url = "http://example.com".parse::<Url>().unwrap();
989 let path = "/users";
990 let mut query_params = QueryParamsStore::new();
991 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
992
993 assert_eq!("http://example.com/users", result.as_str());
994 assert!(query_params.is_empty());
995 }
996
997 #[test]
998 fn it_should_copy_all_query_params_to_store_when_unrestricted() {
999 let base_url = "http://example.com?base=aaa".parse::<Url>().unwrap();
1000 let path = "/users?path=bbb&path-flag";
1001 let mut query_params = QueryParamsStore::new();
1002 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1003
1004 assert_eq!("http://example.com/users", result.as_str());
1005 assert_eq!("base=aaa&path=bbb&path-flag", query_params.to_string());
1006 }
1007
1008 #[test]
1009 fn it_should_copy_host_like_a_path_when_unrestricted() {
1010 let base_url = "http://example.com".parse::<Url>().unwrap();
1011 let path = "google.com";
1012 let mut query_params = QueryParamsStore::new();
1013 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1014
1015 assert_eq!("http://example.com/google.com", result.as_str());
1016 assert!(query_params.is_empty());
1017 }
1018
1019 #[test]
1020 fn it_should_copy_host_like_a_path_when_restricted() {
1021 let base_url = "http://example.com".parse::<Url>().unwrap();
1022 let path = "google.com";
1023 let mut query_params = QueryParamsStore::new();
1024 let result = build_url(base_url, &path, &mut query_params, true).unwrap();
1025
1026 assert_eq!("http://example.com/google.com", result.as_str());
1027 assert!(query_params.is_empty());
1028 }
1029
1030 #[test]
1031 fn it_should_replace_url_when_unrestricted() {
1032 let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
1033 let path = "ftp://google.com:123/users.csv?limit=456";
1034 let mut query_params = QueryParamsStore::new();
1035 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1036
1037 assert_eq!("ftp://google.com:123/users.csv", result.as_str());
1038 assert_eq!("limit=456", query_params.to_string());
1039 }
1040
1041 #[test]
1042 fn it_should_allow_different_scheme_when_unrestricted() {
1043 let base_url = "http://example.com".parse::<Url>().unwrap();
1044 let path = "ftp://example.com";
1045 let mut query_params = QueryParamsStore::new();
1046 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1047
1048 assert_eq!("ftp://example.com/", result.as_str());
1049 }
1050
1051 #[test]
1052 fn it_should_allow_different_host_when_unrestricted() {
1053 let base_url = "http://example.com".parse::<Url>().unwrap();
1054 let path = "http://google.com";
1055 let mut query_params = QueryParamsStore::new();
1056 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1057
1058 assert_eq!("http://google.com/", result.as_str());
1059 }
1060
1061 #[test]
1062 fn it_should_allow_different_port_when_unrestricted() {
1063 let base_url = "http://example.com:123".parse::<Url>().unwrap();
1064 let path = "http://example.com:456";
1065 let mut query_params = QueryParamsStore::new();
1066 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1067
1068 assert_eq!("http://example.com:456/", result.as_str());
1069 }
1070
1071 #[test]
1072 fn it_should_allow_same_host_port_when_unrestricted() {
1073 let base_url = "http://example.com:123".parse::<Url>().unwrap();
1074 let path = "http://example.com:123";
1075 let mut query_params = QueryParamsStore::new();
1076 let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1077
1078 assert_eq!("http://example.com:123/", result.as_str());
1079 }
1080
1081 #[test]
1082 fn it_should_not_allow_different_scheme_when_restricted() {
1083 let base_url = "http://example.com".parse::<Url>().unwrap();
1084 let path = "ftp://example.com";
1085 let mut query_params = QueryParamsStore::new();
1086 let result = build_url(base_url, &path, &mut query_params, true);
1087
1088 assert!(result.is_err());
1089 }
1090
1091 #[test]
1092 fn it_should_not_allow_different_host_when_restricted() {
1093 let base_url = "http://example.com".parse::<Url>().unwrap();
1094 let path = "http://google.com";
1095 let mut query_params = QueryParamsStore::new();
1096 let result = build_url(base_url, &path, &mut query_params, true);
1097
1098 assert!(result.is_err());
1099 }
1100
1101 #[test]
1102 fn it_should_not_allow_different_port_when_restricted() {
1103 let base_url = "http://example.com:123".parse::<Url>().unwrap();
1104 let path = "http://example.com:456";
1105 let mut query_params = QueryParamsStore::new();
1106 let result = build_url(base_url, &path, &mut query_params, true);
1107
1108 assert!(result.is_err());
1109 }
1110
1111 #[test]
1112 fn it_should_allow_same_host_port_when_restricted() {
1113 let base_url = "http://example.com:123".parse::<Url>().unwrap();
1114 let path = "http://example.com:123";
1115 let mut query_params = QueryParamsStore::new();
1116 let result = build_url(base_url, &path, &mut query_params, true).unwrap();
1117
1118 assert_eq!("http://example.com:123/", result.as_str());
1119 }
1120}
1121
1122#[cfg(test)]
1123mod test_new {
1124 use axum::Router;
1125 use axum::routing::get;
1126 use std::net::SocketAddr;
1127
1128 use crate::TestServer;
1129
1130 async fn get_ping() -> &'static str {
1131 "pong!"
1132 }
1133
1134 #[tokio::test]
1135 async fn it_should_run_into_make_into_service_with_connect_info_by_default() {
1136 let app = Router::new()
1138 .route("/ping", get(get_ping))
1139 .into_make_service_with_connect_info::<SocketAddr>();
1140
1141 let server = TestServer::new(app).expect("Should create test server");
1143
1144 server.get(&"/ping").await.assert_text(&"pong!");
1146 }
1147}
1148
1149#[cfg(test)]
1150mod test_get {
1151 use super::*;
1152
1153 use axum::Router;
1154 use axum::routing::get;
1155 use reserve_port::ReservedSocketAddr;
1156
1157 async fn get_ping() -> &'static str {
1158 "pong!"
1159 }
1160
1161 #[tokio::test]
1162 async fn it_should_get_using_relative_path_with_slash() {
1163 let app = Router::new().route("/ping", get(get_ping));
1164 let server = TestServer::new(app).expect("Should create test server");
1165
1166 server.get(&"/ping").await.assert_text(&"pong!");
1168 }
1169
1170 #[tokio::test]
1171 async fn it_should_get_using_relative_path_without_slash() {
1172 let app = Router::new().route("/ping", get(get_ping));
1173 let server = TestServer::new(app).expect("Should create test server");
1174
1175 server.get(&"ping").await.assert_text(&"pong!");
1177 }
1178
1179 #[tokio::test]
1180 async fn it_should_get_using_absolute_path() {
1181 let app = Router::new().route("/ping", get(get_ping));
1183
1184 let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1186 let ip = reserved_address.ip();
1187 let port = reserved_address.port();
1188
1189 let server = TestServer::builder()
1191 .http_transport_with_ip_port(Some(ip), Some(port))
1192 .build(app)
1193 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1194 .unwrap();
1195
1196 let absolute_url = format!("http://{ip}:{port}/ping");
1198 let response = server.get(&absolute_url).await;
1199
1200 response.assert_text(&"pong!");
1201 let request_path = response.request_url();
1202 assert_eq!(request_path.to_string(), format!("http://{ip}:{port}/ping"));
1203 }
1204
1205 #[tokio::test]
1206 async fn it_should_get_using_absolute_path_and_restricted_if_path_is_for_server() {
1207 let app = Router::new().route("/ping", get(get_ping));
1209
1210 let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1212 let ip = reserved_address.ip();
1213 let port = reserved_address.port();
1214
1215 let server = TestServer::builder()
1217 .http_transport_with_ip_port(Some(ip), Some(port))
1218 .restrict_requests_with_http_schema() .build(app)
1220 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1221 .unwrap();
1222
1223 let absolute_url = format!("http://{ip}:{port}/ping");
1225 let response = server.get(&absolute_url).await;
1226
1227 response.assert_text(&"pong!");
1228 let request_path = response.request_url();
1229 assert_eq!(request_path.to_string(), format!("http://{ip}:{port}/ping"));
1230 }
1231
1232 #[tokio::test]
1233 #[should_panic]
1234 async fn it_should_not_get_using_absolute_path_if_restricted_and_different_port() {
1235 let app = Router::new().route("/ping", get(get_ping));
1237
1238 let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1240 let ip = reserved_address.ip();
1241 let mut port = reserved_address.port();
1242
1243 let server = TestServer::builder()
1245 .http_transport_with_ip_port(Some(ip), Some(port))
1246 .restrict_requests_with_http_schema() .build(app)
1248 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1249 .unwrap();
1250
1251 port += 1; let absolute_url = format!("http://{ip}:{port}/ping");
1254 server.get(&absolute_url).await;
1255 }
1256
1257 #[tokio::test]
1258 async fn it_should_work_in_parallel() {
1259 let app = Router::new().route("/ping", get(get_ping));
1260 let server = TestServer::new(app).expect("Should create test server");
1261
1262 let future1 = async { server.get("/ping").await };
1263 let future2 = async { server.get("/ping").await };
1264 let (r1, r2) = tokio::join!(future1, future2);
1265
1266 assert_eq!(r1.text(), r2.text());
1267 }
1268
1269 #[tokio::test]
1270 async fn it_should_work_in_parallel_with_sleeping_requests() {
1271 let app = axum::Router::new().route(
1272 &"/slow",
1273 axum::routing::get(|| async {
1274 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1275 "hello!"
1276 }),
1277 );
1278
1279 let server = TestServer::new(app).expect("Should create test server");
1280
1281 let future1 = async { server.get("/slow").await };
1282 let future2 = async { server.get("/slow").await };
1283 let (r1, r2) = tokio::join!(future1, future2);
1284
1285 assert_eq!(r1.text(), r2.text());
1286 }
1287}
1288
1289#[cfg(feature = "reqwest")]
1290#[cfg(test)]
1291mod test_reqwest_get {
1292 use super::*;
1293
1294 use axum::Router;
1295 use axum::routing::get;
1296
1297 async fn get_ping() -> &'static str {
1298 "pong!"
1299 }
1300
1301 #[tokio::test]
1302 async fn it_should_get_using_relative_path_with_slash() {
1303 let app = Router::new().route("/ping", get(get_ping));
1304 let server = TestServer::builder()
1305 .http_transport()
1306 .build(app)
1307 .expect("Should create test server");
1308
1309 let response = server
1310 .reqwest_get(&"/ping")
1311 .send()
1312 .await
1313 .unwrap()
1314 .text()
1315 .await
1316 .unwrap();
1317
1318 assert_eq!(response, "pong!");
1319 }
1320}
1321
1322#[cfg(feature = "reqwest")]
1323#[cfg(test)]
1324mod test_reqwest_post {
1325 use super::*;
1326
1327 use axum::Json;
1328 use axum::Router;
1329 use axum::routing::post;
1330 use serde::Deserialize;
1331
1332 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1333 struct TestBody {
1334 number: u32,
1335 text: String,
1336 }
1337
1338 async fn post_json(Json(body): Json<TestBody>) -> Json<TestBody> {
1339 let response = TestBody {
1340 number: body.number * 2,
1341 text: format!("{}_plus_response", body.text),
1342 };
1343
1344 Json(response)
1345 }
1346
1347 #[tokio::test]
1348 async fn it_should_post_and_receive_json() {
1349 let app = Router::new().route("/json", post(post_json));
1350 let server = TestServer::builder()
1351 .http_transport()
1352 .build(app)
1353 .expect("Should create test server");
1354
1355 let response = server
1356 .reqwest_post(&"/json")
1357 .json(&TestBody {
1358 number: 111,
1359 text: format!("request"),
1360 })
1361 .send()
1362 .await
1363 .unwrap()
1364 .json::<TestBody>()
1365 .await
1366 .unwrap();
1367
1368 assert_eq!(
1369 response,
1370 TestBody {
1371 number: 222,
1372 text: format!("request_plus_response"),
1373 }
1374 );
1375 }
1376}
1377
1378#[cfg(test)]
1379mod test_server_address {
1380 use super::*;
1381
1382 use axum::Router;
1383 use local_ip_address::local_ip;
1384 use regex::Regex;
1385 use reserve_port::ReservedPort;
1386
1387 #[tokio::test]
1388 async fn it_should_return_address_used_from_config() {
1389 let reserved_port = ReservedPort::random().unwrap();
1390 let ip = local_ip().unwrap();
1391 let port = reserved_port.port();
1392
1393 let app = Router::new();
1395 let server = TestServer::builder()
1396 .http_transport_with_ip_port(Some(ip), Some(port))
1397 .build(app)
1398 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1399 .unwrap();
1400
1401 let expected_ip_port = format!("http://{}:{}/", ip, reserved_port.port());
1402 assert_eq!(
1403 server.server_address().unwrap().to_string(),
1404 expected_ip_port
1405 );
1406 }
1407
1408 #[tokio::test]
1409 async fn it_should_return_default_address_without_ending_slash() {
1410 let app = Router::new();
1411 let server = TestServer::builder()
1412 .http_transport()
1413 .build(app)
1414 .expect("Should create test server");
1415
1416 let address_regex = Regex::new("^http://127\\.0\\.0\\.1:[0-9]+/$").unwrap();
1417 let is_match = address_regex.is_match(&server.server_address().unwrap().to_string());
1418 assert!(is_match);
1419 }
1420
1421 #[tokio::test]
1422 async fn it_should_return_none_on_mock_transport() {
1423 let app = Router::new();
1424 let server = TestServer::builder()
1425 .mock_transport()
1426 .build(app)
1427 .expect("Should create test server");
1428
1429 assert!(server.server_address().is_none());
1430 }
1431}
1432
1433#[cfg(test)]
1434mod test_server_url {
1435 use super::*;
1436
1437 use axum::Router;
1438 use local_ip_address::local_ip;
1439 use regex::Regex;
1440 use reserve_port::ReservedPort;
1441
1442 #[tokio::test]
1443 async fn it_should_return_address_with_url_on_http_ip_port() {
1444 let reserved_port = ReservedPort::random().unwrap();
1445 let ip = local_ip().unwrap();
1446 let port = reserved_port.port();
1447
1448 let app = Router::new();
1450 let server = TestServer::builder()
1451 .http_transport_with_ip_port(Some(ip), Some(port))
1452 .build(app)
1453 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1454 .unwrap();
1455
1456 let expected_ip_port_url = format!("http://{}:{}/users", ip, reserved_port.port());
1457 let absolute_url = server.server_url("/users").unwrap().to_string();
1458 assert_eq!(absolute_url, expected_ip_port_url);
1459 }
1460
1461 #[tokio::test]
1462 async fn it_should_return_address_with_url_on_random_http() {
1463 let app = Router::new();
1464 let server = TestServer::builder()
1465 .http_transport()
1466 .build(app)
1467 .expect("Should create test server");
1468
1469 let address_regex =
1470 Regex::new("^http://127\\.0\\.0\\.1:[0-9]+/users/123\\?filter=enabled$").unwrap();
1471 let absolute_url = &server
1472 .server_url(&"/users/123?filter=enabled")
1473 .unwrap()
1474 .to_string();
1475
1476 let is_match = address_regex.is_match(absolute_url);
1477 assert!(is_match);
1478 }
1479
1480 #[tokio::test]
1481 async fn it_should_error_on_mock_transport() {
1482 let app = Router::new();
1484 let server = TestServer::builder()
1485 .mock_transport()
1486 .build(app)
1487 .expect("Should create test server");
1488
1489 let result = server.server_url("/users");
1490 assert!(result.is_err());
1491 }
1492
1493 #[tokio::test]
1494 async fn it_should_include_path_query_params() {
1495 let reserved_port = ReservedPort::random().unwrap();
1496 let ip = local_ip().unwrap();
1497 let port = reserved_port.port();
1498
1499 let app = Router::new();
1501 let server = TestServer::builder()
1502 .http_transport_with_ip_port(Some(ip), Some(port))
1503 .build(app)
1504 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1505 .unwrap();
1506
1507 let expected_url = format!(
1508 "http://{}:{}/users?filter=enabled",
1509 ip,
1510 reserved_port.port()
1511 );
1512 let received_url = server
1513 .server_url("/users?filter=enabled")
1514 .unwrap()
1515 .to_string();
1516
1517 assert_eq!(received_url, expected_url);
1518 }
1519
1520 #[tokio::test]
1521 async fn it_should_include_server_query_params() {
1522 let reserved_port = ReservedPort::random().unwrap();
1523 let ip = local_ip().unwrap();
1524 let port = reserved_port.port();
1525
1526 let app = Router::new();
1528 let mut server = TestServer::builder()
1529 .http_transport_with_ip_port(Some(ip), Some(port))
1530 .build(app)
1531 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1532 .unwrap();
1533
1534 server.add_query_param("filter", "enabled");
1535
1536 let expected_url = format!(
1537 "http://{}:{}/users?filter=enabled",
1538 ip,
1539 reserved_port.port()
1540 );
1541 let received_url = server.server_url("/users").unwrap().to_string();
1542
1543 assert_eq!(received_url, expected_url);
1544 }
1545
1546 #[tokio::test]
1547 async fn it_should_include_server_and_path_query_params() {
1548 let reserved_port = ReservedPort::random().unwrap();
1549 let ip = local_ip().unwrap();
1550 let port = reserved_port.port();
1551
1552 let app = Router::new();
1554 let mut server = TestServer::builder()
1555 .http_transport_with_ip_port(Some(ip), Some(port))
1556 .build(app)
1557 .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1558 .unwrap();
1559
1560 server.add_query_param("filter", "enabled");
1561
1562 let expected_url = format!(
1563 "http://{}:{}/users?filter=enabled&animal=donkeys",
1564 ip,
1565 reserved_port.port()
1566 );
1567 let received_url = server
1568 .server_url("/users?animal=donkeys")
1569 .unwrap()
1570 .to_string();
1571
1572 assert_eq!(received_url, expected_url);
1573 }
1574}
1575
1576#[cfg(test)]
1577mod test_add_cookie {
1578 use crate::TestServer;
1579
1580 use axum::Router;
1581 use axum::routing::get;
1582 use axum_extra::extract::cookie::CookieJar;
1583 use cookie::Cookie;
1584
1585 const TEST_COOKIE_NAME: &'static str = &"test-cookie";
1586
1587 async fn get_cookie(cookies: CookieJar) -> (CookieJar, String) {
1588 let cookie = cookies.get(&TEST_COOKIE_NAME);
1589 let cookie_value = cookie
1590 .map(|c| c.value().to_string())
1591 .unwrap_or_else(|| "cookie-not-found".to_string());
1592
1593 (cookies, cookie_value)
1594 }
1595
1596 #[tokio::test]
1597 async fn it_should_send_cookies_added_to_request() {
1598 let app = Router::new().route("/cookie", get(get_cookie));
1599 let mut server = TestServer::new(app).expect("Should create test server");
1600
1601 let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1602 server.add_cookie(cookie);
1603
1604 let response_text = server.get(&"/cookie").await.text();
1605 assert_eq!(response_text, "my-custom-cookie");
1606 }
1607}
1608
1609#[cfg(test)]
1610mod test_add_cookies {
1611 use crate::TestServer;
1612
1613 use axum::Router;
1614 use axum::routing::get;
1615 use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1616 use cookie::Cookie;
1617 use cookie::CookieJar;
1618
1619 async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1620 let mut all_cookies = cookies
1621 .iter()
1622 .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1623 .collect::<Vec<String>>();
1624 all_cookies.sort();
1625
1626 all_cookies.join(&", ")
1627 }
1628
1629 #[tokio::test]
1630 async fn it_should_send_all_cookies_added_by_jar() {
1631 let app = Router::new().route("/cookies", get(route_get_cookies));
1632 let mut server = TestServer::new(app).expect("Should create test server");
1633
1634 let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1636 let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1637 let mut cookie_jar = CookieJar::new();
1638 cookie_jar.add(cookie_1);
1639 cookie_jar.add(cookie_2);
1640
1641 server.add_cookies(cookie_jar);
1642
1643 server
1644 .get(&"/cookies")
1645 .await
1646 .assert_text("first-cookie=my-custom-cookie, second-cookie=other-cookie");
1647 }
1648}
1649
1650#[cfg(test)]
1651mod test_clear_cookies {
1652 use crate::TestServer;
1653
1654 use axum::Router;
1655 use axum::routing::get;
1656 use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1657 use cookie::Cookie;
1658 use cookie::CookieJar;
1659
1660 async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1661 let mut all_cookies = cookies
1662 .iter()
1663 .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1664 .collect::<Vec<String>>();
1665 all_cookies.sort();
1666
1667 all_cookies.join(&", ")
1668 }
1669
1670 #[tokio::test]
1671 async fn it_should_not_send_cookies_cleared() {
1672 let app = Router::new().route("/cookies", get(route_get_cookies));
1673 let mut server = TestServer::new(app).expect("Should create test server");
1674
1675 let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1676 let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1677 let mut cookie_jar = CookieJar::new();
1678 cookie_jar.add(cookie_1);
1679 cookie_jar.add(cookie_2);
1680
1681 server.add_cookies(cookie_jar);
1682
1683 server.clear_cookies();
1685
1686 server.get(&"/cookies").await.assert_text("");
1687 }
1688}
1689
1690#[cfg(test)]
1691mod test_add_header {
1692 use super::*;
1693 use crate::TestServer;
1694 use axum::Router;
1695 use axum::extract::FromRequestParts;
1696 use axum::routing::get;
1697 use http::HeaderName;
1698 use http::HeaderValue;
1699 use http::request::Parts;
1700 use hyper::StatusCode;
1701 use std::marker::Sync;
1702
1703 const TEST_HEADER_NAME: &'static str = &"test-header";
1704 const TEST_HEADER_CONTENT: &'static str = &"Test header content";
1705
1706 struct TestHeader(Vec<u8>);
1707
1708 impl<S: Sync> FromRequestParts<S> for TestHeader {
1709 type Rejection = (StatusCode, &'static str);
1710
1711 async fn from_request_parts(
1712 parts: &mut Parts,
1713 _state: &S,
1714 ) -> Result<TestHeader, Self::Rejection> {
1715 parts
1716 .headers
1717 .get(HeaderName::from_static(TEST_HEADER_NAME))
1718 .map(|v| TestHeader(v.as_bytes().to_vec()))
1719 .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
1720 }
1721 }
1722
1723 async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
1724 header
1725 }
1726
1727 #[tokio::test]
1728 async fn it_should_send_header_added_to_server() {
1729 let app = Router::new().route("/header", get(ping_header));
1731
1732 let mut server = TestServer::new(app).expect("Should create test server");
1734 server.add_header(
1735 HeaderName::from_static(TEST_HEADER_NAME),
1736 HeaderValue::from_static(TEST_HEADER_CONTENT),
1737 );
1738
1739 let response = server.get(&"/header").await;
1741
1742 response.assert_text(TEST_HEADER_CONTENT)
1744 }
1745}
1746
1747#[cfg(test)]
1748mod test_clear_headers {
1749 use super::*;
1750 use crate::TestServer;
1751 use axum::Router;
1752 use axum::extract::FromRequestParts;
1753 use axum::routing::get;
1754 use http::HeaderName;
1755 use http::HeaderValue;
1756 use http::request::Parts;
1757 use hyper::StatusCode;
1758 use std::marker::Sync;
1759
1760 const TEST_HEADER_NAME: &'static str = &"test-header";
1761 const TEST_HEADER_CONTENT: &'static str = &"Test header content";
1762
1763 struct TestHeader(Vec<u8>);
1764
1765 impl<S: Sync> FromRequestParts<S> for TestHeader {
1766 type Rejection = (StatusCode, &'static str);
1767
1768 async fn from_request_parts(
1769 parts: &mut Parts,
1770 _state: &S,
1771 ) -> Result<Self, Self::Rejection> {
1772 parts
1773 .headers
1774 .get(HeaderName::from_static(TEST_HEADER_NAME))
1775 .map(|v| TestHeader(v.as_bytes().to_vec()))
1776 .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
1777 }
1778 }
1779
1780 async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
1781 header
1782 }
1783
1784 #[tokio::test]
1785 async fn it_should_not_send_headers_cleared_by_server() {
1786 let app = Router::new().route("/header", get(ping_header));
1788
1789 let mut server = TestServer::new(app).expect("Should create test server");
1791 server.add_header(
1792 HeaderName::from_static(TEST_HEADER_NAME),
1793 HeaderValue::from_static(TEST_HEADER_CONTENT),
1794 );
1795 server.clear_headers();
1796
1797 let response = server.get(&"/header").await;
1799
1800 response.assert_status_bad_request();
1802 response.assert_text("Missing test header");
1803 }
1804}
1805
1806#[cfg(test)]
1807mod test_add_query_params {
1808 use axum::Router;
1809 use axum::extract::Query;
1810 use axum::routing::get;
1811
1812 use serde::Deserialize;
1813 use serde::Serialize;
1814 use serde_json::json;
1815
1816 use crate::TestServer;
1817
1818 #[derive(Debug, Deserialize, Serialize)]
1819 struct QueryParam {
1820 message: String,
1821 }
1822
1823 async fn get_query_param(Query(params): Query<QueryParam>) -> String {
1824 params.message
1825 }
1826
1827 #[derive(Debug, Deserialize, Serialize)]
1828 struct QueryParam2 {
1829 message: String,
1830 other: String,
1831 }
1832
1833 async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
1834 format!("{}-{}", params.message, params.other)
1835 }
1836
1837 #[tokio::test]
1838 async fn it_should_pass_up_query_params_from_serialization() {
1839 let app = Router::new().route("/query", get(get_query_param));
1841
1842 let mut server = TestServer::new(app).expect("Should create test server");
1844 server.add_query_params(QueryParam {
1845 message: "it works".to_string(),
1846 });
1847
1848 server.get(&"/query").await.assert_text(&"it works");
1850 }
1851
1852 #[tokio::test]
1853 async fn it_should_pass_up_query_params_from_pairs() {
1854 let app = Router::new().route("/query", get(get_query_param));
1856
1857 let mut server = TestServer::new(app).expect("Should create test server");
1859 server.add_query_params(&[("message", "it works")]);
1860
1861 server.get(&"/query").await.assert_text(&"it works");
1863 }
1864
1865 #[tokio::test]
1866 async fn it_should_pass_up_multiple_query_params_from_multiple_params() {
1867 let app = Router::new().route("/query-2", get(get_query_param_2));
1869
1870 let mut server = TestServer::new(app).expect("Should create test server");
1872 server.add_query_params(&[("message", "it works"), ("other", "yup")]);
1873
1874 server.get(&"/query-2").await.assert_text(&"it works-yup");
1876 }
1877
1878 #[tokio::test]
1879 async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
1880 let app = Router::new().route("/query-2", get(get_query_param_2));
1882
1883 let mut server = TestServer::new(app).expect("Should create test server");
1885 server.add_query_params(&[("message", "it works")]);
1886 server.add_query_params(&[("other", "yup")]);
1887
1888 server.get(&"/query-2").await.assert_text(&"it works-yup");
1890 }
1891
1892 #[tokio::test]
1893 async fn it_should_pass_up_multiple_query_params_from_json() {
1894 let app = Router::new().route("/query-2", get(get_query_param_2));
1896
1897 let mut server = TestServer::new(app).expect("Should create test server");
1899 server.add_query_params(json!({
1900 "message": "it works",
1901 "other": "yup"
1902 }));
1903
1904 server.get(&"/query-2").await.assert_text(&"it works-yup");
1906 }
1907}
1908
1909#[cfg(test)]
1910mod test_add_query_param {
1911 use axum::Router;
1912 use axum::extract::Query;
1913 use axum::routing::get;
1914
1915 use serde::Deserialize;
1916 use serde::Serialize;
1917
1918 use crate::TestServer;
1919
1920 #[derive(Debug, Deserialize, Serialize)]
1921 struct QueryParam {
1922 message: String,
1923 }
1924
1925 async fn get_query_param(Query(params): Query<QueryParam>) -> String {
1926 params.message
1927 }
1928
1929 #[derive(Debug, Deserialize, Serialize)]
1930 struct QueryParam2 {
1931 message: String,
1932 other: String,
1933 }
1934
1935 async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
1936 format!("{}-{}", params.message, params.other)
1937 }
1938
1939 #[tokio::test]
1940 async fn it_should_pass_up_query_params_from_pairs() {
1941 let app = Router::new().route("/query", get(get_query_param));
1943
1944 let mut server = TestServer::new(app).expect("Should create test server");
1946 server.add_query_param("message", "it works");
1947
1948 server.get(&"/query").await.assert_text(&"it works");
1950 }
1951
1952 #[tokio::test]
1953 async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
1954 let app = Router::new().route("/query-2", get(get_query_param_2));
1956
1957 let mut server = TestServer::new(app).expect("Should create test server");
1959 server.add_query_param("message", "it works");
1960 server.add_query_param("other", "yup");
1961
1962 server.get(&"/query-2").await.assert_text(&"it works-yup");
1964 }
1965
1966 #[tokio::test]
1967 async fn it_should_pass_up_multiple_query_params_from_calls_across_server_and_request() {
1968 let app = Router::new().route("/query-2", get(get_query_param_2));
1970
1971 let mut server = TestServer::new(app).expect("Should create test server");
1973 server.add_query_param("message", "it works");
1974
1975 server
1977 .get(&"/query-2")
1978 .add_query_param("other", "yup")
1979 .await
1980 .assert_text(&"it works-yup");
1981 }
1982}
1983
1984#[cfg(test)]
1985mod test_add_raw_query_param {
1986 use axum::Router;
1987 use axum::extract::Query as AxumStdQuery;
1988 use axum::routing::get;
1989 use axum_extra::extract::Query as AxumExtraQuery;
1990 use serde::Deserialize;
1991 use serde::Serialize;
1992 use std::fmt::Write;
1993
1994 use crate::TestServer;
1995
1996 #[derive(Debug, Deserialize, Serialize)]
1997 struct QueryParam {
1998 message: String,
1999 }
2000
2001 async fn get_query_param(AxumStdQuery(params): AxumStdQuery<QueryParam>) -> String {
2002 params.message
2003 }
2004
2005 #[derive(Debug, Deserialize, Serialize)]
2006 struct QueryParamExtra {
2007 #[serde(default)]
2008 items: Vec<String>,
2009
2010 #[serde(default, rename = "arrs[]")]
2011 arrs: Vec<String>,
2012 }
2013
2014 async fn get_query_param_extra(
2015 AxumExtraQuery(params): AxumExtraQuery<QueryParamExtra>,
2016 ) -> String {
2017 let mut output = String::new();
2018
2019 if params.items.len() > 0 {
2020 write!(output, "{}", params.items.join(", ")).unwrap();
2021 }
2022
2023 if params.arrs.len() > 0 {
2024 write!(output, "{}", params.arrs.join(", ")).unwrap();
2025 }
2026
2027 output
2028 }
2029
2030 fn build_app() -> Router {
2031 Router::new()
2032 .route("/query", get(get_query_param))
2033 .route("/query-extra", get(get_query_param_extra))
2034 }
2035
2036 #[tokio::test]
2037 async fn it_should_pass_up_query_param_as_is() {
2038 let mut server = TestServer::new(build_app()).expect("Should create test server");
2040 server.add_raw_query_param(&"message=it-works");
2041
2042 server.get(&"/query").await.assert_text(&"it-works");
2044 }
2045
2046 #[tokio::test]
2047 async fn it_should_pass_up_array_query_params_as_one_string() {
2048 let mut server = TestServer::new(build_app()).expect("Should create test server");
2050 server.add_raw_query_param(&"items=one&items=two&items=three");
2051
2052 server
2054 .get(&"/query-extra")
2055 .await
2056 .assert_text(&"one, two, three");
2057 }
2058
2059 #[tokio::test]
2060 async fn it_should_pass_up_array_query_params_as_multiple_params() {
2061 let mut server = TestServer::new(build_app()).expect("Should create test server");
2063 server.add_raw_query_param(&"arrs[]=one");
2064 server.add_raw_query_param(&"arrs[]=two");
2065 server.add_raw_query_param(&"arrs[]=three");
2066
2067 server
2069 .get(&"/query-extra")
2070 .await
2071 .assert_text(&"one, two, three");
2072 }
2073}
2074
2075#[cfg(test)]
2076mod test_clear_query_params {
2077 use axum::Router;
2078 use axum::extract::Query;
2079 use axum::routing::get;
2080
2081 use serde::Deserialize;
2082 use serde::Serialize;
2083
2084 use crate::TestServer;
2085
2086 #[derive(Debug, Deserialize, Serialize)]
2087 struct QueryParams {
2088 first: Option<String>,
2089 second: Option<String>,
2090 }
2091
2092 async fn get_query_params(Query(params): Query<QueryParams>) -> String {
2093 format!(
2094 "has first? {}, has second? {}",
2095 params.first.is_some(),
2096 params.second.is_some()
2097 )
2098 }
2099
2100 #[tokio::test]
2101 async fn it_should_clear_all_params_set() {
2102 let app = Router::new().route("/query", get(get_query_params));
2104
2105 let mut server = TestServer::new(app).expect("Should create test server");
2107 server.add_query_params(QueryParams {
2108 first: Some("first".to_string()),
2109 second: Some("second".to_string()),
2110 });
2111 server.clear_query_params();
2112
2113 server
2115 .get(&"/query")
2116 .await
2117 .assert_text(&"has first? false, has second? false");
2118 }
2119
2120 #[tokio::test]
2121 async fn it_should_clear_all_params_set_and_allow_replacement() {
2122 let app = Router::new().route("/query", get(get_query_params));
2124
2125 let mut server = TestServer::new(app).expect("Should create test server");
2127 server.add_query_params(QueryParams {
2128 first: Some("first".to_string()),
2129 second: Some("second".to_string()),
2130 });
2131 server.clear_query_params();
2132 server.add_query_params(QueryParams {
2133 first: Some("first".to_string()),
2134 second: Some("second".to_string()),
2135 });
2136
2137 server
2139 .get(&"/query")
2140 .await
2141 .assert_text(&"has first? true, has second? true");
2142 }
2143}
2144
2145#[cfg(test)]
2146mod test_expect_success_by_default {
2147 use super::*;
2148
2149 use axum::Router;
2150 use axum::routing::get;
2151
2152 #[tokio::test]
2153 async fn it_should_not_panic_by_default_if_accessing_404_route() {
2154 let app = Router::new();
2155 let server = TestServer::new(app).expect("Should create test server");
2156
2157 server.get(&"/some_unknown_route").await;
2158 }
2159
2160 #[tokio::test]
2161 async fn it_should_not_panic_by_default_if_accessing_200_route() {
2162 let app = Router::new().route("/known_route", get(|| async { "🦊🦊🦊" }));
2163 let server = TestServer::new(app).expect("Should create test server");
2164
2165 server.get(&"/known_route").await;
2166 }
2167
2168 #[tokio::test]
2169 #[should_panic]
2170 async fn it_should_panic_by_default_if_accessing_404_route_and_expect_success_on() {
2171 let app = Router::new();
2172 let server = TestServer::builder()
2173 .expect_success_by_default()
2174 .build(app)
2175 .expect("Should create test server");
2176
2177 server.get(&"/some_unknown_route").await;
2178 }
2179
2180 #[tokio::test]
2181 async fn it_should_not_panic_by_default_if_accessing_200_route_and_expect_success_on() {
2182 let app = Router::new().route("/known_route", get(|| async { "🦊🦊🦊" }));
2183 let server = TestServer::builder()
2184 .expect_success_by_default()
2185 .build(app)
2186 .expect("Should create test server");
2187
2188 server.get(&"/known_route").await;
2189 }
2190}
2191
2192#[cfg(test)]
2193mod test_content_type {
2194 use super::*;
2195
2196 use axum::Router;
2197 use axum::routing::get;
2198 use http::HeaderMap;
2199 use http::header::CONTENT_TYPE;
2200
2201 async fn get_content_type(headers: HeaderMap) -> String {
2202 headers
2203 .get(CONTENT_TYPE)
2204 .map(|h| h.to_str().unwrap().to_string())
2205 .unwrap_or_else(|| "".to_string())
2206 }
2207
2208 #[tokio::test]
2209 async fn it_should_default_to_server_content_type_when_present() {
2210 let app = Router::new().route("/content_type", get(get_content_type));
2212
2213 let server = TestServer::builder()
2215 .default_content_type("text/plain")
2216 .build(app)
2217 .expect("Should create test server");
2218
2219 let text = server.get(&"/content_type").await.text();
2221
2222 assert_eq!(text, "text/plain");
2223 }
2224}
2225
2226#[cfg(test)]
2227mod test_expect_success {
2228 use crate::TestServer;
2229 use axum::Router;
2230 use axum::routing::get;
2231 use http::StatusCode;
2232
2233 #[tokio::test]
2234 async fn it_should_not_panic_if_success_is_returned() {
2235 async fn get_ping() -> &'static str {
2236 "pong!"
2237 }
2238
2239 let app = Router::new().route("/ping", get(get_ping));
2241
2242 let mut server = TestServer::new(app).expect("Should create test server");
2244 server.expect_success();
2245
2246 server.get(&"/ping").await;
2248 }
2249
2250 #[tokio::test]
2251 async fn it_should_not_panic_on_other_2xx_status_code() {
2252 async fn get_accepted() -> StatusCode {
2253 StatusCode::ACCEPTED
2254 }
2255
2256 let app = Router::new().route("/accepted", get(get_accepted));
2258
2259 let mut server = TestServer::new(app).expect("Should create test server");
2261 server.expect_success();
2262
2263 server.get(&"/accepted").await;
2265 }
2266
2267 #[tokio::test]
2268 #[should_panic]
2269 async fn it_should_panic_on_404() {
2270 let app = Router::new();
2272
2273 let mut server = TestServer::new(app).expect("Should create test server");
2275 server.expect_success();
2276
2277 server.get(&"/some_unknown_route").await;
2279 }
2280}
2281
2282#[cfg(test)]
2283mod test_expect_failure {
2284 use crate::TestServer;
2285 use axum::Router;
2286 use axum::routing::get;
2287 use http::StatusCode;
2288
2289 #[tokio::test]
2290 async fn it_should_not_panic_if_expect_failure_on_404() {
2291 let app = Router::new();
2293
2294 let mut server = TestServer::new(app).expect("Should create test server");
2296 server.expect_failure();
2297
2298 server.get(&"/some_unknown_route").await;
2300 }
2301
2302 #[tokio::test]
2303 #[should_panic]
2304 async fn it_should_panic_if_success_is_returned() {
2305 async fn get_ping() -> &'static str {
2306 "pong!"
2307 }
2308
2309 let app = Router::new().route("/ping", get(get_ping));
2311
2312 let mut server = TestServer::new(app).expect("Should create test server");
2314 server.expect_failure();
2315
2316 server.get(&"/ping").await;
2318 }
2319
2320 #[tokio::test]
2321 #[should_panic]
2322 async fn it_should_panic_on_other_2xx_status_code() {
2323 async fn get_accepted() -> StatusCode {
2324 StatusCode::ACCEPTED
2325 }
2326
2327 let app = Router::new().route("/accepted", get(get_accepted));
2329
2330 let mut server = TestServer::new(app).expect("Should create test server");
2332 server.expect_failure();
2333
2334 server.get(&"/accepted").await;
2336 }
2337}
2338
2339#[cfg(test)]
2340mod test_scheme {
2341 use axum::Router;
2342 use axum::extract::Request;
2343 use axum::routing::get;
2344
2345 use crate::TestServer;
2346
2347 async fn route_get_scheme(request: Request) -> String {
2348 request.uri().scheme_str().unwrap().to_string()
2349 }
2350
2351 #[tokio::test]
2352 async fn it_should_return_http_by_default() {
2353 let router = Router::new().route("/scheme", get(route_get_scheme));
2354 let server = TestServer::builder().build(router).unwrap();
2355
2356 server.get("/scheme").await.assert_text("http");
2357 }
2358
2359 #[tokio::test]
2360 async fn it_should_return_https_across_multiple_requests_when_set() {
2361 let router = Router::new().route("/scheme", get(route_get_scheme));
2362 let mut server = TestServer::builder().build(router).unwrap();
2363 server.scheme(&"https");
2364
2365 server.get("/scheme").await.assert_text("https");
2366 }
2367}
2368
2369#[cfg(feature = "typed-routing")]
2370#[cfg(test)]
2371mod test_typed_get {
2372 use super::*;
2373
2374 use axum::Router;
2375 use axum_extra::routing::RouterExt;
2376 use serde::Deserialize;
2377
2378 #[derive(TypedPath, Deserialize)]
2379 #[typed_path("/path/{id}")]
2380 struct TestingPath {
2381 id: u32,
2382 }
2383
2384 async fn route_get(TestingPath { id }: TestingPath) -> String {
2385 format!("get {id}")
2386 }
2387
2388 fn new_app() -> Router {
2389 Router::new().typed_get(route_get)
2390 }
2391
2392 #[tokio::test]
2393 async fn it_should_send_get() {
2394 let server = TestServer::new(new_app()).unwrap();
2395
2396 server
2397 .typed_get(&TestingPath { id: 123 })
2398 .await
2399 .assert_text("get 123");
2400 }
2401}
2402
2403#[cfg(feature = "typed-routing")]
2404#[cfg(test)]
2405mod test_typed_post {
2406 use super::*;
2407
2408 use axum::Router;
2409 use axum_extra::routing::RouterExt;
2410 use serde::Deserialize;
2411
2412 #[derive(TypedPath, Deserialize)]
2413 #[typed_path("/path/{id}")]
2414 struct TestingPath {
2415 id: u32,
2416 }
2417
2418 async fn route_post(TestingPath { id }: TestingPath) -> String {
2419 format!("post {id}")
2420 }
2421
2422 fn new_app() -> Router {
2423 Router::new().typed_post(route_post)
2424 }
2425
2426 #[tokio::test]
2427 async fn it_should_send_post() {
2428 let server = TestServer::new(new_app()).unwrap();
2429
2430 server
2431 .typed_post(&TestingPath { id: 123 })
2432 .await
2433 .assert_text("post 123");
2434 }
2435}
2436
2437#[cfg(feature = "typed-routing")]
2438#[cfg(test)]
2439mod test_typed_patch {
2440 use super::*;
2441
2442 use axum::Router;
2443 use axum_extra::routing::RouterExt;
2444 use serde::Deserialize;
2445
2446 #[derive(TypedPath, Deserialize)]
2447 #[typed_path("/path/{id}")]
2448 struct TestingPath {
2449 id: u32,
2450 }
2451
2452 async fn route_patch(TestingPath { id }: TestingPath) -> String {
2453 format!("patch {id}")
2454 }
2455
2456 fn new_app() -> Router {
2457 Router::new().typed_patch(route_patch)
2458 }
2459
2460 #[tokio::test]
2461 async fn it_should_send_patch() {
2462 let server = TestServer::new(new_app()).unwrap();
2463
2464 server
2465 .typed_patch(&TestingPath { id: 123 })
2466 .await
2467 .assert_text("patch 123");
2468 }
2469}
2470
2471#[cfg(feature = "typed-routing")]
2472#[cfg(test)]
2473mod test_typed_put {
2474 use super::*;
2475
2476 use axum::Router;
2477 use axum_extra::routing::RouterExt;
2478 use serde::Deserialize;
2479
2480 #[derive(TypedPath, Deserialize)]
2481 #[typed_path("/path/{id}")]
2482 struct TestingPath {
2483 id: u32,
2484 }
2485
2486 async fn route_put(TestingPath { id }: TestingPath) -> String {
2487 format!("put {id}")
2488 }
2489
2490 fn new_app() -> Router {
2491 Router::new().typed_put(route_put)
2492 }
2493
2494 #[tokio::test]
2495 async fn it_should_send_put() {
2496 let server = TestServer::new(new_app()).unwrap();
2497
2498 server
2499 .typed_put(&TestingPath { id: 123 })
2500 .await
2501 .assert_text("put 123");
2502 }
2503}
2504
2505#[cfg(feature = "typed-routing")]
2506#[cfg(test)]
2507mod test_typed_delete {
2508 use super::*;
2509
2510 use axum::Router;
2511 use axum_extra::routing::RouterExt;
2512 use serde::Deserialize;
2513
2514 #[derive(TypedPath, Deserialize)]
2515 #[typed_path("/path/{id}")]
2516 struct TestingPath {
2517 id: u32,
2518 }
2519
2520 async fn route_delete(TestingPath { id }: TestingPath) -> String {
2521 format!("delete {id}")
2522 }
2523
2524 fn new_app() -> Router {
2525 Router::new().typed_delete(route_delete)
2526 }
2527
2528 #[tokio::test]
2529 async fn it_should_send_delete() {
2530 let server = TestServer::new(new_app()).unwrap();
2531
2532 server
2533 .typed_delete(&TestingPath { id: 123 })
2534 .await
2535 .assert_text("delete 123");
2536 }
2537}
2538
2539#[cfg(feature = "typed-routing")]
2540#[cfg(test)]
2541mod test_typed_method {
2542 use super::*;
2543
2544 use axum::Router;
2545 use axum_extra::routing::RouterExt;
2546 use serde::Deserialize;
2547
2548 #[derive(TypedPath, Deserialize)]
2549 #[typed_path("/path/{id}")]
2550 struct TestingPath {
2551 id: u32,
2552 }
2553
2554 async fn route_get(TestingPath { id }: TestingPath) -> String {
2555 format!("get {id}")
2556 }
2557
2558 async fn route_post(TestingPath { id }: TestingPath) -> String {
2559 format!("post {id}")
2560 }
2561
2562 async fn route_patch(TestingPath { id }: TestingPath) -> String {
2563 format!("patch {id}")
2564 }
2565
2566 async fn route_put(TestingPath { id }: TestingPath) -> String {
2567 format!("put {id}")
2568 }
2569
2570 async fn route_delete(TestingPath { id }: TestingPath) -> String {
2571 format!("delete {id}")
2572 }
2573
2574 fn new_app() -> Router {
2575 Router::new()
2576 .typed_get(route_get)
2577 .typed_post(route_post)
2578 .typed_patch(route_patch)
2579 .typed_put(route_put)
2580 .typed_delete(route_delete)
2581 }
2582
2583 #[tokio::test]
2584 async fn it_should_send_get() {
2585 let server = TestServer::new(new_app()).unwrap();
2586
2587 server
2588 .typed_method(Method::GET, &TestingPath { id: 123 })
2589 .await
2590 .assert_text("get 123");
2591 }
2592
2593 #[tokio::test]
2594 async fn it_should_send_post() {
2595 let server = TestServer::new(new_app()).unwrap();
2596
2597 server
2598 .typed_method(Method::POST, &TestingPath { id: 123 })
2599 .await
2600 .assert_text("post 123");
2601 }
2602
2603 #[tokio::test]
2604 async fn it_should_send_patch() {
2605 let server = TestServer::new(new_app()).unwrap();
2606
2607 server
2608 .typed_method(Method::PATCH, &TestingPath { id: 123 })
2609 .await
2610 .assert_text("patch 123");
2611 }
2612
2613 #[tokio::test]
2614 async fn it_should_send_put() {
2615 let server = TestServer::new(new_app()).unwrap();
2616
2617 server
2618 .typed_method(Method::PUT, &TestingPath { id: 123 })
2619 .await
2620 .assert_text("put 123");
2621 }
2622
2623 #[tokio::test]
2624 async fn it_should_send_delete() {
2625 let server = TestServer::new(new_app()).unwrap();
2626
2627 server
2628 .typed_method(Method::DELETE, &TestingPath { id: 123 })
2629 .await
2630 .assert_text("delete 123");
2631 }
2632}
2633
2634#[cfg(test)]
2635mod test_sync {
2636 use super::*;
2637 use axum::Router;
2638 use axum::routing::get;
2639 use std::cell::OnceCell;
2640
2641 #[tokio::test]
2642 async fn it_should_be_able_to_be_in_one_cell() {
2643 let cell: OnceCell<TestServer> = OnceCell::new();
2644 let server = cell.get_or_init(|| {
2645 async fn route_get() -> &'static str {
2646 "it works"
2647 }
2648
2649 let router = Router::new().route("/test", get(route_get));
2650
2651 TestServer::new(router).unwrap()
2652 });
2653
2654 server.get("/test").await.assert_text("it works");
2655 }
2656}
2657
2658#[cfg(test)]
2659mod test_is_running {
2660 use super::*;
2661 use crate::util::new_random_tokio_tcp_listener;
2662 use axum::Router;
2663 use axum::routing::IntoMakeService;
2664 use axum::routing::get;
2665 use axum::serve;
2666 use std::time::Duration;
2667 use tokio::sync::Notify;
2668 use tokio::time::sleep;
2669
2670 async fn get_ping() -> &'static str {
2671 "pong!"
2672 }
2673
2674 #[tokio::test]
2675 #[should_panic]
2676 async fn it_should_panic_when_run_with_mock_http() {
2677 let shutdown_notification = Arc::new(Notify::new());
2678 let waiting_notification = shutdown_notification.clone();
2679
2680 let app: IntoMakeService<Router> = Router::new()
2682 .route("/ping", get(get_ping))
2683 .into_make_service();
2684 let port = new_random_tokio_tcp_listener().unwrap();
2685 let application = serve(port, app)
2686 .with_graceful_shutdown(async move { waiting_notification.notified().await });
2687
2688 let server = TestServer::builder()
2690 .build(application)
2691 .expect("Should create test server");
2692
2693 server.get("/ping").await.assert_status_ok();
2694 assert!(server.is_running());
2695
2696 shutdown_notification.notify_one();
2697 sleep(Duration::from_millis(10)).await;
2698
2699 assert!(!server.is_running());
2700 server.get("/ping").await.assert_status_ok();
2701 }
2702}