axum_test/
test_server.rs

1use anyhow::anyhow;
2use anyhow::Context;
3use anyhow::Result;
4use cookie::Cookie;
5use cookie::CookieJar;
6use http::HeaderName;
7use http::HeaderValue;
8use http::Method;
9use http::Uri;
10use serde::Serialize;
11use std::fmt::Debug;
12use std::sync::Arc;
13use std::sync::Mutex;
14use url::Url;
15
16#[cfg(feature = "typed-routing")]
17use axum_extra::routing::TypedPath;
18
19#[cfg(feature = "reqwest")]
20use crate::transport_layer::TransportLayerType;
21#[cfg(feature = "reqwest")]
22use reqwest::Client;
23#[cfg(feature = "reqwest")]
24use reqwest::RequestBuilder;
25
26use crate::internals::ExpectedState;
27use crate::internals::QueryParamsStore;
28use crate::internals::RequestPathFormatter;
29use crate::transport_layer::IntoTransportLayer;
30use crate::transport_layer::TransportLayer;
31use crate::transport_layer::TransportLayerBuilder;
32use crate::TestRequest;
33use crate::TestRequestConfig;
34use crate::TestServerBuilder;
35use crate::TestServerConfig;
36use crate::Transport;
37
38mod server_shared_state;
39pub(crate) use self::server_shared_state::*;
40
41const DEFAULT_URL_ADDRESS: &str = "http://localhost";
42
43///
44/// The `TestServer` runs your Axum application,
45/// allowing you to make HTTP requests against it.
46///
47/// # Building
48///
49/// A `TestServer` can be used to run an [`axum::Router`], an [`::axum::routing::IntoMakeService`],
50/// a [`shuttle_axum::ShuttleAxum`], and others.
51///
52/// The most straight forward approach is to call [`crate::TestServer::new`],
53/// and pass in your application:
54///
55/// ```rust
56/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
57/// #
58/// use axum::Router;
59/// use axum::routing::get;
60///
61/// use axum_test::TestServer;
62///
63/// let app = Router::new()
64///     .route(&"/hello", get(|| async { "hello!" }));
65///
66/// let server = TestServer::new(app)?;
67/// #
68/// # Ok(())
69/// # }
70/// ```
71///
72/// # Requests
73///
74/// Requests are built by calling [`TestServer::get()`](crate::TestServer::get()),
75/// [`TestServer::post()`](crate::TestServer::post()), [`TestServer::put()`](crate::TestServer::put()),
76/// [`TestServer::delete()`](crate::TestServer::delete()), and [`TestServer::patch()`](crate::TestServer::patch()) methods.
77/// Each returns a [`TestRequest`](crate::TestRequest), which allows for customising the request content.
78///
79/// For example:
80///
81/// ```rust
82/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
83/// #
84/// use axum::Router;
85/// use axum::routing::get;
86///
87/// use axum_test::TestServer;
88///
89/// let app = Router::new()
90///     .route(&"/hello", get(|| async { "hello!" }));
91///
92/// let server = TestServer::new(app)?;
93///
94/// let response = server.get(&"/hello")
95///     .authorization_bearer("password12345")
96///     .add_header("x-custom-header", "custom-value")
97///     .await;
98///
99/// response.assert_text("hello!");
100/// #
101/// # Ok(())
102/// # }
103/// ```
104///
105/// Request methods also exist for using Axum Extra [`axum_extra::routing::TypedPath`],
106/// or for building Reqwest [`reqwest::RequestBuilder`]. See those methods for detauls.
107///
108/// # Customising
109///
110/// A `TestServer` can be built from a builder, by calling [`crate::TestServer::builder`],
111/// and customising settings. This allows one to set **mocked** (default when possible)
112/// or **real http** networking for your service.
113///
114/// ```rust
115/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
116/// #
117/// use axum::Router;
118/// use axum::routing::get;
119///
120/// use axum_test::TestServer;
121///
122/// let app = Router::new()
123///     .route(&"/hello", get(|| async { "hello!" }));
124///
125/// // Customise server when building
126/// let mut server = TestServer::builder()
127///     .http_transport()
128///     .expect_success_by_default()
129///     .save_cookies()
130///     .build(app)?;
131///
132/// // Add items to be sent on _all_ all requests
133/// server.add_header("x-custom-for-all", "common-value");
134///
135/// let response = server.get("/hello").await;
136/// #
137/// # Ok(())
138/// # }
139/// ```
140///
141#[derive(Debug)]
142pub struct TestServer {
143    state: Arc<Mutex<ServerSharedState>>,
144    transport: Arc<Box<dyn TransportLayer>>,
145    save_cookies: bool,
146    expected_state: ExpectedState,
147    default_content_type: Option<String>,
148    is_http_path_restricted: bool,
149
150    #[cfg(feature = "reqwest")]
151    maybe_reqwest_client: Option<Client>,
152}
153
154impl TestServer {
155    /// A helper function to create a builder for creating a [`TestServer`].
156    pub fn builder() -> TestServerBuilder {
157        TestServerBuilder::default()
158    }
159
160    /// This will run the given Axum app,
161    /// allowing you to make requests against it.
162    ///
163    /// This is the same as creating a new `TestServer` with a configuration,
164    /// and passing [`crate::TestServerConfig::default()`].
165    ///
166    /// ```rust
167    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
168    /// #
169    /// use axum::Router;
170    /// use axum::routing::get;
171    /// use axum_test::TestServer;
172    ///
173    /// let app = Router::new()
174    ///     .route(&"/hello", get(|| async { "hello!" }));
175    ///
176    /// let server = TestServer::new(app)?;
177    /// #
178    /// # Ok(())
179    /// # }
180    /// ```
181    ///
182    /// The of applications that can be passed in include:
183    ///
184    ///  - [`axum::Router`]
185    ///  - [`axum::routing::IntoMakeService`]
186    ///  - [`axum::extract::connect_info::IntoMakeServiceWithConnectInfo`]
187    ///  - [`axum::serve::Serve`]
188    ///  - [`axum::serve::WithGracefulShutdown`]
189    ///  - [`shuttle_axum::ShuttleAxum`]
190    ///
191    pub fn new<A>(app: A) -> Result<Self>
192    where
193        A: IntoTransportLayer,
194    {
195        Self::new_with_config(app, TestServerConfig::default())
196    }
197
198    /// Similar to [`TestServer::new()`], with a customised configuration.
199    /// This includes type of transport in use (i.e. specify a specific port),
200    /// or change default settings (like the default content type for requests).
201    ///
202    /// This can take a [`crate::TestServerConfig`] or a [`crate::TestServerBuilder`].
203    /// See those for more information on configuration settings.
204    pub fn new_with_config<A, C>(app: A, config: C) -> Result<Self>
205    where
206        A: IntoTransportLayer,
207        C: Into<TestServerConfig>,
208    {
209        let config = config.into();
210        let mut shared_state = ServerSharedState::new();
211        if let Some(scheme) = config.default_scheme {
212            shared_state.set_scheme_unlocked(scheme);
213        }
214
215        let shared_state_mutex = Mutex::new(shared_state);
216        let state = Arc::new(shared_state_mutex);
217
218        let transport = match config.transport {
219            None => {
220                let builder = TransportLayerBuilder::new(None, None);
221                let transport = app.into_default_transport(builder)?;
222                Arc::new(transport)
223            }
224            Some(Transport::HttpRandomPort) => {
225                let builder = TransportLayerBuilder::new(None, None);
226                let transport = app.into_http_transport_layer(builder)?;
227                Arc::new(transport)
228            }
229            Some(Transport::HttpIpPort { ip, port }) => {
230                let builder = TransportLayerBuilder::new(ip, port);
231                let transport = app.into_http_transport_layer(builder)?;
232                Arc::new(transport)
233            }
234            Some(Transport::MockHttp) => {
235                let transport = app.into_mock_transport_layer()?;
236                Arc::new(transport)
237            }
238        };
239
240        let expected_state = match config.expect_success_by_default {
241            true => ExpectedState::Success,
242            false => ExpectedState::None,
243        };
244
245        #[cfg(feature = "reqwest")]
246        let maybe_reqwest_client = match transport.transport_layer_type() {
247            TransportLayerType::Http => {
248                let reqwest_client = reqwest::Client::builder()
249                    .redirect(reqwest::redirect::Policy::none())
250                    .cookie_store(config.save_cookies)
251                    .build()
252                    .expect("Failed to build Reqwest Client");
253
254                Some(reqwest_client)
255            }
256            TransportLayerType::Mock => None,
257        };
258
259        Ok(Self {
260            state,
261            transport,
262            save_cookies: config.save_cookies,
263            expected_state,
264            default_content_type: config.default_content_type,
265            is_http_path_restricted: config.restrict_requests_with_http_schema,
266
267            #[cfg(feature = "reqwest")]
268            maybe_reqwest_client,
269        })
270    }
271
272    /// Creates a HTTP GET request to the path.
273    pub fn get(&self, path: &str) -> TestRequest {
274        self.method(Method::GET, path)
275    }
276
277    /// Creates a HTTP POST request to the given path.
278    pub fn post(&self, path: &str) -> TestRequest {
279        self.method(Method::POST, path)
280    }
281
282    /// Creates a HTTP PATCH request to the path.
283    pub fn patch(&self, path: &str) -> TestRequest {
284        self.method(Method::PATCH, path)
285    }
286
287    /// Creates a HTTP PUT request to the path.
288    pub fn put(&self, path: &str) -> TestRequest {
289        self.method(Method::PUT, path)
290    }
291
292    /// Creates a HTTP DELETE request to the path.
293    pub fn delete(&self, path: &str) -> TestRequest {
294        self.method(Method::DELETE, path)
295    }
296
297    /// Creates a HTTP request, to the method and path provided.
298    pub fn method(&self, method: Method, path: &str) -> TestRequest {
299        let maybe_config = self.build_test_request_config(method.clone(), path);
300        let config = maybe_config
301            .with_context(|| format!("Failed to build, for request {method} {path}"))
302            .unwrap();
303
304        TestRequest::new(self.state.clone(), self.transport.clone(), config)
305    }
306
307    #[cfg(feature = "reqwest")]
308    fn reqwest_client(&self) -> &Client {
309        self.maybe_reqwest_client
310            .as_ref()
311            .expect("Reqwest client is not available, TestServer must be build with HTTP transport for Reqwest to be available")
312    }
313
314    #[cfg(feature = "reqwest")]
315    pub fn reqwest_get(&self, path: &str) -> RequestBuilder {
316        self.reqwest_method(Method::GET, path)
317    }
318
319    #[cfg(feature = "reqwest")]
320    pub fn reqwest_post(&self, path: &str) -> RequestBuilder {
321        self.reqwest_method(Method::POST, path)
322    }
323
324    #[cfg(feature = "reqwest")]
325    pub fn reqwest_put(&self, path: &str) -> RequestBuilder {
326        self.reqwest_method(Method::PUT, path)
327    }
328
329    #[cfg(feature = "reqwest")]
330    pub fn reqwest_patch(&self, path: &str) -> RequestBuilder {
331        self.reqwest_method(Method::PATCH, path)
332    }
333
334    #[cfg(feature = "reqwest")]
335    pub fn reqwest_delete(&self, path: &str) -> RequestBuilder {
336        self.reqwest_method(Method::DELETE, path)
337    }
338
339    #[cfg(feature = "reqwest")]
340    pub fn reqwest_head(&self, path: &str) -> RequestBuilder {
341        self.reqwest_method(Method::HEAD, path)
342    }
343
344    /// Creates a HTTP request, using Reqwest, using the method + path described.
345    /// This expects a relative url to the `TestServer`.
346    ///
347    /// ```rust
348    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
349    /// #
350    /// use axum::Router;
351    /// use axum_test::TestServer;
352    ///
353    /// let my_app = Router::new();
354    /// let server = TestServer::builder()
355    ///     .http_transport() // Important, must be HTTP!
356    ///     .build(my_app)?;
357    ///
358    /// // Build your request
359    /// let request = server.get(&"/user")
360    ///     .add_header("x-custom-header", "example.com")
361    ///     .content_type("application/yaml");
362    ///
363    /// // await request to execute
364    /// let response = request.await;
365    /// #
366    /// # Ok(()) }
367    /// ```
368    #[cfg(feature = "reqwest")]
369    pub fn reqwest_method(&self, method: Method, path: &str) -> RequestBuilder {
370        let request_url = self
371            .server_url(path)
372            .expect("Failed to generate server url for request {method} {path}");
373
374        self.reqwest_client().request(method, request_url)
375    }
376
377    /// Creates a request to the server, to start a Websocket connection,
378    /// on the path given.
379    ///
380    /// This is the requivalent of making a GET request to the endpoint,
381    /// and setting the various headers needed for making an upgrade request.
382    ///
383    /// *Note*, this requires the server to be running on a real HTTP
384    /// port. Either using a randomly assigned port, or a specified one.
385    /// See the [`TestServerConfig::transport`](crate::TestServerConfig::transport) for more details.
386    ///
387    /// # Example
388    ///
389    /// ```rust
390    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
391    /// #
392    /// use axum::Router;
393    /// use axum_test::TestServer;
394    ///
395    /// let app = Router::new();
396    /// let server = TestServer::builder()
397    ///     .http_transport()
398    ///     .build(app)?;
399    ///
400    /// let mut websocket = server
401    ///     .get_websocket(&"/my-web-socket-end-point")
402    ///     .await
403    ///     .into_websocket()
404    ///     .await;
405    ///
406    /// websocket.send_text("Hello!").await;
407    /// #
408    /// # Ok(()) }
409    /// ```
410    ///
411    #[cfg(feature = "ws")]
412    pub fn get_websocket(&self, path: &str) -> TestRequest {
413        use http::header;
414
415        self.get(path)
416            .add_header(header::CONNECTION, "upgrade")
417            .add_header(header::UPGRADE, "websocket")
418            .add_header(header::SEC_WEBSOCKET_VERSION, "13")
419            .add_header(
420                header::SEC_WEBSOCKET_KEY,
421                crate::internals::generate_ws_key(),
422            )
423    }
424
425    /// Creates a HTTP GET request, using the typed path provided.
426    ///
427    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
428    ///
429    /// # Example Test
430    ///
431    /// Using a `TypedPath` you can write build and test a route like below:
432    ///
433    /// ```rust
434    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
435    /// #
436    /// use axum::Json;
437    /// use axum::Router;
438    /// use axum::routing::get;
439    /// use axum_extra::routing::RouterExt;
440    /// use axum_extra::routing::TypedPath;
441    /// use serde::Deserialize;
442    /// use serde::Serialize;
443    ///
444    /// use axum_test::TestServer;
445    ///
446    /// #[derive(TypedPath, Deserialize)]
447    /// #[typed_path("/users/{user_id}")]
448    /// struct UserPath {
449    ///     pub user_id: u32,
450    /// }
451    ///
452    /// // Build a typed route:
453    /// async fn route_get_user(UserPath { user_id }: UserPath) -> String {
454    ///     format!("hello user {user_id}")
455    /// }
456    ///
457    /// let app = Router::new()
458    ///     .typed_get(route_get_user);
459    ///
460    /// // Then test the route:
461    /// let server = TestServer::new(app)?;
462    /// server
463    ///     .typed_get(&UserPath { user_id: 123 })
464    ///     .await
465    ///     .assert_text("hello user 123");
466    /// #
467    /// # Ok(())
468    /// # }
469    /// ```
470    ///
471    #[cfg(feature = "typed-routing")]
472    pub fn typed_get<P>(&self, path: &P) -> TestRequest
473    where
474        P: TypedPath,
475    {
476        self.typed_method(Method::GET, path)
477    }
478
479    /// Creates a HTTP POST request, using the typed path provided.
480    ///
481    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
482    #[cfg(feature = "typed-routing")]
483    pub fn typed_post<P>(&self, path: &P) -> TestRequest
484    where
485        P: TypedPath,
486    {
487        self.typed_method(Method::POST, path)
488    }
489
490    /// Creates a HTTP PATCH request, using the typed path provided.
491    ///
492    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
493    #[cfg(feature = "typed-routing")]
494    pub fn typed_patch<P>(&self, path: &P) -> TestRequest
495    where
496        P: TypedPath,
497    {
498        self.typed_method(Method::PATCH, path)
499    }
500
501    /// Creates a HTTP PUT request, using the typed path provided.
502    ///
503    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
504    #[cfg(feature = "typed-routing")]
505    pub fn typed_put<P>(&self, path: &P) -> TestRequest
506    where
507        P: TypedPath,
508    {
509        self.typed_method(Method::PUT, path)
510    }
511
512    /// Creates a HTTP DELETE request, using the typed path provided.
513    ///
514    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
515    #[cfg(feature = "typed-routing")]
516    pub fn typed_delete<P>(&self, path: &P) -> TestRequest
517    where
518        P: TypedPath,
519    {
520        self.typed_method(Method::DELETE, path)
521    }
522
523    /// Creates a typed HTTP request, using the method provided.
524    ///
525    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
526    #[cfg(feature = "typed-routing")]
527    pub fn typed_method<P>(&self, method: Method, path: &P) -> TestRequest
528    where
529        P: TypedPath,
530    {
531        self.method(method, &path.to_string())
532    }
533
534    /// Returns the local web address for the test server,
535    /// if an address is available.
536    ///
537    /// The address is available when running as a real web server,
538    /// by setting the [`TestServerConfig`](crate::TestServerConfig) `transport` field to `Transport::HttpRandomPort` or `Transport::HttpIpPort`.
539    ///
540    /// This will return `None` when there is mock HTTP transport (the default).
541    pub fn server_address(&self) -> Option<Url> {
542        self.url()
543    }
544
545    /// This turns a relative path, into an absolute path to the server.
546    /// i.e. A path like `/users/123` will become something like `http://127.0.0.1:1234/users/123`.
547    ///
548    /// The absolute address can be used to make requests to the running server,
549    /// using any appropriate client you wish.
550    ///
551    /// # Example
552    ///
553    /// ```rust
554    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
555    /// #
556    /// use axum::Router;
557    /// use axum_test::TestServer;
558    ///
559    /// let app = Router::new();
560    /// let server = TestServer::builder()
561    ///         .http_transport()
562    ///         .build(app)?;
563    ///
564    /// let full_url = server.server_url(&"/users/123?filter=enabled")?;
565    ///
566    /// // Prints something like ... http://127.0.0.1:1234/users/123?filter=enabled
567    /// println!("{full_url}");
568    /// #
569    /// # Ok(()) }
570    /// ```
571    ///
572    /// This will return an error if you are using the mock transport.
573    /// Real HTTP transport is required to use this method (see [`TestServerConfig`](crate::TestServerConfig) `transport` field).
574    ///
575    /// It will also return an error if you provide an absolute path,
576    /// for example if you pass in `http://google.com`.
577    pub fn server_url(&self, path: &str) -> Result<Url> {
578        let path_uri = path.parse::<Uri>()?;
579        if is_absolute_uri(&path_uri) {
580            return Err(anyhow!(
581                "Absolute path provided for building server url, need to provide a relative uri"
582            ));
583        }
584
585        let server_url = self.url()
586            .ok_or_else(||
587                anyhow!(
588                    "No local address for server, need to run with HTTP transport to have a server address",
589                )
590            )?;
591
592        let server_locked = self.state.as_ref().lock().map_err(|err| {
593            anyhow!("Failed to lock InternalTestServer, for building server_url, received {err:?}",)
594        })?;
595        let mut query_params = server_locked.query_params().clone();
596        let mut full_server_url = build_url(
597            server_url,
598            path,
599            &mut query_params,
600            self.is_http_path_restricted,
601        )?;
602
603        // Ensure the query params are present
604        if query_params.has_content() {
605            full_server_url.set_query(Some(&query_params.to_string()));
606        }
607
608        Ok(full_server_url)
609    }
610
611    /// Adds a single cookie to be included on *all* future requests.
612    ///
613    /// If a cookie with the same name already exists,
614    /// then it will be replaced.
615    pub fn add_cookie(&mut self, cookie: Cookie) {
616        ServerSharedState::add_cookie(&self.state, cookie)
617            .context("Trying to call add_cookie")
618            .unwrap()
619    }
620
621    /// Adds extra cookies to be used on *all* future requests.
622    ///
623    /// Any cookies which have the same name as the new cookies,
624    /// will get replaced.
625    pub fn add_cookies(&mut self, cookies: CookieJar) {
626        ServerSharedState::add_cookies(&self.state, cookies)
627            .context("Trying to call add_cookies")
628            .unwrap()
629    }
630
631    /// Clears all of the cookies stored internally.
632    pub fn clear_cookies(&mut self) {
633        ServerSharedState::clear_cookies(&self.state)
634            .context("Trying to call clear_cookies")
635            .unwrap()
636    }
637
638    /// Requests made using this `TestServer` will save their cookies for future requests to send.
639    ///
640    /// This behaviour is off by default.
641    pub fn save_cookies(&mut self) {
642        self.save_cookies = true;
643    }
644
645    /// Requests made using this `TestServer` will _not_ save their cookies for future requests to send up.
646    ///
647    /// This is the default behaviour.
648    pub fn do_not_save_cookies(&mut self) {
649        self.save_cookies = false;
650    }
651
652    /// Requests made using this `TestServer` will assert a HTTP status in the 2xx range will be returned, unless marked otherwise.
653    ///
654    /// By default this behaviour is off.
655    pub fn expect_success(&mut self) {
656        self.expected_state = ExpectedState::Success;
657    }
658
659    /// Requests made using this `TestServer` will assert a HTTP status is outside the 2xx range will be returned, unless marked otherwise.
660    ///
661    /// By default this behaviour is off.
662    pub fn expect_failure(&mut self) {
663        self.expected_state = ExpectedState::Failure;
664    }
665
666    /// Adds a query parameter to be sent on *all* future requests.
667    pub fn add_query_param<V>(&mut self, key: &str, value: V)
668    where
669        V: Serialize,
670    {
671        ServerSharedState::add_query_param(&self.state, key, value)
672            .context("Trying to call add_query_param")
673            .unwrap()
674    }
675
676    /// Adds query parameters to be sent on *all* future requests.
677    pub fn add_query_params<V>(&mut self, query_params: V)
678    where
679        V: Serialize,
680    {
681        ServerSharedState::add_query_params(&self.state, query_params)
682            .context("Trying to call add_query_params")
683            .unwrap()
684    }
685
686    /// Adds a raw query param, with no urlencoding of any kind,
687    /// to be send on *all* future requests.
688    pub fn add_raw_query_param(&mut self, raw_query_param: &str) {
689        ServerSharedState::add_raw_query_param(&self.state, raw_query_param)
690            .context("Trying to call add_raw_query_param")
691            .unwrap()
692    }
693
694    /// Clears all query params set.
695    pub fn clear_query_params(&mut self) {
696        ServerSharedState::clear_query_params(&self.state)
697            .context("Trying to call clear_query_params")
698            .unwrap()
699    }
700
701    /// Adds a header to be sent with all future requests built from this `TestServer`.
702    ///
703    /// ```rust
704    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
705    /// #
706    /// use axum::Router;
707    /// use axum_test::TestServer;
708    ///
709    /// let app = Router::new();
710    /// let mut server = TestServer::new(app)?;
711    ///
712    /// server.add_header("x-custom-header", "custom-value");
713    /// server.add_header(http::header::CONTENT_LENGTH, 12345);
714    /// server.add_header(http::header::HOST, "example.com");
715    ///
716    /// let response = server.get(&"/my-end-point")
717    ///     .await;
718    /// #
719    /// # Ok(()) }
720    /// ```
721    pub fn add_header<N, V>(&mut self, name: N, value: V)
722    where
723        N: TryInto<HeaderName>,
724        N::Error: Debug,
725        V: TryInto<HeaderValue>,
726        V::Error: Debug,
727    {
728        let header_name: HeaderName = name
729            .try_into()
730            .expect("Failed to convert header name to HeaderName");
731        let header_value: HeaderValue = value
732            .try_into()
733            .expect("Failed to convert header vlue to HeaderValue");
734
735        ServerSharedState::add_header(&self.state, header_name, header_value)
736            .context("Trying to call add_header")
737            .unwrap()
738    }
739
740    /// Clears all headers set so far.
741    pub fn clear_headers(&mut self) {
742        ServerSharedState::clear_headers(&self.state)
743            .context("Trying to call clear_headers")
744            .unwrap()
745    }
746
747    /// Sets the scheme to use when making _all_ requests from the `TestServer`.
748    /// i.e. http or https.
749    ///
750    /// The default scheme is 'http'.
751    ///
752    /// ```rust
753    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
754    /// #
755    /// use axum::Router;
756    /// use axum_test::TestServer;
757    ///
758    /// let app = Router::new();
759    /// let mut server = TestServer::new(app)?;
760    /// server
761    ///     .scheme(&"https");
762    ///
763    /// let response = server
764    ///     .get(&"/my-end-point")
765    ///     .await;
766    /// #
767    /// # Ok(()) }
768    /// ```
769    ///
770    pub fn scheme(&mut self, scheme: &str) {
771        ServerSharedState::set_scheme(&self.state, scheme.to_string())
772            .context("Trying to call set_scheme")
773            .unwrap()
774    }
775
776    pub(crate) fn url(&self) -> Option<Url> {
777        self.transport.url().cloned()
778    }
779
780    pub(crate) fn build_test_request_config(
781        &self,
782        method: Method,
783        path: &str,
784    ) -> Result<TestRequestConfig> {
785        let url = self
786            .url()
787            .unwrap_or_else(|| DEFAULT_URL_ADDRESS.parse().unwrap());
788
789        let server_locked = self.state.as_ref().lock().map_err(|err| {
790            anyhow!(
791                "Failed to lock InternalTestServer, for request {method} {path}, received {err:?}",
792            )
793        })?;
794
795        let cookies = server_locked.cookies().clone();
796        let mut query_params = server_locked.query_params().clone();
797        let headers = server_locked.headers().clone();
798        let mut full_request_url =
799            build_url(url, path, &mut query_params, self.is_http_path_restricted)?;
800
801        if let Some(scheme) = server_locked.scheme() {
802            full_request_url.set_scheme(scheme).map_err(|_| {
803                let debug_request_format = RequestPathFormatter::new(&method, full_request_url.as_str(), Some(&query_params));
804                anyhow!("Scheme '{scheme}' from TestServer cannot be set to request {debug_request_format}")
805            })?;
806        }
807
808        ::std::mem::drop(server_locked);
809
810        Ok(TestRequestConfig {
811            is_saving_cookies: self.save_cookies,
812            expected_state: self.expected_state,
813            content_type: self.default_content_type.clone(),
814            method,
815
816            full_request_url,
817            cookies,
818            query_params,
819            headers,
820        })
821    }
822
823    /// Returns true or false if the underlying service inside the `TestServer`
824    /// is still running. For many types of services this will always return `true`.
825    ///
826    /// When a `TestServer` is built using [`axum::serve::WithGracefulShutdown`],
827    /// this will return false if the service has shutdown.
828    pub fn is_running(&self) -> bool {
829        self.transport.is_running()
830    }
831}
832
833fn build_url(
834    mut url: Url,
835    path: &str,
836    query_params: &mut QueryParamsStore,
837    is_http_restricted: bool,
838) -> Result<Url> {
839    let path_uri = path.parse::<Uri>()?;
840
841    // If there is a scheme, then this is an absolute path.
842    if let Some(scheme) = path_uri.scheme_str() {
843        if is_http_restricted {
844            if has_different_schema(&url, &path_uri) || has_different_authority(&url, &path_uri) {
845                return Err(anyhow!("Request disallowed for path '{path}', requests are only allowed to local server. Turn off 'restrict_requests_with_http_schema' to change this."));
846            }
847        } else {
848            url.set_scheme(scheme)
849                .map_err(|_| anyhow!("Failed to set scheme for request, with path '{path}'"))?;
850
851            // We only set the host/port if the scheme is also present.
852            if let Some(authority) = path_uri.authority() {
853                url.set_host(Some(authority.host()))
854                    .map_err(|_| anyhow!("Failed to set host for request, with path '{path}'"))?;
855                url.set_port(authority.port().map(|p| p.as_u16()))
856                    .map_err(|_| anyhow!("Failed to set port for request, with path '{path}'"))?;
857
858                // todo, add username:password support
859            }
860        }
861    }
862
863    // Why does this exist?
864    //
865    // This exists to allow `server.get("/users")` and `server.get("users")` (without a slash)
866    // to go to the same place.
867    //
868    // It does this by saying ...
869    //  - if there is a scheme, it's a full path.
870    //  - if no scheme, it must be a path
871    //
872    if is_absolute_uri(&path_uri) {
873        url.set_path(path_uri.path());
874
875        // In this path we are replacing, so drop any query params on the original url.
876        if url.query().is_some() {
877            url.set_query(None);
878        }
879    } else {
880        // Grab everything up until the query parameters, or everything after that
881        let calculated_path = path.split('?').next().unwrap_or(path);
882        url.set_path(calculated_path);
883
884        // Move any query parameters from the url to the query params store.
885        if let Some(url_query) = url.query() {
886            query_params.add_raw(url_query.to_string());
887            url.set_query(None);
888        }
889    }
890
891    if let Some(path_query) = path_uri.query() {
892        query_params.add_raw(path_query.to_string());
893    }
894
895    Ok(url)
896}
897
898fn is_absolute_uri(path_uri: &Uri) -> bool {
899    path_uri.scheme_str().is_some()
900}
901
902fn has_different_schema(base_url: &Url, path_uri: &Uri) -> bool {
903    if let Some(scheme) = path_uri.scheme_str() {
904        return scheme != base_url.scheme();
905    }
906
907    false
908}
909
910fn has_different_authority(base_url: &Url, path_uri: &Uri) -> bool {
911    if let Some(authority) = path_uri.authority() {
912        return authority.as_str() != base_url.authority();
913    }
914
915    false
916}
917
918#[cfg(test)]
919mod test_build_url {
920    use super::*;
921
922    #[test]
923    fn it_should_copy_path_to_url_returned_when_restricted() {
924        let base_url = "http://example.com".parse::<Url>().unwrap();
925        let path = "/users";
926        let mut query_params = QueryParamsStore::new();
927        let result = build_url(base_url, &path, &mut query_params, true).unwrap();
928
929        assert_eq!("http://example.com/users", result.as_str());
930        assert!(query_params.is_empty());
931    }
932
933    #[test]
934    fn it_should_copy_all_query_params_to_store_when_restricted() {
935        let base_url = "http://example.com?base=aaa".parse::<Url>().unwrap();
936        let path = "/users?path=bbb&path-flag";
937        let mut query_params = QueryParamsStore::new();
938        let result = build_url(base_url, &path, &mut query_params, true).unwrap();
939
940        assert_eq!("http://example.com/users", result.as_str());
941        assert_eq!("base=aaa&path=bbb&path-flag", query_params.to_string());
942    }
943
944    #[test]
945    fn it_should_not_replace_url_when_restricted_with_different_scheme() {
946        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
947        let path = "ftp://google.com:123/users.csv?limit=456";
948        let mut query_params = QueryParamsStore::new();
949        let result = build_url(base_url, &path, &mut query_params, true);
950
951        assert!(result.is_err());
952    }
953
954    #[test]
955    fn it_should_not_replace_url_when_restricted_with_same_scheme() {
956        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
957        let path = "http://google.com:123/users.csv?limit=456";
958        let mut query_params = QueryParamsStore::new();
959        let result = build_url(base_url, &path, &mut query_params, true);
960
961        assert!(result.is_err());
962    }
963
964    #[test]
965    fn it_should_block_url_when_restricted_with_same_scheme() {
966        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
967        let path = "http://google.com";
968        let mut query_params = QueryParamsStore::new();
969        let result = build_url(base_url, &path, &mut query_params, true);
970
971        assert!(result.is_err());
972    }
973
974    #[test]
975    fn it_should_block_url_when_restricted_and_same_domain_with_different_scheme() {
976        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
977        let path = "ftp://example.com/users";
978        let mut query_params = QueryParamsStore::new();
979        let result = build_url(base_url, &path, &mut query_params, true);
980
981        assert!(result.is_err());
982    }
983
984    #[test]
985    fn it_should_copy_path_to_url_returned_when_unrestricted() {
986        let base_url = "http://example.com".parse::<Url>().unwrap();
987        let path = "/users";
988        let mut query_params = QueryParamsStore::new();
989        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
990
991        assert_eq!("http://example.com/users", result.as_str());
992        assert!(query_params.is_empty());
993    }
994
995    #[test]
996    fn it_should_copy_all_query_params_to_store_when_unrestricted() {
997        let base_url = "http://example.com?base=aaa".parse::<Url>().unwrap();
998        let path = "/users?path=bbb&path-flag";
999        let mut query_params = QueryParamsStore::new();
1000        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1001
1002        assert_eq!("http://example.com/users", result.as_str());
1003        assert_eq!("base=aaa&path=bbb&path-flag", query_params.to_string());
1004    }
1005
1006    #[test]
1007    fn it_should_copy_host_like_a_path_when_unrestricted() {
1008        let base_url = "http://example.com".parse::<Url>().unwrap();
1009        let path = "google.com";
1010        let mut query_params = QueryParamsStore::new();
1011        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1012
1013        assert_eq!("http://example.com/google.com", result.as_str());
1014        assert!(query_params.is_empty());
1015    }
1016
1017    #[test]
1018    fn it_should_copy_host_like_a_path_when_restricted() {
1019        let base_url = "http://example.com".parse::<Url>().unwrap();
1020        let path = "google.com";
1021        let mut query_params = QueryParamsStore::new();
1022        let result = build_url(base_url, &path, &mut query_params, true).unwrap();
1023
1024        assert_eq!("http://example.com/google.com", result.as_str());
1025        assert!(query_params.is_empty());
1026    }
1027
1028    #[test]
1029    fn it_should_replace_url_when_unrestricted() {
1030        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
1031        let path = "ftp://google.com:123/users.csv?limit=456";
1032        let mut query_params = QueryParamsStore::new();
1033        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1034
1035        assert_eq!("ftp://google.com:123/users.csv", result.as_str());
1036        assert_eq!("limit=456", query_params.to_string());
1037    }
1038
1039    #[test]
1040    fn it_should_allow_different_scheme_when_unrestricted() {
1041        let base_url = "http://example.com".parse::<Url>().unwrap();
1042        let path = "ftp://example.com";
1043        let mut query_params = QueryParamsStore::new();
1044        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1045
1046        assert_eq!("ftp://example.com/", result.as_str());
1047    }
1048
1049    #[test]
1050    fn it_should_allow_different_host_when_unrestricted() {
1051        let base_url = "http://example.com".parse::<Url>().unwrap();
1052        let path = "http://google.com";
1053        let mut query_params = QueryParamsStore::new();
1054        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1055
1056        assert_eq!("http://google.com/", result.as_str());
1057    }
1058
1059    #[test]
1060    fn it_should_allow_different_port_when_unrestricted() {
1061        let base_url = "http://example.com:123".parse::<Url>().unwrap();
1062        let path = "http://example.com:456";
1063        let mut query_params = QueryParamsStore::new();
1064        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1065
1066        assert_eq!("http://example.com:456/", result.as_str());
1067    }
1068
1069    #[test]
1070    fn it_should_allow_same_host_port_when_unrestricted() {
1071        let base_url = "http://example.com:123".parse::<Url>().unwrap();
1072        let path = "http://example.com:123";
1073        let mut query_params = QueryParamsStore::new();
1074        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1075
1076        assert_eq!("http://example.com:123/", result.as_str());
1077    }
1078
1079    #[test]
1080    fn it_should_not_allow_different_scheme_when_restricted() {
1081        let base_url = "http://example.com".parse::<Url>().unwrap();
1082        let path = "ftp://example.com";
1083        let mut query_params = QueryParamsStore::new();
1084        let result = build_url(base_url, &path, &mut query_params, true);
1085
1086        assert!(result.is_err());
1087    }
1088
1089    #[test]
1090    fn it_should_not_allow_different_host_when_restricted() {
1091        let base_url = "http://example.com".parse::<Url>().unwrap();
1092        let path = "http://google.com";
1093        let mut query_params = QueryParamsStore::new();
1094        let result = build_url(base_url, &path, &mut query_params, true);
1095
1096        assert!(result.is_err());
1097    }
1098
1099    #[test]
1100    fn it_should_not_allow_different_port_when_restricted() {
1101        let base_url = "http://example.com:123".parse::<Url>().unwrap();
1102        let path = "http://example.com:456";
1103        let mut query_params = QueryParamsStore::new();
1104        let result = build_url(base_url, &path, &mut query_params, true);
1105
1106        assert!(result.is_err());
1107    }
1108
1109    #[test]
1110    fn it_should_allow_same_host_port_when_restricted() {
1111        let base_url = "http://example.com:123".parse::<Url>().unwrap();
1112        let path = "http://example.com:123";
1113        let mut query_params = QueryParamsStore::new();
1114        let result = build_url(base_url, &path, &mut query_params, true).unwrap();
1115
1116        assert_eq!("http://example.com:123/", result.as_str());
1117    }
1118}
1119
1120#[cfg(test)]
1121mod test_new {
1122    use axum::routing::get;
1123    use axum::Router;
1124    use std::net::SocketAddr;
1125
1126    use crate::TestServer;
1127
1128    async fn get_ping() -> &'static str {
1129        "pong!"
1130    }
1131
1132    #[tokio::test]
1133    async fn it_should_run_into_make_into_service_with_connect_info_by_default() {
1134        // Build an application with a route.
1135        let app = Router::new()
1136            .route("/ping", get(get_ping))
1137            .into_make_service_with_connect_info::<SocketAddr>();
1138
1139        // Run the server.
1140        let server = TestServer::new(app).expect("Should create test server");
1141
1142        // Get the request.
1143        server.get(&"/ping").await.assert_text(&"pong!");
1144    }
1145}
1146
1147#[cfg(test)]
1148mod test_get {
1149    use super::*;
1150
1151    use axum::routing::get;
1152    use axum::Router;
1153    use reserve_port::ReservedSocketAddr;
1154
1155    async fn get_ping() -> &'static str {
1156        "pong!"
1157    }
1158
1159    #[tokio::test]
1160    async fn it_should_get_using_relative_path_with_slash() {
1161        let app = Router::new().route("/ping", get(get_ping));
1162        let server = TestServer::new(app).expect("Should create test server");
1163
1164        // Get the request _with_ slash
1165        server.get(&"/ping").await.assert_text(&"pong!");
1166    }
1167
1168    #[tokio::test]
1169    async fn it_should_get_using_relative_path_without_slash() {
1170        let app = Router::new().route("/ping", get(get_ping));
1171        let server = TestServer::new(app).expect("Should create test server");
1172
1173        // Get the request _without_ slash
1174        server.get(&"ping").await.assert_text(&"pong!");
1175    }
1176
1177    #[tokio::test]
1178    async fn it_should_get_using_absolute_path() {
1179        // Build an application with a route.
1180        let app = Router::new().route("/ping", get(get_ping));
1181
1182        // Reserve an address
1183        let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1184        let ip = reserved_address.ip();
1185        let port = reserved_address.port();
1186
1187        // Run the server.
1188        let server = TestServer::builder()
1189            .http_transport_with_ip_port(Some(ip), Some(port))
1190            .build(app)
1191            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1192            .unwrap();
1193
1194        // Get the request.
1195        let absolute_url = format!("http://{ip}:{port}/ping");
1196        let response = server.get(&absolute_url).await;
1197
1198        response.assert_text(&"pong!");
1199        let request_path = response.request_url();
1200        assert_eq!(request_path.to_string(), format!("http://{ip}:{port}/ping"));
1201    }
1202
1203    #[tokio::test]
1204    async fn it_should_get_using_absolute_path_and_restricted_if_path_is_for_server() {
1205        // Build an application with a route.
1206        let app = Router::new().route("/ping", get(get_ping));
1207
1208        // Reserve an IP / Port
1209        let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1210        let ip = reserved_address.ip();
1211        let port = reserved_address.port();
1212
1213        // Run the server.
1214        let server = TestServer::builder()
1215            .http_transport_with_ip_port(Some(ip), Some(port))
1216            .restrict_requests_with_http_schema() // Key part of the test!
1217            .build(app)
1218            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1219            .unwrap();
1220
1221        // Get the request.
1222        let absolute_url = format!("http://{ip}:{port}/ping");
1223        let response = server.get(&absolute_url).await;
1224
1225        response.assert_text(&"pong!");
1226        let request_path = response.request_url();
1227        assert_eq!(request_path.to_string(), format!("http://{ip}:{port}/ping"));
1228    }
1229
1230    #[tokio::test]
1231    #[should_panic]
1232    async fn it_should_not_get_using_absolute_path_if_restricted_and_different_port() {
1233        // Build an application with a route.
1234        let app = Router::new().route("/ping", get(get_ping));
1235
1236        // Reserve an IP / Port
1237        let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1238        let ip = reserved_address.ip();
1239        let mut port = reserved_address.port();
1240
1241        // Run the server.
1242        let server = TestServer::builder()
1243            .http_transport_with_ip_port(Some(ip), Some(port))
1244            .restrict_requests_with_http_schema() // Key part of the test!
1245            .build(app)
1246            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1247            .unwrap();
1248
1249        // Get the request.
1250        port += 1; // << Change the port to be off by one and not match the server
1251        let absolute_url = format!("http://{ip}:{port}/ping");
1252        server.get(&absolute_url).await;
1253    }
1254
1255    #[tokio::test]
1256    async fn it_should_work_in_parallel() {
1257        let app = Router::new().route("/ping", get(get_ping));
1258        let server = TestServer::new(app).expect("Should create test server");
1259
1260        let future1 = async { server.get("/ping").await };
1261        let future2 = async { server.get("/ping").await };
1262        let (r1, r2) = tokio::join!(future1, future2);
1263
1264        assert_eq!(r1.text(), r2.text());
1265    }
1266
1267    #[tokio::test]
1268    async fn it_should_work_in_parallel_with_sleeping_requests() {
1269        let app = axum::Router::new().route(
1270            &"/slow",
1271            axum::routing::get(|| async {
1272                tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1273                "hello!"
1274            }),
1275        );
1276
1277        let server = TestServer::new(app).expect("Should create test server");
1278
1279        let future1 = async { server.get("/slow").await };
1280        let future2 = async { server.get("/slow").await };
1281        let (r1, r2) = tokio::join!(future1, future2);
1282
1283        assert_eq!(r1.text(), r2.text());
1284    }
1285}
1286
1287#[cfg(feature = "reqwest")]
1288#[cfg(test)]
1289mod test_reqwest_get {
1290    use super::*;
1291
1292    use axum::routing::get;
1293    use axum::Router;
1294
1295    async fn get_ping() -> &'static str {
1296        "pong!"
1297    }
1298
1299    #[tokio::test]
1300    async fn it_should_get_using_relative_path_with_slash() {
1301        let app = Router::new().route("/ping", get(get_ping));
1302        let server = TestServer::builder()
1303            .http_transport()
1304            .build(app)
1305            .expect("Should create test server");
1306
1307        let response = server
1308            .reqwest_get(&"/ping")
1309            .send()
1310            .await
1311            .unwrap()
1312            .text()
1313            .await
1314            .unwrap();
1315
1316        assert_eq!(response, "pong!");
1317    }
1318}
1319
1320#[cfg(feature = "reqwest")]
1321#[cfg(test)]
1322mod test_reqwest_post {
1323    use super::*;
1324
1325    use axum::routing::post;
1326    use axum::Json;
1327    use axum::Router;
1328    use serde::Deserialize;
1329
1330    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1331    struct TestBody {
1332        number: u32,
1333        text: String,
1334    }
1335
1336    async fn post_json(Json(body): Json<TestBody>) -> Json<TestBody> {
1337        let response = TestBody {
1338            number: body.number * 2,
1339            text: format!("{}_plus_response", body.text),
1340        };
1341
1342        Json(response)
1343    }
1344
1345    #[tokio::test]
1346    async fn it_should_post_and_receive_json() {
1347        let app = Router::new().route("/json", post(post_json));
1348        let server = TestServer::builder()
1349            .http_transport()
1350            .build(app)
1351            .expect("Should create test server");
1352
1353        let response = server
1354            .reqwest_post(&"/json")
1355            .json(&TestBody {
1356                number: 111,
1357                text: format!("request"),
1358            })
1359            .send()
1360            .await
1361            .unwrap()
1362            .json::<TestBody>()
1363            .await
1364            .unwrap();
1365
1366        assert_eq!(
1367            response,
1368            TestBody {
1369                number: 222,
1370                text: format!("request_plus_response"),
1371            }
1372        );
1373    }
1374}
1375
1376#[cfg(test)]
1377mod test_server_address {
1378    use super::*;
1379
1380    use axum::Router;
1381    use local_ip_address::local_ip;
1382    use regex::Regex;
1383    use reserve_port::ReservedPort;
1384
1385    #[tokio::test]
1386    async fn it_should_return_address_used_from_config() {
1387        let reserved_port = ReservedPort::random().unwrap();
1388        let ip = local_ip().unwrap();
1389        let port = reserved_port.port();
1390
1391        // Build an application with a route.
1392        let app = Router::new();
1393        let server = TestServer::builder()
1394            .http_transport_with_ip_port(Some(ip), Some(port))
1395            .build(app)
1396            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1397            .unwrap();
1398
1399        let expected_ip_port = format!("http://{}:{}/", ip, reserved_port.port());
1400        assert_eq!(
1401            server.server_address().unwrap().to_string(),
1402            expected_ip_port
1403        );
1404    }
1405
1406    #[tokio::test]
1407    async fn it_should_return_default_address_without_ending_slash() {
1408        let app = Router::new();
1409        let server = TestServer::builder()
1410            .http_transport()
1411            .build(app)
1412            .expect("Should create test server");
1413
1414        let address_regex = Regex::new("^http://127\\.0\\.0\\.1:[0-9]+/$").unwrap();
1415        let is_match = address_regex.is_match(&server.server_address().unwrap().to_string());
1416        assert!(is_match);
1417    }
1418
1419    #[tokio::test]
1420    async fn it_should_return_none_on_mock_transport() {
1421        let app = Router::new();
1422        let server = TestServer::builder()
1423            .mock_transport()
1424            .build(app)
1425            .expect("Should create test server");
1426
1427        assert!(server.server_address().is_none());
1428    }
1429}
1430
1431#[cfg(test)]
1432mod test_server_url {
1433    use super::*;
1434
1435    use axum::Router;
1436    use local_ip_address::local_ip;
1437    use regex::Regex;
1438    use reserve_port::ReservedPort;
1439
1440    #[tokio::test]
1441    async fn it_should_return_address_with_url_on_http_ip_port() {
1442        let reserved_port = ReservedPort::random().unwrap();
1443        let ip = local_ip().unwrap();
1444        let port = reserved_port.port();
1445
1446        // Build an application with a route.
1447        let app = Router::new();
1448        let server = TestServer::builder()
1449            .http_transport_with_ip_port(Some(ip), Some(port))
1450            .build(app)
1451            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1452            .unwrap();
1453
1454        let expected_ip_port_url = format!("http://{}:{}/users", ip, reserved_port.port());
1455        let absolute_url = server.server_url("/users").unwrap().to_string();
1456        assert_eq!(absolute_url, expected_ip_port_url);
1457    }
1458
1459    #[tokio::test]
1460    async fn it_should_return_address_with_url_on_random_http() {
1461        let app = Router::new();
1462        let server = TestServer::builder()
1463            .http_transport()
1464            .build(app)
1465            .expect("Should create test server");
1466
1467        let address_regex =
1468            Regex::new("^http://127\\.0\\.0\\.1:[0-9]+/users/123\\?filter=enabled$").unwrap();
1469        let absolute_url = &server
1470            .server_url(&"/users/123?filter=enabled")
1471            .unwrap()
1472            .to_string();
1473
1474        let is_match = address_regex.is_match(absolute_url);
1475        assert!(is_match);
1476    }
1477
1478    #[tokio::test]
1479    async fn it_should_error_on_mock_transport() {
1480        // Build an application with a route.
1481        let app = Router::new();
1482        let server = TestServer::builder()
1483            .mock_transport()
1484            .build(app)
1485            .expect("Should create test server");
1486
1487        let result = server.server_url("/users");
1488        assert!(result.is_err());
1489    }
1490
1491    #[tokio::test]
1492    async fn it_should_include_path_query_params() {
1493        let reserved_port = ReservedPort::random().unwrap();
1494        let ip = local_ip().unwrap();
1495        let port = reserved_port.port();
1496
1497        // Build an application with a route.
1498        let app = Router::new();
1499        let server = TestServer::builder()
1500            .http_transport_with_ip_port(Some(ip), Some(port))
1501            .build(app)
1502            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1503            .unwrap();
1504
1505        let expected_url = format!(
1506            "http://{}:{}/users?filter=enabled",
1507            ip,
1508            reserved_port.port()
1509        );
1510        let received_url = server
1511            .server_url("/users?filter=enabled")
1512            .unwrap()
1513            .to_string();
1514
1515        assert_eq!(received_url, expected_url);
1516    }
1517
1518    #[tokio::test]
1519    async fn it_should_include_server_query_params() {
1520        let reserved_port = ReservedPort::random().unwrap();
1521        let ip = local_ip().unwrap();
1522        let port = reserved_port.port();
1523
1524        // Build an application with a route.
1525        let app = Router::new();
1526        let mut server = TestServer::builder()
1527            .http_transport_with_ip_port(Some(ip), Some(port))
1528            .build(app)
1529            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1530            .unwrap();
1531
1532        server.add_query_param("filter", "enabled");
1533
1534        let expected_url = format!(
1535            "http://{}:{}/users?filter=enabled",
1536            ip,
1537            reserved_port.port()
1538        );
1539        let received_url = server.server_url("/users").unwrap().to_string();
1540
1541        assert_eq!(received_url, expected_url);
1542    }
1543
1544    #[tokio::test]
1545    async fn it_should_include_server_and_path_query_params() {
1546        let reserved_port = ReservedPort::random().unwrap();
1547        let ip = local_ip().unwrap();
1548        let port = reserved_port.port();
1549
1550        // Build an application with a route.
1551        let app = Router::new();
1552        let mut server = TestServer::builder()
1553            .http_transport_with_ip_port(Some(ip), Some(port))
1554            .build(app)
1555            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1556            .unwrap();
1557
1558        server.add_query_param("filter", "enabled");
1559
1560        let expected_url = format!(
1561            "http://{}:{}/users?filter=enabled&animal=donkeys",
1562            ip,
1563            reserved_port.port()
1564        );
1565        let received_url = server
1566            .server_url("/users?animal=donkeys")
1567            .unwrap()
1568            .to_string();
1569
1570        assert_eq!(received_url, expected_url);
1571    }
1572}
1573
1574#[cfg(test)]
1575mod test_add_cookie {
1576    use crate::TestServer;
1577
1578    use axum::routing::get;
1579    use axum::Router;
1580    use axum_extra::extract::cookie::CookieJar;
1581    use cookie::Cookie;
1582
1583    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
1584
1585    async fn get_cookie(cookies: CookieJar) -> (CookieJar, String) {
1586        let cookie = cookies.get(&TEST_COOKIE_NAME);
1587        let cookie_value = cookie
1588            .map(|c| c.value().to_string())
1589            .unwrap_or_else(|| "cookie-not-found".to_string());
1590
1591        (cookies, cookie_value)
1592    }
1593
1594    #[tokio::test]
1595    async fn it_should_send_cookies_added_to_request() {
1596        let app = Router::new().route("/cookie", get(get_cookie));
1597        let mut server = TestServer::new(app).expect("Should create test server");
1598
1599        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1600        server.add_cookie(cookie);
1601
1602        let response_text = server.get(&"/cookie").await.text();
1603        assert_eq!(response_text, "my-custom-cookie");
1604    }
1605}
1606
1607#[cfg(test)]
1608mod test_add_cookies {
1609    use crate::TestServer;
1610
1611    use axum::routing::get;
1612    use axum::Router;
1613    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1614    use cookie::Cookie;
1615    use cookie::CookieJar;
1616
1617    async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1618        let mut all_cookies = cookies
1619            .iter()
1620            .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1621            .collect::<Vec<String>>();
1622        all_cookies.sort();
1623
1624        all_cookies.join(&", ")
1625    }
1626
1627    #[tokio::test]
1628    async fn it_should_send_all_cookies_added_by_jar() {
1629        let app = Router::new().route("/cookies", get(route_get_cookies));
1630        let mut server = TestServer::new(app).expect("Should create test server");
1631
1632        // Build cookies to send up
1633        let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1634        let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1635        let mut cookie_jar = CookieJar::new();
1636        cookie_jar.add(cookie_1);
1637        cookie_jar.add(cookie_2);
1638
1639        server.add_cookies(cookie_jar);
1640
1641        server
1642            .get(&"/cookies")
1643            .await
1644            .assert_text("first-cookie=my-custom-cookie, second-cookie=other-cookie");
1645    }
1646}
1647
1648#[cfg(test)]
1649mod test_clear_cookies {
1650    use crate::TestServer;
1651
1652    use axum::routing::get;
1653    use axum::Router;
1654    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1655    use cookie::Cookie;
1656    use cookie::CookieJar;
1657
1658    async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1659        let mut all_cookies = cookies
1660            .iter()
1661            .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1662            .collect::<Vec<String>>();
1663        all_cookies.sort();
1664
1665        all_cookies.join(&", ")
1666    }
1667
1668    #[tokio::test]
1669    async fn it_should_not_send_cookies_cleared() {
1670        let app = Router::new().route("/cookies", get(route_get_cookies));
1671        let mut server = TestServer::new(app).expect("Should create test server");
1672
1673        let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1674        let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1675        let mut cookie_jar = CookieJar::new();
1676        cookie_jar.add(cookie_1);
1677        cookie_jar.add(cookie_2);
1678
1679        server.add_cookies(cookie_jar);
1680
1681        // The important bit of this test
1682        server.clear_cookies();
1683
1684        server.get(&"/cookies").await.assert_text("");
1685    }
1686}
1687
1688#[cfg(test)]
1689mod test_add_header {
1690    use super::*;
1691    use crate::TestServer;
1692    use axum::extract::FromRequestParts;
1693    use axum::routing::get;
1694    use axum::Router;
1695    use http::request::Parts;
1696    use http::HeaderName;
1697    use http::HeaderValue;
1698    use hyper::StatusCode;
1699    use std::marker::Sync;
1700
1701    const TEST_HEADER_NAME: &'static str = &"test-header";
1702    const TEST_HEADER_CONTENT: &'static str = &"Test header content";
1703
1704    struct TestHeader(Vec<u8>);
1705
1706    impl<S: Sync> FromRequestParts<S> for TestHeader {
1707        type Rejection = (StatusCode, &'static str);
1708
1709        async fn from_request_parts(
1710            parts: &mut Parts,
1711            _state: &S,
1712        ) -> Result<TestHeader, Self::Rejection> {
1713            parts
1714                .headers
1715                .get(HeaderName::from_static(TEST_HEADER_NAME))
1716                .map(|v| TestHeader(v.as_bytes().to_vec()))
1717                .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
1718        }
1719    }
1720
1721    async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
1722        header
1723    }
1724
1725    #[tokio::test]
1726    async fn it_should_send_header_added_to_server() {
1727        // Build an application with a route.
1728        let app = Router::new().route("/header", get(ping_header));
1729
1730        // Run the server.
1731        let mut server = TestServer::new(app).expect("Should create test server");
1732        server.add_header(
1733            HeaderName::from_static(TEST_HEADER_NAME),
1734            HeaderValue::from_static(TEST_HEADER_CONTENT),
1735        );
1736
1737        // Send a request with the header
1738        let response = server.get(&"/header").await;
1739
1740        // Check it sent back the right text
1741        response.assert_text(TEST_HEADER_CONTENT)
1742    }
1743}
1744
1745#[cfg(test)]
1746mod test_clear_headers {
1747    use super::*;
1748    use crate::TestServer;
1749    use axum::extract::FromRequestParts;
1750    use axum::routing::get;
1751    use axum::Router;
1752    use http::request::Parts;
1753    use http::HeaderName;
1754    use http::HeaderValue;
1755    use hyper::StatusCode;
1756    use std::marker::Sync;
1757
1758    const TEST_HEADER_NAME: &'static str = &"test-header";
1759    const TEST_HEADER_CONTENT: &'static str = &"Test header content";
1760
1761    struct TestHeader(Vec<u8>);
1762
1763    impl<S: Sync> FromRequestParts<S> for TestHeader {
1764        type Rejection = (StatusCode, &'static str);
1765
1766        async fn from_request_parts(
1767            parts: &mut Parts,
1768            _state: &S,
1769        ) -> Result<Self, Self::Rejection> {
1770            parts
1771                .headers
1772                .get(HeaderName::from_static(TEST_HEADER_NAME))
1773                .map(|v| TestHeader(v.as_bytes().to_vec()))
1774                .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
1775        }
1776    }
1777
1778    async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
1779        header
1780    }
1781
1782    #[tokio::test]
1783    async fn it_should_not_send_headers_cleared_by_server() {
1784        // Build an application with a route.
1785        let app = Router::new().route("/header", get(ping_header));
1786
1787        // Run the server.
1788        let mut server = TestServer::new(app).expect("Should create test server");
1789        server.add_header(
1790            HeaderName::from_static(TEST_HEADER_NAME),
1791            HeaderValue::from_static(TEST_HEADER_CONTENT),
1792        );
1793        server.clear_headers();
1794
1795        // Send a request with the header
1796        let response = server.get(&"/header").await;
1797
1798        // Check it sent back the right text
1799        response.assert_status_bad_request();
1800        response.assert_text("Missing test header");
1801    }
1802}
1803
1804#[cfg(test)]
1805mod test_add_query_params {
1806    use axum::extract::Query;
1807    use axum::routing::get;
1808    use axum::Router;
1809
1810    use serde::Deserialize;
1811    use serde::Serialize;
1812    use serde_json::json;
1813
1814    use crate::TestServer;
1815
1816    #[derive(Debug, Deserialize, Serialize)]
1817    struct QueryParam {
1818        message: String,
1819    }
1820
1821    async fn get_query_param(Query(params): Query<QueryParam>) -> String {
1822        params.message
1823    }
1824
1825    #[derive(Debug, Deserialize, Serialize)]
1826    struct QueryParam2 {
1827        message: String,
1828        other: String,
1829    }
1830
1831    async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
1832        format!("{}-{}", params.message, params.other)
1833    }
1834
1835    #[tokio::test]
1836    async fn it_should_pass_up_query_params_from_serialization() {
1837        // Build an application with a route.
1838        let app = Router::new().route("/query", get(get_query_param));
1839
1840        // Run the server.
1841        let mut server = TestServer::new(app).expect("Should create test server");
1842        server.add_query_params(QueryParam {
1843            message: "it works".to_string(),
1844        });
1845
1846        // Get the request.
1847        server.get(&"/query").await.assert_text(&"it works");
1848    }
1849
1850    #[tokio::test]
1851    async fn it_should_pass_up_query_params_from_pairs() {
1852        // Build an application with a route.
1853        let app = Router::new().route("/query", get(get_query_param));
1854
1855        // Run the server.
1856        let mut server = TestServer::new(app).expect("Should create test server");
1857        server.add_query_params(&[("message", "it works")]);
1858
1859        // Get the request.
1860        server.get(&"/query").await.assert_text(&"it works");
1861    }
1862
1863    #[tokio::test]
1864    async fn it_should_pass_up_multiple_query_params_from_multiple_params() {
1865        // Build an application with a route.
1866        let app = Router::new().route("/query-2", get(get_query_param_2));
1867
1868        // Run the server.
1869        let mut server = TestServer::new(app).expect("Should create test server");
1870        server.add_query_params(&[("message", "it works"), ("other", "yup")]);
1871
1872        // Get the request.
1873        server.get(&"/query-2").await.assert_text(&"it works-yup");
1874    }
1875
1876    #[tokio::test]
1877    async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
1878        // Build an application with a route.
1879        let app = Router::new().route("/query-2", get(get_query_param_2));
1880
1881        // Run the server.
1882        let mut server = TestServer::new(app).expect("Should create test server");
1883        server.add_query_params(&[("message", "it works")]);
1884        server.add_query_params(&[("other", "yup")]);
1885
1886        // Get the request.
1887        server.get(&"/query-2").await.assert_text(&"it works-yup");
1888    }
1889
1890    #[tokio::test]
1891    async fn it_should_pass_up_multiple_query_params_from_json() {
1892        // Build an application with a route.
1893        let app = Router::new().route("/query-2", get(get_query_param_2));
1894
1895        // Run the server.
1896        let mut server = TestServer::new(app).expect("Should create test server");
1897        server.add_query_params(json!({
1898            "message": "it works",
1899            "other": "yup"
1900        }));
1901
1902        // Get the request.
1903        server.get(&"/query-2").await.assert_text(&"it works-yup");
1904    }
1905}
1906
1907#[cfg(test)]
1908mod test_add_query_param {
1909    use axum::extract::Query;
1910    use axum::routing::get;
1911    use axum::Router;
1912
1913    use serde::Deserialize;
1914    use serde::Serialize;
1915
1916    use crate::TestServer;
1917
1918    #[derive(Debug, Deserialize, Serialize)]
1919    struct QueryParam {
1920        message: String,
1921    }
1922
1923    async fn get_query_param(Query(params): Query<QueryParam>) -> String {
1924        params.message
1925    }
1926
1927    #[derive(Debug, Deserialize, Serialize)]
1928    struct QueryParam2 {
1929        message: String,
1930        other: String,
1931    }
1932
1933    async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
1934        format!("{}-{}", params.message, params.other)
1935    }
1936
1937    #[tokio::test]
1938    async fn it_should_pass_up_query_params_from_pairs() {
1939        // Build an application with a route.
1940        let app = Router::new().route("/query", get(get_query_param));
1941
1942        // Run the server.
1943        let mut server = TestServer::new(app).expect("Should create test server");
1944        server.add_query_param("message", "it works");
1945
1946        // Get the request.
1947        server.get(&"/query").await.assert_text(&"it works");
1948    }
1949
1950    #[tokio::test]
1951    async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
1952        // Build an application with a route.
1953        let app = Router::new().route("/query-2", get(get_query_param_2));
1954
1955        // Run the server.
1956        let mut server = TestServer::new(app).expect("Should create test server");
1957        server.add_query_param("message", "it works");
1958        server.add_query_param("other", "yup");
1959
1960        // Get the request.
1961        server.get(&"/query-2").await.assert_text(&"it works-yup");
1962    }
1963
1964    #[tokio::test]
1965    async fn it_should_pass_up_multiple_query_params_from_calls_across_server_and_request() {
1966        // Build an application with a route.
1967        let app = Router::new().route("/query-2", get(get_query_param_2));
1968
1969        // Run the server.
1970        let mut server = TestServer::new(app).expect("Should create test server");
1971        server.add_query_param("message", "it works");
1972
1973        // Get the request.
1974        server
1975            .get(&"/query-2")
1976            .add_query_param("other", "yup")
1977            .await
1978            .assert_text(&"it works-yup");
1979    }
1980}
1981
1982#[cfg(test)]
1983mod test_add_raw_query_param {
1984    use axum::extract::Query as AxumStdQuery;
1985    use axum::routing::get;
1986    use axum::Router;
1987    use axum_extra::extract::Query as AxumExtraQuery;
1988    use serde::Deserialize;
1989    use serde::Serialize;
1990    use std::fmt::Write;
1991
1992    use crate::TestServer;
1993
1994    #[derive(Debug, Deserialize, Serialize)]
1995    struct QueryParam {
1996        message: String,
1997    }
1998
1999    async fn get_query_param(AxumStdQuery(params): AxumStdQuery<QueryParam>) -> String {
2000        params.message
2001    }
2002
2003    #[derive(Debug, Deserialize, Serialize)]
2004    struct QueryParamExtra {
2005        #[serde(default)]
2006        items: Vec<String>,
2007
2008        #[serde(default, rename = "arrs[]")]
2009        arrs: Vec<String>,
2010    }
2011
2012    async fn get_query_param_extra(
2013        AxumExtraQuery(params): AxumExtraQuery<QueryParamExtra>,
2014    ) -> String {
2015        let mut output = String::new();
2016
2017        if params.items.len() > 0 {
2018            write!(output, "{}", params.items.join(", ")).unwrap();
2019        }
2020
2021        if params.arrs.len() > 0 {
2022            write!(output, "{}", params.arrs.join(", ")).unwrap();
2023        }
2024
2025        output
2026    }
2027
2028    fn build_app() -> Router {
2029        Router::new()
2030            .route("/query", get(get_query_param))
2031            .route("/query-extra", get(get_query_param_extra))
2032    }
2033
2034    #[tokio::test]
2035    async fn it_should_pass_up_query_param_as_is() {
2036        // Run the server.
2037        let mut server = TestServer::new(build_app()).expect("Should create test server");
2038        server.add_raw_query_param(&"message=it-works");
2039
2040        // Get the request.
2041        server.get(&"/query").await.assert_text(&"it-works");
2042    }
2043
2044    #[tokio::test]
2045    async fn it_should_pass_up_array_query_params_as_one_string() {
2046        // Run the server.
2047        let mut server = TestServer::new(build_app()).expect("Should create test server");
2048        server.add_raw_query_param(&"items=one&items=two&items=three");
2049
2050        // Get the request.
2051        server
2052            .get(&"/query-extra")
2053            .await
2054            .assert_text(&"one, two, three");
2055    }
2056
2057    #[tokio::test]
2058    async fn it_should_pass_up_array_query_params_as_multiple_params() {
2059        // Run the server.
2060        let mut server = TestServer::new(build_app()).expect("Should create test server");
2061        server.add_raw_query_param(&"arrs[]=one");
2062        server.add_raw_query_param(&"arrs[]=two");
2063        server.add_raw_query_param(&"arrs[]=three");
2064
2065        // Get the request.
2066        server
2067            .get(&"/query-extra")
2068            .await
2069            .assert_text(&"one, two, three");
2070    }
2071}
2072
2073#[cfg(test)]
2074mod test_clear_query_params {
2075    use axum::extract::Query;
2076    use axum::routing::get;
2077    use axum::Router;
2078
2079    use serde::Deserialize;
2080    use serde::Serialize;
2081
2082    use crate::TestServer;
2083
2084    #[derive(Debug, Deserialize, Serialize)]
2085    struct QueryParams {
2086        first: Option<String>,
2087        second: Option<String>,
2088    }
2089
2090    async fn get_query_params(Query(params): Query<QueryParams>) -> String {
2091        format!(
2092            "has first? {}, has second? {}",
2093            params.first.is_some(),
2094            params.second.is_some()
2095        )
2096    }
2097
2098    #[tokio::test]
2099    async fn it_should_clear_all_params_set() {
2100        // Build an application with a route.
2101        let app = Router::new().route("/query", get(get_query_params));
2102
2103        // Run the server.
2104        let mut server = TestServer::new(app).expect("Should create test server");
2105        server.add_query_params(QueryParams {
2106            first: Some("first".to_string()),
2107            second: Some("second".to_string()),
2108        });
2109        server.clear_query_params();
2110
2111        // Get the request.
2112        server
2113            .get(&"/query")
2114            .await
2115            .assert_text(&"has first? false, has second? false");
2116    }
2117
2118    #[tokio::test]
2119    async fn it_should_clear_all_params_set_and_allow_replacement() {
2120        // Build an application with a route.
2121        let app = Router::new().route("/query", get(get_query_params));
2122
2123        // Run the server.
2124        let mut server = TestServer::new(app).expect("Should create test server");
2125        server.add_query_params(QueryParams {
2126            first: Some("first".to_string()),
2127            second: Some("second".to_string()),
2128        });
2129        server.clear_query_params();
2130        server.add_query_params(QueryParams {
2131            first: Some("first".to_string()),
2132            second: Some("second".to_string()),
2133        });
2134
2135        // Get the request.
2136        server
2137            .get(&"/query")
2138            .await
2139            .assert_text(&"has first? true, has second? true");
2140    }
2141}
2142
2143#[cfg(test)]
2144mod test_expect_success_by_default {
2145    use super::*;
2146
2147    use axum::routing::get;
2148    use axum::Router;
2149
2150    #[tokio::test]
2151    async fn it_should_not_panic_by_default_if_accessing_404_route() {
2152        let app = Router::new();
2153        let server = TestServer::new(app).expect("Should create test server");
2154
2155        server.get(&"/some_unknown_route").await;
2156    }
2157
2158    #[tokio::test]
2159    async fn it_should_not_panic_by_default_if_accessing_200_route() {
2160        let app = Router::new().route("/known_route", get(|| async { "🦊🦊🦊" }));
2161        let server = TestServer::new(app).expect("Should create test server");
2162
2163        server.get(&"/known_route").await;
2164    }
2165
2166    #[tokio::test]
2167    #[should_panic]
2168    async fn it_should_panic_by_default_if_accessing_404_route_and_expect_success_on() {
2169        let app = Router::new();
2170        let server = TestServer::builder()
2171            .expect_success_by_default()
2172            .build(app)
2173            .expect("Should create test server");
2174
2175        server.get(&"/some_unknown_route").await;
2176    }
2177
2178    #[tokio::test]
2179    async fn it_should_not_panic_by_default_if_accessing_200_route_and_expect_success_on() {
2180        let app = Router::new().route("/known_route", get(|| async { "🦊🦊🦊" }));
2181        let server = TestServer::builder()
2182            .expect_success_by_default()
2183            .build(app)
2184            .expect("Should create test server");
2185
2186        server.get(&"/known_route").await;
2187    }
2188}
2189
2190#[cfg(test)]
2191mod test_content_type {
2192    use super::*;
2193
2194    use axum::routing::get;
2195    use axum::Router;
2196    use http::header::CONTENT_TYPE;
2197    use http::HeaderMap;
2198
2199    async fn get_content_type(headers: HeaderMap) -> String {
2200        headers
2201            .get(CONTENT_TYPE)
2202            .map(|h| h.to_str().unwrap().to_string())
2203            .unwrap_or_else(|| "".to_string())
2204    }
2205
2206    #[tokio::test]
2207    async fn it_should_default_to_server_content_type_when_present() {
2208        // Build an application with a route.
2209        let app = Router::new().route("/content_type", get(get_content_type));
2210
2211        // Run the server.
2212        let server = TestServer::builder()
2213            .default_content_type("text/plain")
2214            .build(app)
2215            .expect("Should create test server");
2216
2217        // Get the request.
2218        let text = server.get(&"/content_type").await.text();
2219
2220        assert_eq!(text, "text/plain");
2221    }
2222}
2223
2224#[cfg(test)]
2225mod test_expect_success {
2226    use crate::TestServer;
2227    use axum::routing::get;
2228    use axum::Router;
2229    use http::StatusCode;
2230
2231    #[tokio::test]
2232    async fn it_should_not_panic_if_success_is_returned() {
2233        async fn get_ping() -> &'static str {
2234            "pong!"
2235        }
2236
2237        // Build an application with a route.
2238        let app = Router::new().route("/ping", get(get_ping));
2239
2240        // Run the server.
2241        let mut server = TestServer::new(app).expect("Should create test server");
2242        server.expect_success();
2243
2244        // Get the request.
2245        server.get(&"/ping").await;
2246    }
2247
2248    #[tokio::test]
2249    async fn it_should_not_panic_on_other_2xx_status_code() {
2250        async fn get_accepted() -> StatusCode {
2251            StatusCode::ACCEPTED
2252        }
2253
2254        // Build an application with a route.
2255        let app = Router::new().route("/accepted", get(get_accepted));
2256
2257        // Run the server.
2258        let mut server = TestServer::new(app).expect("Should create test server");
2259        server.expect_success();
2260
2261        // Get the request.
2262        server.get(&"/accepted").await;
2263    }
2264
2265    #[tokio::test]
2266    #[should_panic]
2267    async fn it_should_panic_on_404() {
2268        // Build an application with a route.
2269        let app = Router::new();
2270
2271        // Run the server.
2272        let mut server = TestServer::new(app).expect("Should create test server");
2273        server.expect_success();
2274
2275        // Get the request.
2276        server.get(&"/some_unknown_route").await;
2277    }
2278}
2279
2280#[cfg(test)]
2281mod test_expect_failure {
2282    use crate::TestServer;
2283    use axum::routing::get;
2284    use axum::Router;
2285    use http::StatusCode;
2286
2287    #[tokio::test]
2288    async fn it_should_not_panic_if_expect_failure_on_404() {
2289        // Build an application with a route.
2290        let app = Router::new();
2291
2292        // Run the server.
2293        let mut server = TestServer::new(app).expect("Should create test server");
2294        server.expect_failure();
2295
2296        // Get the request.
2297        server.get(&"/some_unknown_route").await;
2298    }
2299
2300    #[tokio::test]
2301    #[should_panic]
2302    async fn it_should_panic_if_success_is_returned() {
2303        async fn get_ping() -> &'static str {
2304            "pong!"
2305        }
2306
2307        // Build an application with a route.
2308        let app = Router::new().route("/ping", get(get_ping));
2309
2310        // Run the server.
2311        let mut server = TestServer::new(app).expect("Should create test server");
2312        server.expect_failure();
2313
2314        // Get the request.
2315        server.get(&"/ping").await;
2316    }
2317
2318    #[tokio::test]
2319    #[should_panic]
2320    async fn it_should_panic_on_other_2xx_status_code() {
2321        async fn get_accepted() -> StatusCode {
2322            StatusCode::ACCEPTED
2323        }
2324
2325        // Build an application with a route.
2326        let app = Router::new().route("/accepted", get(get_accepted));
2327
2328        // Run the server.
2329        let mut server = TestServer::new(app).expect("Should create test server");
2330        server.expect_failure();
2331
2332        // Get the request.
2333        server.get(&"/accepted").await;
2334    }
2335}
2336
2337#[cfg(test)]
2338mod test_scheme {
2339    use axum::extract::Request;
2340    use axum::routing::get;
2341    use axum::Router;
2342
2343    use crate::TestServer;
2344
2345    async fn route_get_scheme(request: Request) -> String {
2346        request.uri().scheme_str().unwrap().to_string()
2347    }
2348
2349    #[tokio::test]
2350    async fn it_should_return_http_by_default() {
2351        let router = Router::new().route("/scheme", get(route_get_scheme));
2352        let server = TestServer::builder().build(router).unwrap();
2353
2354        server.get("/scheme").await.assert_text("http");
2355    }
2356
2357    #[tokio::test]
2358    async fn it_should_return_https_across_multiple_requests_when_set() {
2359        let router = Router::new().route("/scheme", get(route_get_scheme));
2360        let mut server = TestServer::builder().build(router).unwrap();
2361        server.scheme(&"https");
2362
2363        server.get("/scheme").await.assert_text("https");
2364    }
2365}
2366
2367#[cfg(feature = "typed-routing")]
2368#[cfg(test)]
2369mod test_typed_get {
2370    use super::*;
2371
2372    use axum::Router;
2373    use axum_extra::routing::RouterExt;
2374    use serde::Deserialize;
2375
2376    #[derive(TypedPath, Deserialize)]
2377    #[typed_path("/path/{id}")]
2378    struct TestingPath {
2379        id: u32,
2380    }
2381
2382    async fn route_get(TestingPath { id }: TestingPath) -> String {
2383        format!("get {id}")
2384    }
2385
2386    fn new_app() -> Router {
2387        Router::new().typed_get(route_get)
2388    }
2389
2390    #[tokio::test]
2391    async fn it_should_send_get() {
2392        let server = TestServer::new(new_app()).unwrap();
2393
2394        server
2395            .typed_get(&TestingPath { id: 123 })
2396            .await
2397            .assert_text("get 123");
2398    }
2399}
2400
2401#[cfg(feature = "typed-routing")]
2402#[cfg(test)]
2403mod test_typed_post {
2404    use super::*;
2405
2406    use axum::Router;
2407    use axum_extra::routing::RouterExt;
2408    use serde::Deserialize;
2409
2410    #[derive(TypedPath, Deserialize)]
2411    #[typed_path("/path/{id}")]
2412    struct TestingPath {
2413        id: u32,
2414    }
2415
2416    async fn route_post(TestingPath { id }: TestingPath) -> String {
2417        format!("post {id}")
2418    }
2419
2420    fn new_app() -> Router {
2421        Router::new().typed_post(route_post)
2422    }
2423
2424    #[tokio::test]
2425    async fn it_should_send_post() {
2426        let server = TestServer::new(new_app()).unwrap();
2427
2428        server
2429            .typed_post(&TestingPath { id: 123 })
2430            .await
2431            .assert_text("post 123");
2432    }
2433}
2434
2435#[cfg(feature = "typed-routing")]
2436#[cfg(test)]
2437mod test_typed_patch {
2438    use super::*;
2439
2440    use axum::Router;
2441    use axum_extra::routing::RouterExt;
2442    use serde::Deserialize;
2443
2444    #[derive(TypedPath, Deserialize)]
2445    #[typed_path("/path/{id}")]
2446    struct TestingPath {
2447        id: u32,
2448    }
2449
2450    async fn route_patch(TestingPath { id }: TestingPath) -> String {
2451        format!("patch {id}")
2452    }
2453
2454    fn new_app() -> Router {
2455        Router::new().typed_patch(route_patch)
2456    }
2457
2458    #[tokio::test]
2459    async fn it_should_send_patch() {
2460        let server = TestServer::new(new_app()).unwrap();
2461
2462        server
2463            .typed_patch(&TestingPath { id: 123 })
2464            .await
2465            .assert_text("patch 123");
2466    }
2467}
2468
2469#[cfg(feature = "typed-routing")]
2470#[cfg(test)]
2471mod test_typed_put {
2472    use super::*;
2473
2474    use axum::Router;
2475    use axum_extra::routing::RouterExt;
2476    use serde::Deserialize;
2477
2478    #[derive(TypedPath, Deserialize)]
2479    #[typed_path("/path/{id}")]
2480    struct TestingPath {
2481        id: u32,
2482    }
2483
2484    async fn route_put(TestingPath { id }: TestingPath) -> String {
2485        format!("put {id}")
2486    }
2487
2488    fn new_app() -> Router {
2489        Router::new().typed_put(route_put)
2490    }
2491
2492    #[tokio::test]
2493    async fn it_should_send_put() {
2494        let server = TestServer::new(new_app()).unwrap();
2495
2496        server
2497            .typed_put(&TestingPath { id: 123 })
2498            .await
2499            .assert_text("put 123");
2500    }
2501}
2502
2503#[cfg(feature = "typed-routing")]
2504#[cfg(test)]
2505mod test_typed_delete {
2506    use super::*;
2507
2508    use axum::Router;
2509    use axum_extra::routing::RouterExt;
2510    use serde::Deserialize;
2511
2512    #[derive(TypedPath, Deserialize)]
2513    #[typed_path("/path/{id}")]
2514    struct TestingPath {
2515        id: u32,
2516    }
2517
2518    async fn route_delete(TestingPath { id }: TestingPath) -> String {
2519        format!("delete {id}")
2520    }
2521
2522    fn new_app() -> Router {
2523        Router::new().typed_delete(route_delete)
2524    }
2525
2526    #[tokio::test]
2527    async fn it_should_send_delete() {
2528        let server = TestServer::new(new_app()).unwrap();
2529
2530        server
2531            .typed_delete(&TestingPath { id: 123 })
2532            .await
2533            .assert_text("delete 123");
2534    }
2535}
2536
2537#[cfg(feature = "typed-routing")]
2538#[cfg(test)]
2539mod test_typed_method {
2540    use super::*;
2541
2542    use axum::Router;
2543    use axum_extra::routing::RouterExt;
2544    use serde::Deserialize;
2545
2546    #[derive(TypedPath, Deserialize)]
2547    #[typed_path("/path/{id}")]
2548    struct TestingPath {
2549        id: u32,
2550    }
2551
2552    async fn route_get(TestingPath { id }: TestingPath) -> String {
2553        format!("get {id}")
2554    }
2555
2556    async fn route_post(TestingPath { id }: TestingPath) -> String {
2557        format!("post {id}")
2558    }
2559
2560    async fn route_patch(TestingPath { id }: TestingPath) -> String {
2561        format!("patch {id}")
2562    }
2563
2564    async fn route_put(TestingPath { id }: TestingPath) -> String {
2565        format!("put {id}")
2566    }
2567
2568    async fn route_delete(TestingPath { id }: TestingPath) -> String {
2569        format!("delete {id}")
2570    }
2571
2572    fn new_app() -> Router {
2573        Router::new()
2574            .typed_get(route_get)
2575            .typed_post(route_post)
2576            .typed_patch(route_patch)
2577            .typed_put(route_put)
2578            .typed_delete(route_delete)
2579    }
2580
2581    #[tokio::test]
2582    async fn it_should_send_get() {
2583        let server = TestServer::new(new_app()).unwrap();
2584
2585        server
2586            .typed_method(Method::GET, &TestingPath { id: 123 })
2587            .await
2588            .assert_text("get 123");
2589    }
2590
2591    #[tokio::test]
2592    async fn it_should_send_post() {
2593        let server = TestServer::new(new_app()).unwrap();
2594
2595        server
2596            .typed_method(Method::POST, &TestingPath { id: 123 })
2597            .await
2598            .assert_text("post 123");
2599    }
2600
2601    #[tokio::test]
2602    async fn it_should_send_patch() {
2603        let server = TestServer::new(new_app()).unwrap();
2604
2605        server
2606            .typed_method(Method::PATCH, &TestingPath { id: 123 })
2607            .await
2608            .assert_text("patch 123");
2609    }
2610
2611    #[tokio::test]
2612    async fn it_should_send_put() {
2613        let server = TestServer::new(new_app()).unwrap();
2614
2615        server
2616            .typed_method(Method::PUT, &TestingPath { id: 123 })
2617            .await
2618            .assert_text("put 123");
2619    }
2620
2621    #[tokio::test]
2622    async fn it_should_send_delete() {
2623        let server = TestServer::new(new_app()).unwrap();
2624
2625        server
2626            .typed_method(Method::DELETE, &TestingPath { id: 123 })
2627            .await
2628            .assert_text("delete 123");
2629    }
2630}
2631
2632#[cfg(test)]
2633mod test_sync {
2634    use super::*;
2635    use axum::routing::get;
2636    use axum::Router;
2637    use std::cell::OnceCell;
2638
2639    #[tokio::test]
2640    async fn it_should_be_able_to_be_in_one_cell() {
2641        let cell: OnceCell<TestServer> = OnceCell::new();
2642        let server = cell.get_or_init(|| {
2643            async fn route_get() -> &'static str {
2644                "it works"
2645            }
2646
2647            let router = Router::new().route("/test", get(route_get));
2648
2649            TestServer::new(router).unwrap()
2650        });
2651
2652        server.get("/test").await.assert_text("it works");
2653    }
2654}
2655
2656#[cfg(test)]
2657mod test_is_running {
2658    use super::*;
2659    use crate::util::new_random_tokio_tcp_listener;
2660    use axum::routing::get;
2661    use axum::routing::IntoMakeService;
2662    use axum::serve;
2663    use axum::Router;
2664    use std::time::Duration;
2665    use tokio::sync::Notify;
2666    use tokio::time::sleep;
2667
2668    async fn get_ping() -> &'static str {
2669        "pong!"
2670    }
2671
2672    #[tokio::test]
2673    #[should_panic]
2674    async fn it_should_panic_when_run_with_mock_http() {
2675        let shutdown_notification = Arc::new(Notify::new());
2676        let waiting_notification = shutdown_notification.clone();
2677
2678        // Build an application with a route.
2679        let app: IntoMakeService<Router> = Router::new()
2680            .route("/ping", get(get_ping))
2681            .into_make_service();
2682        let port = new_random_tokio_tcp_listener().unwrap();
2683        let application = serve(port, app)
2684            .with_graceful_shutdown(async move { waiting_notification.notified().await });
2685
2686        // Run the server.
2687        let server = TestServer::builder()
2688            .build(application)
2689            .expect("Should create test server");
2690
2691        server.get("/ping").await.assert_status_ok();
2692        assert!(server.is_running());
2693
2694        shutdown_notification.notify_one();
2695        sleep(Duration::from_millis(10)).await;
2696
2697        assert!(!server.is_running());
2698        server.get("/ping").await.assert_status_ok();
2699    }
2700}