axum_test/
test_server.rs

1use anyhow::Context;
2use anyhow::Result;
3use anyhow::anyhow;
4use cookie::Cookie;
5use cookie::CookieJar;
6use http::HeaderName;
7use http::HeaderValue;
8use http::Method;
9use http::Uri;
10use serde::Serialize;
11use std::fmt::Debug;
12use std::sync::Arc;
13use std::sync::Mutex;
14use url::Url;
15
16#[cfg(feature = "typed-routing")]
17use axum_extra::routing::TypedPath;
18
19#[cfg(feature = "reqwest")]
20use crate::transport_layer::TransportLayerType;
21#[cfg(feature = "reqwest")]
22use reqwest::Client;
23#[cfg(feature = "reqwest")]
24use reqwest::RequestBuilder;
25
26use crate::TestRequest;
27use crate::TestRequestConfig;
28use crate::TestServerBuilder;
29use crate::TestServerConfig;
30use crate::Transport;
31use crate::internals::ExpectedState;
32use crate::internals::QueryParamsStore;
33use crate::internals::RequestPathFormatter;
34use crate::transport_layer::IntoTransportLayer;
35use crate::transport_layer::TransportLayer;
36use crate::transport_layer::TransportLayerBuilder;
37
38mod server_shared_state;
39pub(crate) use self::server_shared_state::*;
40
41const DEFAULT_URL_ADDRESS: &str = "http://localhost";
42
43///
44/// The `TestServer` runs your Axum application,
45/// allowing you to make HTTP requests against it.
46///
47/// # Building
48///
49/// A `TestServer` can be used to run an [`axum::Router`], an [`::axum::routing::IntoMakeService`],
50/// a [`shuttle_axum::ShuttleAxum`], and others.
51///
52/// The most straight forward approach is to call [`crate::TestServer::new`],
53/// and pass in your application:
54///
55/// ```rust
56/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
57/// #
58/// use axum::Router;
59/// use axum::routing::get;
60///
61/// use axum_test::TestServer;
62///
63/// let app = Router::new()
64///     .route(&"/hello", get(|| async { "hello!" }));
65///
66/// let server = TestServer::new(app)?;
67/// #
68/// # Ok(())
69/// # }
70/// ```
71///
72/// # Requests
73///
74/// Requests are built by calling [`TestServer::get()`](crate::TestServer::get()),
75/// [`TestServer::post()`](crate::TestServer::post()), [`TestServer::put()`](crate::TestServer::put()),
76/// [`TestServer::delete()`](crate::TestServer::delete()), and [`TestServer::patch()`](crate::TestServer::patch()) methods.
77/// Each returns a [`TestRequest`](crate::TestRequest), which allows for customising the request content.
78///
79/// For example:
80///
81/// ```rust
82/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
83/// #
84/// use axum::Router;
85/// use axum::routing::get;
86///
87/// use axum_test::TestServer;
88///
89/// let app = Router::new()
90///     .route(&"/hello", get(|| async { "hello!" }));
91///
92/// let server = TestServer::new(app)?;
93///
94/// let response = server.get(&"/hello")
95///     .authorization_bearer("password12345")
96///     .add_header("x-custom-header", "custom-value")
97///     .await;
98///
99/// response.assert_text("hello!");
100/// #
101/// # Ok(())
102/// # }
103/// ```
104///
105/// Request methods also exist for using Axum Extra [`axum_extra::routing::TypedPath`],
106/// or for building Reqwest [`reqwest::RequestBuilder`]. See those methods for detauls.
107///
108/// # Customising
109///
110/// A `TestServer` can be built from a builder, by calling [`crate::TestServer::builder`],
111/// and customising settings. This allows one to set **mocked** (default when possible)
112/// or **real http** networking for your service.
113///
114/// ```rust
115/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
116/// #
117/// use axum::Router;
118/// use axum::routing::get;
119///
120/// use axum_test::TestServer;
121///
122/// let app = Router::new()
123///     .route(&"/hello", get(|| async { "hello!" }));
124///
125/// // Customise server when building
126/// let mut server = TestServer::builder()
127///     .http_transport()
128///     .expect_success_by_default()
129///     .save_cookies()
130///     .build(app)?;
131///
132/// // Add items to be sent on _all_ all requests
133/// server.add_header("x-custom-for-all", "common-value");
134///
135/// let response = server.get("/hello").await;
136/// #
137/// # Ok(())
138/// # }
139/// ```
140///
141#[derive(Debug)]
142pub struct TestServer {
143    state: Arc<Mutex<ServerSharedState>>,
144    transport: Arc<Box<dyn TransportLayer>>,
145    save_cookies: bool,
146    expected_state: ExpectedState,
147    default_content_type: Option<String>,
148    is_http_path_restricted: bool,
149
150    #[cfg(feature = "reqwest")]
151    maybe_reqwest_client: Option<Client>,
152}
153
154impl TestServer {
155    /// A helper function to create a builder for creating a [`TestServer`].
156    pub fn builder() -> TestServerBuilder {
157        TestServerBuilder::default()
158    }
159
160    /// This will run the given Axum app,
161    /// allowing you to make requests against it.
162    ///
163    /// This is the same as creating a new `TestServer` with a configuration,
164    /// and passing [`crate::TestServerConfig::default()`].
165    ///
166    /// ```rust
167    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
168    /// #
169    /// use axum::Router;
170    /// use axum::routing::get;
171    /// use axum_test::TestServer;
172    ///
173    /// let app = Router::new()
174    ///     .route(&"/hello", get(|| async { "hello!" }));
175    ///
176    /// let server = TestServer::new(app)?;
177    /// #
178    /// # Ok(())
179    /// # }
180    /// ```
181    ///
182    /// The of applications that can be passed in include:
183    ///
184    ///  - [`axum::Router`]
185    ///  - [`axum::routing::IntoMakeService`]
186    ///  - [`axum::extract::connect_info::IntoMakeServiceWithConnectInfo`]
187    ///  - [`axum::serve::Serve`]
188    ///  - [`axum::serve::WithGracefulShutdown`]
189    ///  - [`shuttle_axum::ShuttleAxum`]
190    ///
191    pub fn new<A>(app: A) -> Result<Self>
192    where
193        A: IntoTransportLayer,
194    {
195        Self::new_with_config(app, TestServerConfig::default())
196    }
197
198    /// Similar to [`TestServer::new()`], with a customised configuration.
199    /// This includes type of transport in use (i.e. specify a specific port),
200    /// or change default settings (like the default content type for requests).
201    ///
202    /// This can take a [`crate::TestServerConfig`] or a [`crate::TestServerBuilder`].
203    /// See those for more information on configuration settings.
204    pub fn new_with_config<A, C>(app: A, config: C) -> Result<Self>
205    where
206        A: IntoTransportLayer,
207        C: Into<TestServerConfig>,
208    {
209        let config = config.into();
210        let mut shared_state = ServerSharedState::new();
211        if let Some(scheme) = config.default_scheme {
212            shared_state.set_scheme_unlocked(scheme);
213        }
214
215        let shared_state_mutex = Mutex::new(shared_state);
216        let state = Arc::new(shared_state_mutex);
217
218        let transport = match config.transport {
219            None => {
220                let builder = TransportLayerBuilder::new(None, None);
221                let transport = app.into_default_transport(builder)?;
222                Arc::new(transport)
223            }
224            Some(Transport::HttpRandomPort) => {
225                let builder = TransportLayerBuilder::new(None, None);
226                let transport = app.into_http_transport_layer(builder)?;
227                Arc::new(transport)
228            }
229            Some(Transport::HttpIpPort { ip, port }) => {
230                let builder = TransportLayerBuilder::new(ip, port);
231                let transport = app.into_http_transport_layer(builder)?;
232                Arc::new(transport)
233            }
234            Some(Transport::MockHttp) => {
235                let transport = app.into_mock_transport_layer()?;
236                Arc::new(transport)
237            }
238        };
239
240        let expected_state = match config.expect_success_by_default {
241            true => ExpectedState::Success,
242            false => ExpectedState::None,
243        };
244
245        #[cfg(feature = "reqwest")]
246        let maybe_reqwest_client = match transport.transport_layer_type() {
247            TransportLayerType::Http => {
248                let reqwest_client = reqwest::Client::builder()
249                    .redirect(reqwest::redirect::Policy::none())
250                    .cookie_store(config.save_cookies)
251                    .build()
252                    .expect("Failed to build Reqwest Client");
253
254                Some(reqwest_client)
255            }
256            TransportLayerType::Mock => None,
257        };
258
259        Ok(Self {
260            state,
261            transport,
262            save_cookies: config.save_cookies,
263            expected_state,
264            default_content_type: config.default_content_type,
265            is_http_path_restricted: config.restrict_requests_with_http_schema,
266
267            #[cfg(feature = "reqwest")]
268            maybe_reqwest_client,
269        })
270    }
271
272    /// Creates a HTTP GET request to the path.
273    pub fn get(&self, path: &str) -> TestRequest {
274        self.method(Method::GET, path)
275    }
276
277    /// Creates a HTTP POST request to the given path.
278    pub fn post(&self, path: &str) -> TestRequest {
279        self.method(Method::POST, path)
280    }
281
282    /// Creates a HTTP PATCH request to the path.
283    pub fn patch(&self, path: &str) -> TestRequest {
284        self.method(Method::PATCH, path)
285    }
286
287    /// Creates a HTTP PUT request to the path.
288    pub fn put(&self, path: &str) -> TestRequest {
289        self.method(Method::PUT, path)
290    }
291
292    /// Creates a HTTP DELETE request to the path.
293    pub fn delete(&self, path: &str) -> TestRequest {
294        self.method(Method::DELETE, path)
295    }
296
297    /// Creates a HTTP request, to the method and path provided.
298    pub fn method(&self, method: Method, path: &str) -> TestRequest {
299        let maybe_config = self.build_test_request_config(method.clone(), path);
300        let config = maybe_config
301            .with_context(|| format!("Failed to build, for request {method} {path}"))
302            .unwrap();
303
304        TestRequest::new(self.state.clone(), self.transport.clone(), config)
305    }
306
307    #[cfg(feature = "reqwest")]
308    fn reqwest_client(&self) -> &Client {
309        self.maybe_reqwest_client
310            .as_ref()
311            .expect("Reqwest client is not available, TestServer must be build with HTTP transport for Reqwest to be available")
312    }
313
314    #[cfg(feature = "reqwest")]
315    pub fn reqwest_get(&self, path: &str) -> RequestBuilder {
316        self.reqwest_method(Method::GET, path)
317    }
318
319    #[cfg(feature = "reqwest")]
320    pub fn reqwest_post(&self, path: &str) -> RequestBuilder {
321        self.reqwest_method(Method::POST, path)
322    }
323
324    #[cfg(feature = "reqwest")]
325    pub fn reqwest_put(&self, path: &str) -> RequestBuilder {
326        self.reqwest_method(Method::PUT, path)
327    }
328
329    #[cfg(feature = "reqwest")]
330    pub fn reqwest_patch(&self, path: &str) -> RequestBuilder {
331        self.reqwest_method(Method::PATCH, path)
332    }
333
334    #[cfg(feature = "reqwest")]
335    pub fn reqwest_delete(&self, path: &str) -> RequestBuilder {
336        self.reqwest_method(Method::DELETE, path)
337    }
338
339    #[cfg(feature = "reqwest")]
340    pub fn reqwest_head(&self, path: &str) -> RequestBuilder {
341        self.reqwest_method(Method::HEAD, path)
342    }
343
344    /// Creates a HTTP request, using Reqwest, using the method + path described.
345    /// This expects a relative url to the `TestServer`.
346    ///
347    /// ```rust
348    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
349    /// #
350    /// use axum::Router;
351    /// use axum_test::TestServer;
352    ///
353    /// let my_app = Router::new();
354    /// let server = TestServer::builder()
355    ///     .http_transport() // Important, must be HTTP!
356    ///     .build(my_app)?;
357    ///
358    /// // Build your request
359    /// let request = server.get(&"/user")
360    ///     .add_header("x-custom-header", "example.com")
361    ///     .content_type("application/yaml");
362    ///
363    /// // await request to execute
364    /// let response = request.await;
365    /// #
366    /// # Ok(()) }
367    /// ```
368    #[cfg(feature = "reqwest")]
369    pub fn reqwest_method(&self, method: Method, path: &str) -> RequestBuilder {
370        let request_url = self
371            .server_url(path)
372            .expect("Failed to generate server url for request {method} {path}");
373
374        self.reqwest_client().request(method, request_url)
375    }
376
377    /// Creates a request to the server, to start a Websocket connection,
378    /// on the path given.
379    ///
380    /// This is the requivalent of making a GET request to the endpoint,
381    /// and setting the various headers needed for making an upgrade request.
382    ///
383    /// *Note*, this requires the server to be running on a real HTTP
384    /// port. Either using a randomly assigned port, or a specified one.
385    /// See the [`TestServerConfig::transport`](crate::TestServerConfig::transport) for more details.
386    ///
387    /// # Example
388    ///
389    /// ```rust
390    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
391    /// #
392    /// use axum::Router;
393    /// use axum_test::TestServer;
394    ///
395    /// let app = Router::new();
396    /// let server = TestServer::builder()
397    ///     .http_transport()
398    ///     .build(app)?;
399    ///
400    /// let mut websocket = server
401    ///     .get_websocket(&"/my-web-socket-end-point")
402    ///     .await
403    ///     .into_websocket()
404    ///     .await;
405    ///
406    /// websocket.send_text("Hello!").await;
407    /// #
408    /// # Ok(()) }
409    /// ```
410    ///
411    #[cfg(feature = "ws")]
412    pub fn get_websocket(&self, path: &str) -> TestRequest {
413        use http::header;
414
415        self.get(path)
416            .add_header(header::CONNECTION, "upgrade")
417            .add_header(header::UPGRADE, "websocket")
418            .add_header(header::SEC_WEBSOCKET_VERSION, "13")
419            .add_header(
420                header::SEC_WEBSOCKET_KEY,
421                crate::internals::generate_ws_key(),
422            )
423    }
424
425    /// Creates a HTTP GET request, using the typed path provided.
426    ///
427    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
428    ///
429    /// # Example Test
430    ///
431    /// Using a `TypedPath` you can write build and test a route like below:
432    ///
433    /// ```rust
434    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
435    /// #
436    /// use axum::Json;
437    /// use axum::Router;
438    /// use axum::routing::get;
439    /// use axum_extra::routing::RouterExt;
440    /// use axum_extra::routing::TypedPath;
441    /// use serde::Deserialize;
442    /// use serde::Serialize;
443    ///
444    /// use axum_test::TestServer;
445    ///
446    /// #[derive(TypedPath, Deserialize)]
447    /// #[typed_path("/users/{user_id}")]
448    /// struct UserPath {
449    ///     pub user_id: u32,
450    /// }
451    ///
452    /// // Build a typed route:
453    /// async fn route_get_user(UserPath { user_id }: UserPath) -> String {
454    ///     format!("hello user {user_id}")
455    /// }
456    ///
457    /// let app = Router::new()
458    ///     .typed_get(route_get_user);
459    ///
460    /// // Then test the route:
461    /// let server = TestServer::new(app)?;
462    /// server
463    ///     .typed_get(&UserPath { user_id: 123 })
464    ///     .await
465    ///     .assert_text("hello user 123");
466    /// #
467    /// # Ok(())
468    /// # }
469    /// ```
470    ///
471    #[cfg(feature = "typed-routing")]
472    pub fn typed_get<P>(&self, path: &P) -> TestRequest
473    where
474        P: TypedPath,
475    {
476        self.typed_method(Method::GET, path)
477    }
478
479    /// Creates a HTTP POST request, using the typed path provided.
480    ///
481    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
482    #[cfg(feature = "typed-routing")]
483    pub fn typed_post<P>(&self, path: &P) -> TestRequest
484    where
485        P: TypedPath,
486    {
487        self.typed_method(Method::POST, path)
488    }
489
490    /// Creates a HTTP PATCH request, using the typed path provided.
491    ///
492    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
493    #[cfg(feature = "typed-routing")]
494    pub fn typed_patch<P>(&self, path: &P) -> TestRequest
495    where
496        P: TypedPath,
497    {
498        self.typed_method(Method::PATCH, path)
499    }
500
501    /// Creates a HTTP PUT request, using the typed path provided.
502    ///
503    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
504    #[cfg(feature = "typed-routing")]
505    pub fn typed_put<P>(&self, path: &P) -> TestRequest
506    where
507        P: TypedPath,
508    {
509        self.typed_method(Method::PUT, path)
510    }
511
512    /// Creates a HTTP DELETE request, using the typed path provided.
513    ///
514    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
515    #[cfg(feature = "typed-routing")]
516    pub fn typed_delete<P>(&self, path: &P) -> TestRequest
517    where
518        P: TypedPath,
519    {
520        self.typed_method(Method::DELETE, path)
521    }
522
523    /// Creates a typed HTTP request, using the method provided.
524    ///
525    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
526    #[cfg(feature = "typed-routing")]
527    pub fn typed_method<P>(&self, method: Method, path: &P) -> TestRequest
528    where
529        P: TypedPath,
530    {
531        self.method(method, &path.to_string())
532    }
533
534    /// Returns the local web address for the test server,
535    /// if an address is available.
536    ///
537    /// The address is available when running as a real web server,
538    /// by setting the [`TestServerConfig`](crate::TestServerConfig) `transport` field to `Transport::HttpRandomPort` or `Transport::HttpIpPort`.
539    ///
540    /// This will return `None` when there is mock HTTP transport (the default).
541    pub fn server_address(&self) -> Option<Url> {
542        self.url()
543    }
544
545    /// This turns a relative path, into an absolute path to the server.
546    /// i.e. A path like `/users/123` will become something like `http://127.0.0.1:1234/users/123`.
547    ///
548    /// The absolute address can be used to make requests to the running server,
549    /// using any appropriate client you wish.
550    ///
551    /// # Example
552    ///
553    /// ```rust
554    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
555    /// #
556    /// use axum::Router;
557    /// use axum_test::TestServer;
558    ///
559    /// let app = Router::new();
560    /// let server = TestServer::builder()
561    ///         .http_transport()
562    ///         .build(app)?;
563    ///
564    /// let full_url = server.server_url(&"/users/123?filter=enabled")?;
565    ///
566    /// // Prints something like ... http://127.0.0.1:1234/users/123?filter=enabled
567    /// println!("{full_url}");
568    /// #
569    /// # Ok(()) }
570    /// ```
571    ///
572    /// This will return an error if you are using the mock transport.
573    /// Real HTTP transport is required to use this method (see [`TestServerConfig`](crate::TestServerConfig) `transport` field).
574    ///
575    /// It will also return an error if you provide an absolute path,
576    /// for example if you pass in `http://google.com`.
577    pub fn server_url(&self, path: &str) -> Result<Url> {
578        let path_uri = path.parse::<Uri>()?;
579        if is_absolute_uri(&path_uri) {
580            return Err(anyhow!(
581                "Absolute path provided for building server url, need to provide a relative uri"
582            ));
583        }
584
585        let server_url = self.url()
586            .ok_or_else(||
587                anyhow!(
588                    "No local address for server, need to run with HTTP transport to have a server address",
589                )
590            )?;
591
592        let server_locked = self.state.as_ref().lock().map_err(|err| {
593            anyhow!("Failed to lock InternalTestServer, for building server_url, received {err:?}",)
594        })?;
595        let mut query_params = server_locked.query_params().clone();
596        let mut full_server_url = build_url(
597            server_url,
598            path,
599            &mut query_params,
600            self.is_http_path_restricted,
601        )?;
602
603        // Ensure the query params are present
604        if query_params.has_content() {
605            full_server_url.set_query(Some(&query_params.to_string()));
606        }
607
608        Ok(full_server_url)
609    }
610
611    /// Adds a single cookie to be included on *all* future requests.
612    ///
613    /// If a cookie with the same name already exists,
614    /// then it will be replaced.
615    pub fn add_cookie(&mut self, cookie: Cookie) {
616        ServerSharedState::add_cookie(&self.state, cookie)
617            .context("Trying to call add_cookie")
618            .unwrap()
619    }
620
621    /// Adds extra cookies to be used on *all* future requests.
622    ///
623    /// Any cookies which have the same name as the new cookies,
624    /// will get replaced.
625    pub fn add_cookies(&mut self, cookies: CookieJar) {
626        ServerSharedState::add_cookies(&self.state, cookies)
627            .context("Trying to call add_cookies")
628            .unwrap()
629    }
630
631    /// Clears all of the cookies stored internally.
632    pub fn clear_cookies(&mut self) {
633        ServerSharedState::clear_cookies(&self.state)
634            .context("Trying to call clear_cookies")
635            .unwrap()
636    }
637
638    /// Requests made using this `TestServer` will save their cookies for future requests to send.
639    ///
640    /// This behaviour is off by default.
641    pub fn save_cookies(&mut self) {
642        self.save_cookies = true;
643    }
644
645    /// Requests made using this `TestServer` will _not_ save their cookies for future requests to send up.
646    ///
647    /// This is the default behaviour.
648    pub fn do_not_save_cookies(&mut self) {
649        self.save_cookies = false;
650    }
651
652    /// Requests made using this `TestServer` will assert a HTTP status in the 2xx range will be returned, unless marked otherwise.
653    ///
654    /// By default this behaviour is off.
655    pub fn expect_success(&mut self) {
656        self.expected_state = ExpectedState::Success;
657    }
658
659    /// Requests made using this `TestServer` will assert a HTTP status is outside the 2xx range will be returned, unless marked otherwise.
660    ///
661    /// By default this behaviour is off.
662    pub fn expect_failure(&mut self) {
663        self.expected_state = ExpectedState::Failure;
664    }
665
666    /// Adds a query parameter to be sent on *all* future requests.
667    pub fn add_query_param<V>(&mut self, key: &str, value: V)
668    where
669        V: Serialize,
670    {
671        ServerSharedState::add_query_param(&self.state, key, value)
672            .context("Trying to call add_query_param")
673            .unwrap()
674    }
675
676    /// Adds query parameters to be sent on *all* future requests.
677    pub fn add_query_params<V>(&mut self, query_params: V)
678    where
679        V: Serialize,
680    {
681        ServerSharedState::add_query_params(&self.state, query_params)
682            .context("Trying to call add_query_params")
683            .unwrap()
684    }
685
686    /// Adds a raw query param, with no urlencoding of any kind,
687    /// to be send on *all* future requests.
688    pub fn add_raw_query_param(&mut self, raw_query_param: &str) {
689        ServerSharedState::add_raw_query_param(&self.state, raw_query_param)
690            .context("Trying to call add_raw_query_param")
691            .unwrap()
692    }
693
694    /// Clears all query params set.
695    pub fn clear_query_params(&mut self) {
696        ServerSharedState::clear_query_params(&self.state)
697            .context("Trying to call clear_query_params")
698            .unwrap()
699    }
700
701    /// Adds a header to be sent with all future requests built from this `TestServer`.
702    ///
703    /// ```rust
704    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
705    /// #
706    /// use axum::Router;
707    /// use axum_test::TestServer;
708    ///
709    /// let app = Router::new();
710    /// let mut server = TestServer::new(app)?;
711    ///
712    /// server.add_header("x-custom-header", "custom-value");
713    /// server.add_header(http::header::CONTENT_LENGTH, 12345);
714    /// server.add_header(http::header::HOST, "example.com");
715    ///
716    /// let response = server.get(&"/my-end-point")
717    ///     .await;
718    /// #
719    /// # Ok(()) }
720    /// ```
721    pub fn add_header<N, V>(&mut self, name: N, value: V)
722    where
723        N: TryInto<HeaderName>,
724        N::Error: Debug,
725        V: TryInto<HeaderValue>,
726        V::Error: Debug,
727    {
728        let header_name: HeaderName = name
729            .try_into()
730            .expect("Failed to convert header name to HeaderName");
731        let header_value: HeaderValue = value
732            .try_into()
733            .expect("Failed to convert header vlue to HeaderValue");
734
735        ServerSharedState::add_header(&self.state, header_name, header_value)
736            .context("Trying to call add_header")
737            .unwrap()
738    }
739
740    /// Clears all headers set so far.
741    pub fn clear_headers(&mut self) {
742        ServerSharedState::clear_headers(&self.state)
743            .context("Trying to call clear_headers")
744            .unwrap()
745    }
746
747    /// Sets the scheme to use when making _all_ requests from the `TestServer`.
748    /// i.e. http or https.
749    ///
750    /// The default scheme is 'http'.
751    ///
752    /// ```rust
753    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
754    /// #
755    /// use axum::Router;
756    /// use axum_test::TestServer;
757    ///
758    /// let app = Router::new();
759    /// let mut server = TestServer::new(app)?;
760    /// server
761    ///     .scheme(&"https");
762    ///
763    /// let response = server
764    ///     .get(&"/my-end-point")
765    ///     .await;
766    /// #
767    /// # Ok(()) }
768    /// ```
769    ///
770    pub fn scheme(&mut self, scheme: &str) {
771        ServerSharedState::set_scheme(&self.state, scheme.to_string())
772            .context("Trying to call set_scheme")
773            .unwrap()
774    }
775
776    pub(crate) fn url(&self) -> Option<Url> {
777        self.transport.url().cloned()
778    }
779
780    pub(crate) fn build_test_request_config(
781        &self,
782        method: Method,
783        path: &str,
784    ) -> Result<TestRequestConfig> {
785        let url = self
786            .url()
787            .unwrap_or_else(|| DEFAULT_URL_ADDRESS.parse().unwrap());
788
789        let server_locked = self.state.as_ref().lock().map_err(|err| {
790            anyhow!(
791                "Failed to lock InternalTestServer, for request {method} {path}, received {err:?}",
792            )
793        })?;
794
795        let cookies = server_locked.cookies().clone();
796        let mut query_params = server_locked.query_params().clone();
797        let headers = server_locked.headers().clone();
798        let mut full_request_url =
799            build_url(url, path, &mut query_params, self.is_http_path_restricted)?;
800
801        if let Some(scheme) = server_locked.scheme() {
802            full_request_url.set_scheme(scheme).map_err(|_| {
803                let debug_request_format = RequestPathFormatter::new(&method, full_request_url.as_str(), Some(&query_params));
804                anyhow!("Scheme '{scheme}' from TestServer cannot be set to request {debug_request_format}")
805            })?;
806        }
807
808        ::std::mem::drop(server_locked);
809
810        Ok(TestRequestConfig {
811            is_saving_cookies: self.save_cookies,
812            expected_state: self.expected_state,
813            content_type: self.default_content_type.clone(),
814            method,
815
816            full_request_url,
817            cookies,
818            query_params,
819            headers,
820        })
821    }
822
823    /// Returns true or false if the underlying service inside the `TestServer`
824    /// is still running. For many types of services this will always return `true`.
825    ///
826    /// When a `TestServer` is built using [`axum::serve::WithGracefulShutdown`],
827    /// this will return false if the service has shutdown.
828    pub fn is_running(&self) -> bool {
829        self.transport.is_running()
830    }
831}
832
833fn build_url(
834    mut url: Url,
835    path: &str,
836    query_params: &mut QueryParamsStore,
837    is_http_restricted: bool,
838) -> Result<Url> {
839    let path_uri = path.parse::<Uri>()?;
840
841    // If there is a scheme, then this is an absolute path.
842    if let Some(scheme) = path_uri.scheme_str() {
843        if is_http_restricted {
844            if has_different_schema(&url, &path_uri) || has_different_authority(&url, &path_uri) {
845                return Err(anyhow!(
846                    "Request disallowed for path '{path}', requests are only allowed to local server. Turn off 'restrict_requests_with_http_schema' to change this."
847                ));
848            }
849        } else {
850            url.set_scheme(scheme)
851                .map_err(|_| anyhow!("Failed to set scheme for request, with path '{path}'"))?;
852
853            // We only set the host/port if the scheme is also present.
854            if let Some(authority) = path_uri.authority() {
855                url.set_host(Some(authority.host()))
856                    .map_err(|_| anyhow!("Failed to set host for request, with path '{path}'"))?;
857                url.set_port(authority.port().map(|p| p.as_u16()))
858                    .map_err(|_| anyhow!("Failed to set port for request, with path '{path}'"))?;
859
860                // todo, add username:password support
861            }
862        }
863    }
864
865    // Why does this exist?
866    //
867    // This exists to allow `server.get("/users")` and `server.get("users")` (without a slash)
868    // to go to the same place.
869    //
870    // It does this by saying ...
871    //  - if there is a scheme, it's a full path.
872    //  - if no scheme, it must be a path
873    //
874    if is_absolute_uri(&path_uri) {
875        url.set_path(path_uri.path());
876
877        // In this path we are replacing, so drop any query params on the original url.
878        if url.query().is_some() {
879            url.set_query(None);
880        }
881    } else {
882        // Grab everything up until the query parameters, or everything after that
883        let calculated_path = path.split('?').next().unwrap_or(path);
884        url.set_path(calculated_path);
885
886        // Move any query parameters from the url to the query params store.
887        if let Some(url_query) = url.query() {
888            query_params.add_raw(url_query.to_string());
889            url.set_query(None);
890        }
891    }
892
893    if let Some(path_query) = path_uri.query() {
894        query_params.add_raw(path_query.to_string());
895    }
896
897    Ok(url)
898}
899
900fn is_absolute_uri(path_uri: &Uri) -> bool {
901    path_uri.scheme_str().is_some()
902}
903
904fn has_different_schema(base_url: &Url, path_uri: &Uri) -> bool {
905    if let Some(scheme) = path_uri.scheme_str() {
906        return scheme != base_url.scheme();
907    }
908
909    false
910}
911
912fn has_different_authority(base_url: &Url, path_uri: &Uri) -> bool {
913    if let Some(authority) = path_uri.authority() {
914        return authority.as_str() != base_url.authority();
915    }
916
917    false
918}
919
920#[cfg(test)]
921mod test_build_url {
922    use super::*;
923
924    #[test]
925    fn it_should_copy_path_to_url_returned_when_restricted() {
926        let base_url = "http://example.com".parse::<Url>().unwrap();
927        let path = "/users";
928        let mut query_params = QueryParamsStore::new();
929        let result = build_url(base_url, &path, &mut query_params, true).unwrap();
930
931        assert_eq!("http://example.com/users", result.as_str());
932        assert!(query_params.is_empty());
933    }
934
935    #[test]
936    fn it_should_copy_all_query_params_to_store_when_restricted() {
937        let base_url = "http://example.com?base=aaa".parse::<Url>().unwrap();
938        let path = "/users?path=bbb&path-flag";
939        let mut query_params = QueryParamsStore::new();
940        let result = build_url(base_url, &path, &mut query_params, true).unwrap();
941
942        assert_eq!("http://example.com/users", result.as_str());
943        assert_eq!("base=aaa&path=bbb&path-flag", query_params.to_string());
944    }
945
946    #[test]
947    fn it_should_not_replace_url_when_restricted_with_different_scheme() {
948        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
949        let path = "ftp://google.com:123/users.csv?limit=456";
950        let mut query_params = QueryParamsStore::new();
951        let result = build_url(base_url, &path, &mut query_params, true);
952
953        assert!(result.is_err());
954    }
955
956    #[test]
957    fn it_should_not_replace_url_when_restricted_with_same_scheme() {
958        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
959        let path = "http://google.com:123/users.csv?limit=456";
960        let mut query_params = QueryParamsStore::new();
961        let result = build_url(base_url, &path, &mut query_params, true);
962
963        assert!(result.is_err());
964    }
965
966    #[test]
967    fn it_should_block_url_when_restricted_with_same_scheme() {
968        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
969        let path = "http://google.com";
970        let mut query_params = QueryParamsStore::new();
971        let result = build_url(base_url, &path, &mut query_params, true);
972
973        assert!(result.is_err());
974    }
975
976    #[test]
977    fn it_should_block_url_when_restricted_and_same_domain_with_different_scheme() {
978        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
979        let path = "ftp://example.com/users";
980        let mut query_params = QueryParamsStore::new();
981        let result = build_url(base_url, &path, &mut query_params, true);
982
983        assert!(result.is_err());
984    }
985
986    #[test]
987    fn it_should_copy_path_to_url_returned_when_unrestricted() {
988        let base_url = "http://example.com".parse::<Url>().unwrap();
989        let path = "/users";
990        let mut query_params = QueryParamsStore::new();
991        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
992
993        assert_eq!("http://example.com/users", result.as_str());
994        assert!(query_params.is_empty());
995    }
996
997    #[test]
998    fn it_should_copy_all_query_params_to_store_when_unrestricted() {
999        let base_url = "http://example.com?base=aaa".parse::<Url>().unwrap();
1000        let path = "/users?path=bbb&path-flag";
1001        let mut query_params = QueryParamsStore::new();
1002        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1003
1004        assert_eq!("http://example.com/users", result.as_str());
1005        assert_eq!("base=aaa&path=bbb&path-flag", query_params.to_string());
1006    }
1007
1008    #[test]
1009    fn it_should_copy_host_like_a_path_when_unrestricted() {
1010        let base_url = "http://example.com".parse::<Url>().unwrap();
1011        let path = "google.com";
1012        let mut query_params = QueryParamsStore::new();
1013        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1014
1015        assert_eq!("http://example.com/google.com", result.as_str());
1016        assert!(query_params.is_empty());
1017    }
1018
1019    #[test]
1020    fn it_should_copy_host_like_a_path_when_restricted() {
1021        let base_url = "http://example.com".parse::<Url>().unwrap();
1022        let path = "google.com";
1023        let mut query_params = QueryParamsStore::new();
1024        let result = build_url(base_url, &path, &mut query_params, true).unwrap();
1025
1026        assert_eq!("http://example.com/google.com", result.as_str());
1027        assert!(query_params.is_empty());
1028    }
1029
1030    #[test]
1031    fn it_should_replace_url_when_unrestricted() {
1032        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
1033        let path = "ftp://google.com:123/users.csv?limit=456";
1034        let mut query_params = QueryParamsStore::new();
1035        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1036
1037        assert_eq!("ftp://google.com:123/users.csv", result.as_str());
1038        assert_eq!("limit=456", query_params.to_string());
1039    }
1040
1041    #[test]
1042    fn it_should_allow_different_scheme_when_unrestricted() {
1043        let base_url = "http://example.com".parse::<Url>().unwrap();
1044        let path = "ftp://example.com";
1045        let mut query_params = QueryParamsStore::new();
1046        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1047
1048        assert_eq!("ftp://example.com/", result.as_str());
1049    }
1050
1051    #[test]
1052    fn it_should_allow_different_host_when_unrestricted() {
1053        let base_url = "http://example.com".parse::<Url>().unwrap();
1054        let path = "http://google.com";
1055        let mut query_params = QueryParamsStore::new();
1056        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1057
1058        assert_eq!("http://google.com/", result.as_str());
1059    }
1060
1061    #[test]
1062    fn it_should_allow_different_port_when_unrestricted() {
1063        let base_url = "http://example.com:123".parse::<Url>().unwrap();
1064        let path = "http://example.com:456";
1065        let mut query_params = QueryParamsStore::new();
1066        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1067
1068        assert_eq!("http://example.com:456/", result.as_str());
1069    }
1070
1071    #[test]
1072    fn it_should_allow_same_host_port_when_unrestricted() {
1073        let base_url = "http://example.com:123".parse::<Url>().unwrap();
1074        let path = "http://example.com:123";
1075        let mut query_params = QueryParamsStore::new();
1076        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1077
1078        assert_eq!("http://example.com:123/", result.as_str());
1079    }
1080
1081    #[test]
1082    fn it_should_not_allow_different_scheme_when_restricted() {
1083        let base_url = "http://example.com".parse::<Url>().unwrap();
1084        let path = "ftp://example.com";
1085        let mut query_params = QueryParamsStore::new();
1086        let result = build_url(base_url, &path, &mut query_params, true);
1087
1088        assert!(result.is_err());
1089    }
1090
1091    #[test]
1092    fn it_should_not_allow_different_host_when_restricted() {
1093        let base_url = "http://example.com".parse::<Url>().unwrap();
1094        let path = "http://google.com";
1095        let mut query_params = QueryParamsStore::new();
1096        let result = build_url(base_url, &path, &mut query_params, true);
1097
1098        assert!(result.is_err());
1099    }
1100
1101    #[test]
1102    fn it_should_not_allow_different_port_when_restricted() {
1103        let base_url = "http://example.com:123".parse::<Url>().unwrap();
1104        let path = "http://example.com:456";
1105        let mut query_params = QueryParamsStore::new();
1106        let result = build_url(base_url, &path, &mut query_params, true);
1107
1108        assert!(result.is_err());
1109    }
1110
1111    #[test]
1112    fn it_should_allow_same_host_port_when_restricted() {
1113        let base_url = "http://example.com:123".parse::<Url>().unwrap();
1114        let path = "http://example.com:123";
1115        let mut query_params = QueryParamsStore::new();
1116        let result = build_url(base_url, &path, &mut query_params, true).unwrap();
1117
1118        assert_eq!("http://example.com:123/", result.as_str());
1119    }
1120}
1121
1122#[cfg(test)]
1123mod test_new {
1124    use axum::Router;
1125    use axum::routing::get;
1126    use std::net::SocketAddr;
1127
1128    use crate::TestServer;
1129
1130    async fn get_ping() -> &'static str {
1131        "pong!"
1132    }
1133
1134    #[tokio::test]
1135    async fn it_should_run_into_make_into_service_with_connect_info_by_default() {
1136        // Build an application with a route.
1137        let app = Router::new()
1138            .route("/ping", get(get_ping))
1139            .into_make_service_with_connect_info::<SocketAddr>();
1140
1141        // Run the server.
1142        let server = TestServer::new(app).expect("Should create test server");
1143
1144        // Get the request.
1145        server.get(&"/ping").await.assert_text(&"pong!");
1146    }
1147}
1148
1149#[cfg(test)]
1150mod test_get {
1151    use super::*;
1152
1153    use axum::Router;
1154    use axum::routing::get;
1155    use reserve_port::ReservedSocketAddr;
1156
1157    async fn get_ping() -> &'static str {
1158        "pong!"
1159    }
1160
1161    #[tokio::test]
1162    async fn it_should_get_using_relative_path_with_slash() {
1163        let app = Router::new().route("/ping", get(get_ping));
1164        let server = TestServer::new(app).expect("Should create test server");
1165
1166        // Get the request _with_ slash
1167        server.get(&"/ping").await.assert_text(&"pong!");
1168    }
1169
1170    #[tokio::test]
1171    async fn it_should_get_using_relative_path_without_slash() {
1172        let app = Router::new().route("/ping", get(get_ping));
1173        let server = TestServer::new(app).expect("Should create test server");
1174
1175        // Get the request _without_ slash
1176        server.get(&"ping").await.assert_text(&"pong!");
1177    }
1178
1179    #[tokio::test]
1180    async fn it_should_get_using_absolute_path() {
1181        // Build an application with a route.
1182        let app = Router::new().route("/ping", get(get_ping));
1183
1184        // Reserve an address
1185        let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1186        let ip = reserved_address.ip();
1187        let port = reserved_address.port();
1188
1189        // Run the server.
1190        let server = TestServer::builder()
1191            .http_transport_with_ip_port(Some(ip), Some(port))
1192            .build(app)
1193            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1194            .unwrap();
1195
1196        // Get the request.
1197        let absolute_url = format!("http://{ip}:{port}/ping");
1198        let response = server.get(&absolute_url).await;
1199
1200        response.assert_text(&"pong!");
1201        let request_path = response.request_url();
1202        assert_eq!(request_path.to_string(), format!("http://{ip}:{port}/ping"));
1203    }
1204
1205    #[tokio::test]
1206    async fn it_should_get_using_absolute_path_and_restricted_if_path_is_for_server() {
1207        // Build an application with a route.
1208        let app = Router::new().route("/ping", get(get_ping));
1209
1210        // Reserve an IP / Port
1211        let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1212        let ip = reserved_address.ip();
1213        let port = reserved_address.port();
1214
1215        // Run the server.
1216        let server = TestServer::builder()
1217            .http_transport_with_ip_port(Some(ip), Some(port))
1218            .restrict_requests_with_http_schema() // Key part of the test!
1219            .build(app)
1220            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1221            .unwrap();
1222
1223        // Get the request.
1224        let absolute_url = format!("http://{ip}:{port}/ping");
1225        let response = server.get(&absolute_url).await;
1226
1227        response.assert_text(&"pong!");
1228        let request_path = response.request_url();
1229        assert_eq!(request_path.to_string(), format!("http://{ip}:{port}/ping"));
1230    }
1231
1232    #[tokio::test]
1233    #[should_panic]
1234    async fn it_should_not_get_using_absolute_path_if_restricted_and_different_port() {
1235        // Build an application with a route.
1236        let app = Router::new().route("/ping", get(get_ping));
1237
1238        // Reserve an IP / Port
1239        let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1240        let ip = reserved_address.ip();
1241        let mut port = reserved_address.port();
1242
1243        // Run the server.
1244        let server = TestServer::builder()
1245            .http_transport_with_ip_port(Some(ip), Some(port))
1246            .restrict_requests_with_http_schema() // Key part of the test!
1247            .build(app)
1248            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1249            .unwrap();
1250
1251        // Get the request.
1252        port += 1; // << Change the port to be off by one and not match the server
1253        let absolute_url = format!("http://{ip}:{port}/ping");
1254        server.get(&absolute_url).await;
1255    }
1256
1257    #[tokio::test]
1258    async fn it_should_work_in_parallel() {
1259        let app = Router::new().route("/ping", get(get_ping));
1260        let server = TestServer::new(app).expect("Should create test server");
1261
1262        let future1 = async { server.get("/ping").await };
1263        let future2 = async { server.get("/ping").await };
1264        let (r1, r2) = tokio::join!(future1, future2);
1265
1266        assert_eq!(r1.text(), r2.text());
1267    }
1268
1269    #[tokio::test]
1270    async fn it_should_work_in_parallel_with_sleeping_requests() {
1271        let app = axum::Router::new().route(
1272            &"/slow",
1273            axum::routing::get(|| async {
1274                tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1275                "hello!"
1276            }),
1277        );
1278
1279        let server = TestServer::new(app).expect("Should create test server");
1280
1281        let future1 = async { server.get("/slow").await };
1282        let future2 = async { server.get("/slow").await };
1283        let (r1, r2) = tokio::join!(future1, future2);
1284
1285        assert_eq!(r1.text(), r2.text());
1286    }
1287}
1288
1289#[cfg(feature = "reqwest")]
1290#[cfg(test)]
1291mod test_reqwest_get {
1292    use super::*;
1293
1294    use axum::Router;
1295    use axum::routing::get;
1296
1297    async fn get_ping() -> &'static str {
1298        "pong!"
1299    }
1300
1301    #[tokio::test]
1302    async fn it_should_get_using_relative_path_with_slash() {
1303        let app = Router::new().route("/ping", get(get_ping));
1304        let server = TestServer::builder()
1305            .http_transport()
1306            .build(app)
1307            .expect("Should create test server");
1308
1309        let response = server
1310            .reqwest_get(&"/ping")
1311            .send()
1312            .await
1313            .unwrap()
1314            .text()
1315            .await
1316            .unwrap();
1317
1318        assert_eq!(response, "pong!");
1319    }
1320}
1321
1322#[cfg(feature = "reqwest")]
1323#[cfg(test)]
1324mod test_reqwest_post {
1325    use super::*;
1326
1327    use axum::Json;
1328    use axum::Router;
1329    use axum::routing::post;
1330    use serde::Deserialize;
1331
1332    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1333    struct TestBody {
1334        number: u32,
1335        text: String,
1336    }
1337
1338    async fn post_json(Json(body): Json<TestBody>) -> Json<TestBody> {
1339        let response = TestBody {
1340            number: body.number * 2,
1341            text: format!("{}_plus_response", body.text),
1342        };
1343
1344        Json(response)
1345    }
1346
1347    #[tokio::test]
1348    async fn it_should_post_and_receive_json() {
1349        let app = Router::new().route("/json", post(post_json));
1350        let server = TestServer::builder()
1351            .http_transport()
1352            .build(app)
1353            .expect("Should create test server");
1354
1355        let response = server
1356            .reqwest_post(&"/json")
1357            .json(&TestBody {
1358                number: 111,
1359                text: format!("request"),
1360            })
1361            .send()
1362            .await
1363            .unwrap()
1364            .json::<TestBody>()
1365            .await
1366            .unwrap();
1367
1368        assert_eq!(
1369            response,
1370            TestBody {
1371                number: 222,
1372                text: format!("request_plus_response"),
1373            }
1374        );
1375    }
1376}
1377
1378#[cfg(test)]
1379mod test_server_address {
1380    use super::*;
1381
1382    use axum::Router;
1383    use local_ip_address::local_ip;
1384    use regex::Regex;
1385    use reserve_port::ReservedPort;
1386
1387    #[tokio::test]
1388    async fn it_should_return_address_used_from_config() {
1389        let reserved_port = ReservedPort::random().unwrap();
1390        let ip = local_ip().unwrap();
1391        let port = reserved_port.port();
1392
1393        // Build an application with a route.
1394        let app = Router::new();
1395        let server = TestServer::builder()
1396            .http_transport_with_ip_port(Some(ip), Some(port))
1397            .build(app)
1398            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1399            .unwrap();
1400
1401        let expected_ip_port = format!("http://{}:{}/", ip, reserved_port.port());
1402        assert_eq!(
1403            server.server_address().unwrap().to_string(),
1404            expected_ip_port
1405        );
1406    }
1407
1408    #[tokio::test]
1409    async fn it_should_return_default_address_without_ending_slash() {
1410        let app = Router::new();
1411        let server = TestServer::builder()
1412            .http_transport()
1413            .build(app)
1414            .expect("Should create test server");
1415
1416        let address_regex = Regex::new("^http://127\\.0\\.0\\.1:[0-9]+/$").unwrap();
1417        let is_match = address_regex.is_match(&server.server_address().unwrap().to_string());
1418        assert!(is_match);
1419    }
1420
1421    #[tokio::test]
1422    async fn it_should_return_none_on_mock_transport() {
1423        let app = Router::new();
1424        let server = TestServer::builder()
1425            .mock_transport()
1426            .build(app)
1427            .expect("Should create test server");
1428
1429        assert!(server.server_address().is_none());
1430    }
1431}
1432
1433#[cfg(test)]
1434mod test_server_url {
1435    use super::*;
1436
1437    use axum::Router;
1438    use local_ip_address::local_ip;
1439    use regex::Regex;
1440    use reserve_port::ReservedPort;
1441
1442    #[tokio::test]
1443    async fn it_should_return_address_with_url_on_http_ip_port() {
1444        let reserved_port = ReservedPort::random().unwrap();
1445        let ip = local_ip().unwrap();
1446        let port = reserved_port.port();
1447
1448        // Build an application with a route.
1449        let app = Router::new();
1450        let server = TestServer::builder()
1451            .http_transport_with_ip_port(Some(ip), Some(port))
1452            .build(app)
1453            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1454            .unwrap();
1455
1456        let expected_ip_port_url = format!("http://{}:{}/users", ip, reserved_port.port());
1457        let absolute_url = server.server_url("/users").unwrap().to_string();
1458        assert_eq!(absolute_url, expected_ip_port_url);
1459    }
1460
1461    #[tokio::test]
1462    async fn it_should_return_address_with_url_on_random_http() {
1463        let app = Router::new();
1464        let server = TestServer::builder()
1465            .http_transport()
1466            .build(app)
1467            .expect("Should create test server");
1468
1469        let address_regex =
1470            Regex::new("^http://127\\.0\\.0\\.1:[0-9]+/users/123\\?filter=enabled$").unwrap();
1471        let absolute_url = &server
1472            .server_url(&"/users/123?filter=enabled")
1473            .unwrap()
1474            .to_string();
1475
1476        let is_match = address_regex.is_match(absolute_url);
1477        assert!(is_match);
1478    }
1479
1480    #[tokio::test]
1481    async fn it_should_error_on_mock_transport() {
1482        // Build an application with a route.
1483        let app = Router::new();
1484        let server = TestServer::builder()
1485            .mock_transport()
1486            .build(app)
1487            .expect("Should create test server");
1488
1489        let result = server.server_url("/users");
1490        assert!(result.is_err());
1491    }
1492
1493    #[tokio::test]
1494    async fn it_should_include_path_query_params() {
1495        let reserved_port = ReservedPort::random().unwrap();
1496        let ip = local_ip().unwrap();
1497        let port = reserved_port.port();
1498
1499        // Build an application with a route.
1500        let app = Router::new();
1501        let server = TestServer::builder()
1502            .http_transport_with_ip_port(Some(ip), Some(port))
1503            .build(app)
1504            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1505            .unwrap();
1506
1507        let expected_url = format!(
1508            "http://{}:{}/users?filter=enabled",
1509            ip,
1510            reserved_port.port()
1511        );
1512        let received_url = server
1513            .server_url("/users?filter=enabled")
1514            .unwrap()
1515            .to_string();
1516
1517        assert_eq!(received_url, expected_url);
1518    }
1519
1520    #[tokio::test]
1521    async fn it_should_include_server_query_params() {
1522        let reserved_port = ReservedPort::random().unwrap();
1523        let ip = local_ip().unwrap();
1524        let port = reserved_port.port();
1525
1526        // Build an application with a route.
1527        let app = Router::new();
1528        let mut server = TestServer::builder()
1529            .http_transport_with_ip_port(Some(ip), Some(port))
1530            .build(app)
1531            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1532            .unwrap();
1533
1534        server.add_query_param("filter", "enabled");
1535
1536        let expected_url = format!(
1537            "http://{}:{}/users?filter=enabled",
1538            ip,
1539            reserved_port.port()
1540        );
1541        let received_url = server.server_url("/users").unwrap().to_string();
1542
1543        assert_eq!(received_url, expected_url);
1544    }
1545
1546    #[tokio::test]
1547    async fn it_should_include_server_and_path_query_params() {
1548        let reserved_port = ReservedPort::random().unwrap();
1549        let ip = local_ip().unwrap();
1550        let port = reserved_port.port();
1551
1552        // Build an application with a route.
1553        let app = Router::new();
1554        let mut server = TestServer::builder()
1555            .http_transport_with_ip_port(Some(ip), Some(port))
1556            .build(app)
1557            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1558            .unwrap();
1559
1560        server.add_query_param("filter", "enabled");
1561
1562        let expected_url = format!(
1563            "http://{}:{}/users?filter=enabled&animal=donkeys",
1564            ip,
1565            reserved_port.port()
1566        );
1567        let received_url = server
1568            .server_url("/users?animal=donkeys")
1569            .unwrap()
1570            .to_string();
1571
1572        assert_eq!(received_url, expected_url);
1573    }
1574}
1575
1576#[cfg(test)]
1577mod test_add_cookie {
1578    use crate::TestServer;
1579
1580    use axum::Router;
1581    use axum::routing::get;
1582    use axum_extra::extract::cookie::CookieJar;
1583    use cookie::Cookie;
1584
1585    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
1586
1587    async fn get_cookie(cookies: CookieJar) -> (CookieJar, String) {
1588        let cookie = cookies.get(&TEST_COOKIE_NAME);
1589        let cookie_value = cookie
1590            .map(|c| c.value().to_string())
1591            .unwrap_or_else(|| "cookie-not-found".to_string());
1592
1593        (cookies, cookie_value)
1594    }
1595
1596    #[tokio::test]
1597    async fn it_should_send_cookies_added_to_request() {
1598        let app = Router::new().route("/cookie", get(get_cookie));
1599        let mut server = TestServer::new(app).expect("Should create test server");
1600
1601        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1602        server.add_cookie(cookie);
1603
1604        let response_text = server.get(&"/cookie").await.text();
1605        assert_eq!(response_text, "my-custom-cookie");
1606    }
1607}
1608
1609#[cfg(test)]
1610mod test_add_cookies {
1611    use crate::TestServer;
1612
1613    use axum::Router;
1614    use axum::routing::get;
1615    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1616    use cookie::Cookie;
1617    use cookie::CookieJar;
1618
1619    async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1620        let mut all_cookies = cookies
1621            .iter()
1622            .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1623            .collect::<Vec<String>>();
1624        all_cookies.sort();
1625
1626        all_cookies.join(&", ")
1627    }
1628
1629    #[tokio::test]
1630    async fn it_should_send_all_cookies_added_by_jar() {
1631        let app = Router::new().route("/cookies", get(route_get_cookies));
1632        let mut server = TestServer::new(app).expect("Should create test server");
1633
1634        // Build cookies to send up
1635        let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1636        let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1637        let mut cookie_jar = CookieJar::new();
1638        cookie_jar.add(cookie_1);
1639        cookie_jar.add(cookie_2);
1640
1641        server.add_cookies(cookie_jar);
1642
1643        server
1644            .get(&"/cookies")
1645            .await
1646            .assert_text("first-cookie=my-custom-cookie, second-cookie=other-cookie");
1647    }
1648}
1649
1650#[cfg(test)]
1651mod test_clear_cookies {
1652    use crate::TestServer;
1653
1654    use axum::Router;
1655    use axum::routing::get;
1656    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1657    use cookie::Cookie;
1658    use cookie::CookieJar;
1659
1660    async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1661        let mut all_cookies = cookies
1662            .iter()
1663            .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1664            .collect::<Vec<String>>();
1665        all_cookies.sort();
1666
1667        all_cookies.join(&", ")
1668    }
1669
1670    #[tokio::test]
1671    async fn it_should_not_send_cookies_cleared() {
1672        let app = Router::new().route("/cookies", get(route_get_cookies));
1673        let mut server = TestServer::new(app).expect("Should create test server");
1674
1675        let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1676        let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1677        let mut cookie_jar = CookieJar::new();
1678        cookie_jar.add(cookie_1);
1679        cookie_jar.add(cookie_2);
1680
1681        server.add_cookies(cookie_jar);
1682
1683        // The important bit of this test
1684        server.clear_cookies();
1685
1686        server.get(&"/cookies").await.assert_text("");
1687    }
1688}
1689
1690#[cfg(test)]
1691mod test_add_header {
1692    use super::*;
1693    use crate::TestServer;
1694    use axum::Router;
1695    use axum::extract::FromRequestParts;
1696    use axum::routing::get;
1697    use http::HeaderName;
1698    use http::HeaderValue;
1699    use http::request::Parts;
1700    use hyper::StatusCode;
1701    use std::marker::Sync;
1702
1703    const TEST_HEADER_NAME: &'static str = &"test-header";
1704    const TEST_HEADER_CONTENT: &'static str = &"Test header content";
1705
1706    struct TestHeader(Vec<u8>);
1707
1708    impl<S: Sync> FromRequestParts<S> for TestHeader {
1709        type Rejection = (StatusCode, &'static str);
1710
1711        async fn from_request_parts(
1712            parts: &mut Parts,
1713            _state: &S,
1714        ) -> Result<TestHeader, Self::Rejection> {
1715            parts
1716                .headers
1717                .get(HeaderName::from_static(TEST_HEADER_NAME))
1718                .map(|v| TestHeader(v.as_bytes().to_vec()))
1719                .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
1720        }
1721    }
1722
1723    async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
1724        header
1725    }
1726
1727    #[tokio::test]
1728    async fn it_should_send_header_added_to_server() {
1729        // Build an application with a route.
1730        let app = Router::new().route("/header", get(ping_header));
1731
1732        // Run the server.
1733        let mut server = TestServer::new(app).expect("Should create test server");
1734        server.add_header(
1735            HeaderName::from_static(TEST_HEADER_NAME),
1736            HeaderValue::from_static(TEST_HEADER_CONTENT),
1737        );
1738
1739        // Send a request with the header
1740        let response = server.get(&"/header").await;
1741
1742        // Check it sent back the right text
1743        response.assert_text(TEST_HEADER_CONTENT)
1744    }
1745}
1746
1747#[cfg(test)]
1748mod test_clear_headers {
1749    use super::*;
1750    use crate::TestServer;
1751    use axum::Router;
1752    use axum::extract::FromRequestParts;
1753    use axum::routing::get;
1754    use http::HeaderName;
1755    use http::HeaderValue;
1756    use http::request::Parts;
1757    use hyper::StatusCode;
1758    use std::marker::Sync;
1759
1760    const TEST_HEADER_NAME: &'static str = &"test-header";
1761    const TEST_HEADER_CONTENT: &'static str = &"Test header content";
1762
1763    struct TestHeader(Vec<u8>);
1764
1765    impl<S: Sync> FromRequestParts<S> for TestHeader {
1766        type Rejection = (StatusCode, &'static str);
1767
1768        async fn from_request_parts(
1769            parts: &mut Parts,
1770            _state: &S,
1771        ) -> Result<Self, Self::Rejection> {
1772            parts
1773                .headers
1774                .get(HeaderName::from_static(TEST_HEADER_NAME))
1775                .map(|v| TestHeader(v.as_bytes().to_vec()))
1776                .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
1777        }
1778    }
1779
1780    async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
1781        header
1782    }
1783
1784    #[tokio::test]
1785    async fn it_should_not_send_headers_cleared_by_server() {
1786        // Build an application with a route.
1787        let app = Router::new().route("/header", get(ping_header));
1788
1789        // Run the server.
1790        let mut server = TestServer::new(app).expect("Should create test server");
1791        server.add_header(
1792            HeaderName::from_static(TEST_HEADER_NAME),
1793            HeaderValue::from_static(TEST_HEADER_CONTENT),
1794        );
1795        server.clear_headers();
1796
1797        // Send a request with the header
1798        let response = server.get(&"/header").await;
1799
1800        // Check it sent back the right text
1801        response.assert_status_bad_request();
1802        response.assert_text("Missing test header");
1803    }
1804}
1805
1806#[cfg(test)]
1807mod test_add_query_params {
1808    use axum::Router;
1809    use axum::extract::Query;
1810    use axum::routing::get;
1811
1812    use serde::Deserialize;
1813    use serde::Serialize;
1814    use serde_json::json;
1815
1816    use crate::TestServer;
1817
1818    #[derive(Debug, Deserialize, Serialize)]
1819    struct QueryParam {
1820        message: String,
1821    }
1822
1823    async fn get_query_param(Query(params): Query<QueryParam>) -> String {
1824        params.message
1825    }
1826
1827    #[derive(Debug, Deserialize, Serialize)]
1828    struct QueryParam2 {
1829        message: String,
1830        other: String,
1831    }
1832
1833    async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
1834        format!("{}-{}", params.message, params.other)
1835    }
1836
1837    #[tokio::test]
1838    async fn it_should_pass_up_query_params_from_serialization() {
1839        // Build an application with a route.
1840        let app = Router::new().route("/query", get(get_query_param));
1841
1842        // Run the server.
1843        let mut server = TestServer::new(app).expect("Should create test server");
1844        server.add_query_params(QueryParam {
1845            message: "it works".to_string(),
1846        });
1847
1848        // Get the request.
1849        server.get(&"/query").await.assert_text(&"it works");
1850    }
1851
1852    #[tokio::test]
1853    async fn it_should_pass_up_query_params_from_pairs() {
1854        // Build an application with a route.
1855        let app = Router::new().route("/query", get(get_query_param));
1856
1857        // Run the server.
1858        let mut server = TestServer::new(app).expect("Should create test server");
1859        server.add_query_params(&[("message", "it works")]);
1860
1861        // Get the request.
1862        server.get(&"/query").await.assert_text(&"it works");
1863    }
1864
1865    #[tokio::test]
1866    async fn it_should_pass_up_multiple_query_params_from_multiple_params() {
1867        // Build an application with a route.
1868        let app = Router::new().route("/query-2", get(get_query_param_2));
1869
1870        // Run the server.
1871        let mut server = TestServer::new(app).expect("Should create test server");
1872        server.add_query_params(&[("message", "it works"), ("other", "yup")]);
1873
1874        // Get the request.
1875        server.get(&"/query-2").await.assert_text(&"it works-yup");
1876    }
1877
1878    #[tokio::test]
1879    async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
1880        // Build an application with a route.
1881        let app = Router::new().route("/query-2", get(get_query_param_2));
1882
1883        // Run the server.
1884        let mut server = TestServer::new(app).expect("Should create test server");
1885        server.add_query_params(&[("message", "it works")]);
1886        server.add_query_params(&[("other", "yup")]);
1887
1888        // Get the request.
1889        server.get(&"/query-2").await.assert_text(&"it works-yup");
1890    }
1891
1892    #[tokio::test]
1893    async fn it_should_pass_up_multiple_query_params_from_json() {
1894        // Build an application with a route.
1895        let app = Router::new().route("/query-2", get(get_query_param_2));
1896
1897        // Run the server.
1898        let mut server = TestServer::new(app).expect("Should create test server");
1899        server.add_query_params(json!({
1900            "message": "it works",
1901            "other": "yup"
1902        }));
1903
1904        // Get the request.
1905        server.get(&"/query-2").await.assert_text(&"it works-yup");
1906    }
1907}
1908
1909#[cfg(test)]
1910mod test_add_query_param {
1911    use axum::Router;
1912    use axum::extract::Query;
1913    use axum::routing::get;
1914
1915    use serde::Deserialize;
1916    use serde::Serialize;
1917
1918    use crate::TestServer;
1919
1920    #[derive(Debug, Deserialize, Serialize)]
1921    struct QueryParam {
1922        message: String,
1923    }
1924
1925    async fn get_query_param(Query(params): Query<QueryParam>) -> String {
1926        params.message
1927    }
1928
1929    #[derive(Debug, Deserialize, Serialize)]
1930    struct QueryParam2 {
1931        message: String,
1932        other: String,
1933    }
1934
1935    async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
1936        format!("{}-{}", params.message, params.other)
1937    }
1938
1939    #[tokio::test]
1940    async fn it_should_pass_up_query_params_from_pairs() {
1941        // Build an application with a route.
1942        let app = Router::new().route("/query", get(get_query_param));
1943
1944        // Run the server.
1945        let mut server = TestServer::new(app).expect("Should create test server");
1946        server.add_query_param("message", "it works");
1947
1948        // Get the request.
1949        server.get(&"/query").await.assert_text(&"it works");
1950    }
1951
1952    #[tokio::test]
1953    async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
1954        // Build an application with a route.
1955        let app = Router::new().route("/query-2", get(get_query_param_2));
1956
1957        // Run the server.
1958        let mut server = TestServer::new(app).expect("Should create test server");
1959        server.add_query_param("message", "it works");
1960        server.add_query_param("other", "yup");
1961
1962        // Get the request.
1963        server.get(&"/query-2").await.assert_text(&"it works-yup");
1964    }
1965
1966    #[tokio::test]
1967    async fn it_should_pass_up_multiple_query_params_from_calls_across_server_and_request() {
1968        // Build an application with a route.
1969        let app = Router::new().route("/query-2", get(get_query_param_2));
1970
1971        // Run the server.
1972        let mut server = TestServer::new(app).expect("Should create test server");
1973        server.add_query_param("message", "it works");
1974
1975        // Get the request.
1976        server
1977            .get(&"/query-2")
1978            .add_query_param("other", "yup")
1979            .await
1980            .assert_text(&"it works-yup");
1981    }
1982}
1983
1984#[cfg(test)]
1985mod test_add_raw_query_param {
1986    use axum::Router;
1987    use axum::extract::Query as AxumStdQuery;
1988    use axum::routing::get;
1989    use axum_extra::extract::Query as AxumExtraQuery;
1990    use serde::Deserialize;
1991    use serde::Serialize;
1992    use std::fmt::Write;
1993
1994    use crate::TestServer;
1995
1996    #[derive(Debug, Deserialize, Serialize)]
1997    struct QueryParam {
1998        message: String,
1999    }
2000
2001    async fn get_query_param(AxumStdQuery(params): AxumStdQuery<QueryParam>) -> String {
2002        params.message
2003    }
2004
2005    #[derive(Debug, Deserialize, Serialize)]
2006    struct QueryParamExtra {
2007        #[serde(default)]
2008        items: Vec<String>,
2009
2010        #[serde(default, rename = "arrs[]")]
2011        arrs: Vec<String>,
2012    }
2013
2014    async fn get_query_param_extra(
2015        AxumExtraQuery(params): AxumExtraQuery<QueryParamExtra>,
2016    ) -> String {
2017        let mut output = String::new();
2018
2019        if params.items.len() > 0 {
2020            write!(output, "{}", params.items.join(", ")).unwrap();
2021        }
2022
2023        if params.arrs.len() > 0 {
2024            write!(output, "{}", params.arrs.join(", ")).unwrap();
2025        }
2026
2027        output
2028    }
2029
2030    fn build_app() -> Router {
2031        Router::new()
2032            .route("/query", get(get_query_param))
2033            .route("/query-extra", get(get_query_param_extra))
2034    }
2035
2036    #[tokio::test]
2037    async fn it_should_pass_up_query_param_as_is() {
2038        // Run the server.
2039        let mut server = TestServer::new(build_app()).expect("Should create test server");
2040        server.add_raw_query_param(&"message=it-works");
2041
2042        // Get the request.
2043        server.get(&"/query").await.assert_text(&"it-works");
2044    }
2045
2046    #[tokio::test]
2047    async fn it_should_pass_up_array_query_params_as_one_string() {
2048        // Run the server.
2049        let mut server = TestServer::new(build_app()).expect("Should create test server");
2050        server.add_raw_query_param(&"items=one&items=two&items=three");
2051
2052        // Get the request.
2053        server
2054            .get(&"/query-extra")
2055            .await
2056            .assert_text(&"one, two, three");
2057    }
2058
2059    #[tokio::test]
2060    async fn it_should_pass_up_array_query_params_as_multiple_params() {
2061        // Run the server.
2062        let mut server = TestServer::new(build_app()).expect("Should create test server");
2063        server.add_raw_query_param(&"arrs[]=one");
2064        server.add_raw_query_param(&"arrs[]=two");
2065        server.add_raw_query_param(&"arrs[]=three");
2066
2067        // Get the request.
2068        server
2069            .get(&"/query-extra")
2070            .await
2071            .assert_text(&"one, two, three");
2072    }
2073}
2074
2075#[cfg(test)]
2076mod test_clear_query_params {
2077    use axum::Router;
2078    use axum::extract::Query;
2079    use axum::routing::get;
2080
2081    use serde::Deserialize;
2082    use serde::Serialize;
2083
2084    use crate::TestServer;
2085
2086    #[derive(Debug, Deserialize, Serialize)]
2087    struct QueryParams {
2088        first: Option<String>,
2089        second: Option<String>,
2090    }
2091
2092    async fn get_query_params(Query(params): Query<QueryParams>) -> String {
2093        format!(
2094            "has first? {}, has second? {}",
2095            params.first.is_some(),
2096            params.second.is_some()
2097        )
2098    }
2099
2100    #[tokio::test]
2101    async fn it_should_clear_all_params_set() {
2102        // Build an application with a route.
2103        let app = Router::new().route("/query", get(get_query_params));
2104
2105        // Run the server.
2106        let mut server = TestServer::new(app).expect("Should create test server");
2107        server.add_query_params(QueryParams {
2108            first: Some("first".to_string()),
2109            second: Some("second".to_string()),
2110        });
2111        server.clear_query_params();
2112
2113        // Get the request.
2114        server
2115            .get(&"/query")
2116            .await
2117            .assert_text(&"has first? false, has second? false");
2118    }
2119
2120    #[tokio::test]
2121    async fn it_should_clear_all_params_set_and_allow_replacement() {
2122        // Build an application with a route.
2123        let app = Router::new().route("/query", get(get_query_params));
2124
2125        // Run the server.
2126        let mut server = TestServer::new(app).expect("Should create test server");
2127        server.add_query_params(QueryParams {
2128            first: Some("first".to_string()),
2129            second: Some("second".to_string()),
2130        });
2131        server.clear_query_params();
2132        server.add_query_params(QueryParams {
2133            first: Some("first".to_string()),
2134            second: Some("second".to_string()),
2135        });
2136
2137        // Get the request.
2138        server
2139            .get(&"/query")
2140            .await
2141            .assert_text(&"has first? true, has second? true");
2142    }
2143}
2144
2145#[cfg(test)]
2146mod test_expect_success_by_default {
2147    use super::*;
2148
2149    use axum::Router;
2150    use axum::routing::get;
2151
2152    #[tokio::test]
2153    async fn it_should_not_panic_by_default_if_accessing_404_route() {
2154        let app = Router::new();
2155        let server = TestServer::new(app).expect("Should create test server");
2156
2157        server.get(&"/some_unknown_route").await;
2158    }
2159
2160    #[tokio::test]
2161    async fn it_should_not_panic_by_default_if_accessing_200_route() {
2162        let app = Router::new().route("/known_route", get(|| async { "🦊🦊🦊" }));
2163        let server = TestServer::new(app).expect("Should create test server");
2164
2165        server.get(&"/known_route").await;
2166    }
2167
2168    #[tokio::test]
2169    #[should_panic]
2170    async fn it_should_panic_by_default_if_accessing_404_route_and_expect_success_on() {
2171        let app = Router::new();
2172        let server = TestServer::builder()
2173            .expect_success_by_default()
2174            .build(app)
2175            .expect("Should create test server");
2176
2177        server.get(&"/some_unknown_route").await;
2178    }
2179
2180    #[tokio::test]
2181    async fn it_should_not_panic_by_default_if_accessing_200_route_and_expect_success_on() {
2182        let app = Router::new().route("/known_route", get(|| async { "🦊🦊🦊" }));
2183        let server = TestServer::builder()
2184            .expect_success_by_default()
2185            .build(app)
2186            .expect("Should create test server");
2187
2188        server.get(&"/known_route").await;
2189    }
2190}
2191
2192#[cfg(test)]
2193mod test_content_type {
2194    use super::*;
2195
2196    use axum::Router;
2197    use axum::routing::get;
2198    use http::HeaderMap;
2199    use http::header::CONTENT_TYPE;
2200
2201    async fn get_content_type(headers: HeaderMap) -> String {
2202        headers
2203            .get(CONTENT_TYPE)
2204            .map(|h| h.to_str().unwrap().to_string())
2205            .unwrap_or_else(|| "".to_string())
2206    }
2207
2208    #[tokio::test]
2209    async fn it_should_default_to_server_content_type_when_present() {
2210        // Build an application with a route.
2211        let app = Router::new().route("/content_type", get(get_content_type));
2212
2213        // Run the server.
2214        let server = TestServer::builder()
2215            .default_content_type("text/plain")
2216            .build(app)
2217            .expect("Should create test server");
2218
2219        // Get the request.
2220        let text = server.get(&"/content_type").await.text();
2221
2222        assert_eq!(text, "text/plain");
2223    }
2224}
2225
2226#[cfg(test)]
2227mod test_expect_success {
2228    use crate::TestServer;
2229    use axum::Router;
2230    use axum::routing::get;
2231    use http::StatusCode;
2232
2233    #[tokio::test]
2234    async fn it_should_not_panic_if_success_is_returned() {
2235        async fn get_ping() -> &'static str {
2236            "pong!"
2237        }
2238
2239        // Build an application with a route.
2240        let app = Router::new().route("/ping", get(get_ping));
2241
2242        // Run the server.
2243        let mut server = TestServer::new(app).expect("Should create test server");
2244        server.expect_success();
2245
2246        // Get the request.
2247        server.get(&"/ping").await;
2248    }
2249
2250    #[tokio::test]
2251    async fn it_should_not_panic_on_other_2xx_status_code() {
2252        async fn get_accepted() -> StatusCode {
2253            StatusCode::ACCEPTED
2254        }
2255
2256        // Build an application with a route.
2257        let app = Router::new().route("/accepted", get(get_accepted));
2258
2259        // Run the server.
2260        let mut server = TestServer::new(app).expect("Should create test server");
2261        server.expect_success();
2262
2263        // Get the request.
2264        server.get(&"/accepted").await;
2265    }
2266
2267    #[tokio::test]
2268    #[should_panic]
2269    async fn it_should_panic_on_404() {
2270        // Build an application with a route.
2271        let app = Router::new();
2272
2273        // Run the server.
2274        let mut server = TestServer::new(app).expect("Should create test server");
2275        server.expect_success();
2276
2277        // Get the request.
2278        server.get(&"/some_unknown_route").await;
2279    }
2280}
2281
2282#[cfg(test)]
2283mod test_expect_failure {
2284    use crate::TestServer;
2285    use axum::Router;
2286    use axum::routing::get;
2287    use http::StatusCode;
2288
2289    #[tokio::test]
2290    async fn it_should_not_panic_if_expect_failure_on_404() {
2291        // Build an application with a route.
2292        let app = Router::new();
2293
2294        // Run the server.
2295        let mut server = TestServer::new(app).expect("Should create test server");
2296        server.expect_failure();
2297
2298        // Get the request.
2299        server.get(&"/some_unknown_route").await;
2300    }
2301
2302    #[tokio::test]
2303    #[should_panic]
2304    async fn it_should_panic_if_success_is_returned() {
2305        async fn get_ping() -> &'static str {
2306            "pong!"
2307        }
2308
2309        // Build an application with a route.
2310        let app = Router::new().route("/ping", get(get_ping));
2311
2312        // Run the server.
2313        let mut server = TestServer::new(app).expect("Should create test server");
2314        server.expect_failure();
2315
2316        // Get the request.
2317        server.get(&"/ping").await;
2318    }
2319
2320    #[tokio::test]
2321    #[should_panic]
2322    async fn it_should_panic_on_other_2xx_status_code() {
2323        async fn get_accepted() -> StatusCode {
2324            StatusCode::ACCEPTED
2325        }
2326
2327        // Build an application with a route.
2328        let app = Router::new().route("/accepted", get(get_accepted));
2329
2330        // Run the server.
2331        let mut server = TestServer::new(app).expect("Should create test server");
2332        server.expect_failure();
2333
2334        // Get the request.
2335        server.get(&"/accepted").await;
2336    }
2337}
2338
2339#[cfg(test)]
2340mod test_scheme {
2341    use axum::Router;
2342    use axum::extract::Request;
2343    use axum::routing::get;
2344
2345    use crate::TestServer;
2346
2347    async fn route_get_scheme(request: Request) -> String {
2348        request.uri().scheme_str().unwrap().to_string()
2349    }
2350
2351    #[tokio::test]
2352    async fn it_should_return_http_by_default() {
2353        let router = Router::new().route("/scheme", get(route_get_scheme));
2354        let server = TestServer::builder().build(router).unwrap();
2355
2356        server.get("/scheme").await.assert_text("http");
2357    }
2358
2359    #[tokio::test]
2360    async fn it_should_return_https_across_multiple_requests_when_set() {
2361        let router = Router::new().route("/scheme", get(route_get_scheme));
2362        let mut server = TestServer::builder().build(router).unwrap();
2363        server.scheme(&"https");
2364
2365        server.get("/scheme").await.assert_text("https");
2366    }
2367}
2368
2369#[cfg(feature = "typed-routing")]
2370#[cfg(test)]
2371mod test_typed_get {
2372    use super::*;
2373
2374    use axum::Router;
2375    use axum_extra::routing::RouterExt;
2376    use serde::Deserialize;
2377
2378    #[derive(TypedPath, Deserialize)]
2379    #[typed_path("/path/{id}")]
2380    struct TestingPath {
2381        id: u32,
2382    }
2383
2384    async fn route_get(TestingPath { id }: TestingPath) -> String {
2385        format!("get {id}")
2386    }
2387
2388    fn new_app() -> Router {
2389        Router::new().typed_get(route_get)
2390    }
2391
2392    #[tokio::test]
2393    async fn it_should_send_get() {
2394        let server = TestServer::new(new_app()).unwrap();
2395
2396        server
2397            .typed_get(&TestingPath { id: 123 })
2398            .await
2399            .assert_text("get 123");
2400    }
2401}
2402
2403#[cfg(feature = "typed-routing")]
2404#[cfg(test)]
2405mod test_typed_post {
2406    use super::*;
2407
2408    use axum::Router;
2409    use axum_extra::routing::RouterExt;
2410    use serde::Deserialize;
2411
2412    #[derive(TypedPath, Deserialize)]
2413    #[typed_path("/path/{id}")]
2414    struct TestingPath {
2415        id: u32,
2416    }
2417
2418    async fn route_post(TestingPath { id }: TestingPath) -> String {
2419        format!("post {id}")
2420    }
2421
2422    fn new_app() -> Router {
2423        Router::new().typed_post(route_post)
2424    }
2425
2426    #[tokio::test]
2427    async fn it_should_send_post() {
2428        let server = TestServer::new(new_app()).unwrap();
2429
2430        server
2431            .typed_post(&TestingPath { id: 123 })
2432            .await
2433            .assert_text("post 123");
2434    }
2435}
2436
2437#[cfg(feature = "typed-routing")]
2438#[cfg(test)]
2439mod test_typed_patch {
2440    use super::*;
2441
2442    use axum::Router;
2443    use axum_extra::routing::RouterExt;
2444    use serde::Deserialize;
2445
2446    #[derive(TypedPath, Deserialize)]
2447    #[typed_path("/path/{id}")]
2448    struct TestingPath {
2449        id: u32,
2450    }
2451
2452    async fn route_patch(TestingPath { id }: TestingPath) -> String {
2453        format!("patch {id}")
2454    }
2455
2456    fn new_app() -> Router {
2457        Router::new().typed_patch(route_patch)
2458    }
2459
2460    #[tokio::test]
2461    async fn it_should_send_patch() {
2462        let server = TestServer::new(new_app()).unwrap();
2463
2464        server
2465            .typed_patch(&TestingPath { id: 123 })
2466            .await
2467            .assert_text("patch 123");
2468    }
2469}
2470
2471#[cfg(feature = "typed-routing")]
2472#[cfg(test)]
2473mod test_typed_put {
2474    use super::*;
2475
2476    use axum::Router;
2477    use axum_extra::routing::RouterExt;
2478    use serde::Deserialize;
2479
2480    #[derive(TypedPath, Deserialize)]
2481    #[typed_path("/path/{id}")]
2482    struct TestingPath {
2483        id: u32,
2484    }
2485
2486    async fn route_put(TestingPath { id }: TestingPath) -> String {
2487        format!("put {id}")
2488    }
2489
2490    fn new_app() -> Router {
2491        Router::new().typed_put(route_put)
2492    }
2493
2494    #[tokio::test]
2495    async fn it_should_send_put() {
2496        let server = TestServer::new(new_app()).unwrap();
2497
2498        server
2499            .typed_put(&TestingPath { id: 123 })
2500            .await
2501            .assert_text("put 123");
2502    }
2503}
2504
2505#[cfg(feature = "typed-routing")]
2506#[cfg(test)]
2507mod test_typed_delete {
2508    use super::*;
2509
2510    use axum::Router;
2511    use axum_extra::routing::RouterExt;
2512    use serde::Deserialize;
2513
2514    #[derive(TypedPath, Deserialize)]
2515    #[typed_path("/path/{id}")]
2516    struct TestingPath {
2517        id: u32,
2518    }
2519
2520    async fn route_delete(TestingPath { id }: TestingPath) -> String {
2521        format!("delete {id}")
2522    }
2523
2524    fn new_app() -> Router {
2525        Router::new().typed_delete(route_delete)
2526    }
2527
2528    #[tokio::test]
2529    async fn it_should_send_delete() {
2530        let server = TestServer::new(new_app()).unwrap();
2531
2532        server
2533            .typed_delete(&TestingPath { id: 123 })
2534            .await
2535            .assert_text("delete 123");
2536    }
2537}
2538
2539#[cfg(feature = "typed-routing")]
2540#[cfg(test)]
2541mod test_typed_method {
2542    use super::*;
2543
2544    use axum::Router;
2545    use axum_extra::routing::RouterExt;
2546    use serde::Deserialize;
2547
2548    #[derive(TypedPath, Deserialize)]
2549    #[typed_path("/path/{id}")]
2550    struct TestingPath {
2551        id: u32,
2552    }
2553
2554    async fn route_get(TestingPath { id }: TestingPath) -> String {
2555        format!("get {id}")
2556    }
2557
2558    async fn route_post(TestingPath { id }: TestingPath) -> String {
2559        format!("post {id}")
2560    }
2561
2562    async fn route_patch(TestingPath { id }: TestingPath) -> String {
2563        format!("patch {id}")
2564    }
2565
2566    async fn route_put(TestingPath { id }: TestingPath) -> String {
2567        format!("put {id}")
2568    }
2569
2570    async fn route_delete(TestingPath { id }: TestingPath) -> String {
2571        format!("delete {id}")
2572    }
2573
2574    fn new_app() -> Router {
2575        Router::new()
2576            .typed_get(route_get)
2577            .typed_post(route_post)
2578            .typed_patch(route_patch)
2579            .typed_put(route_put)
2580            .typed_delete(route_delete)
2581    }
2582
2583    #[tokio::test]
2584    async fn it_should_send_get() {
2585        let server = TestServer::new(new_app()).unwrap();
2586
2587        server
2588            .typed_method(Method::GET, &TestingPath { id: 123 })
2589            .await
2590            .assert_text("get 123");
2591    }
2592
2593    #[tokio::test]
2594    async fn it_should_send_post() {
2595        let server = TestServer::new(new_app()).unwrap();
2596
2597        server
2598            .typed_method(Method::POST, &TestingPath { id: 123 })
2599            .await
2600            .assert_text("post 123");
2601    }
2602
2603    #[tokio::test]
2604    async fn it_should_send_patch() {
2605        let server = TestServer::new(new_app()).unwrap();
2606
2607        server
2608            .typed_method(Method::PATCH, &TestingPath { id: 123 })
2609            .await
2610            .assert_text("patch 123");
2611    }
2612
2613    #[tokio::test]
2614    async fn it_should_send_put() {
2615        let server = TestServer::new(new_app()).unwrap();
2616
2617        server
2618            .typed_method(Method::PUT, &TestingPath { id: 123 })
2619            .await
2620            .assert_text("put 123");
2621    }
2622
2623    #[tokio::test]
2624    async fn it_should_send_delete() {
2625        let server = TestServer::new(new_app()).unwrap();
2626
2627        server
2628            .typed_method(Method::DELETE, &TestingPath { id: 123 })
2629            .await
2630            .assert_text("delete 123");
2631    }
2632}
2633
2634#[cfg(test)]
2635mod test_sync {
2636    use super::*;
2637    use axum::Router;
2638    use axum::routing::get;
2639    use std::cell::OnceCell;
2640
2641    #[tokio::test]
2642    async fn it_should_be_able_to_be_in_one_cell() {
2643        let cell: OnceCell<TestServer> = OnceCell::new();
2644        let server = cell.get_or_init(|| {
2645            async fn route_get() -> &'static str {
2646                "it works"
2647            }
2648
2649            let router = Router::new().route("/test", get(route_get));
2650
2651            TestServer::new(router).unwrap()
2652        });
2653
2654        server.get("/test").await.assert_text("it works");
2655    }
2656}
2657
2658#[cfg(test)]
2659mod test_is_running {
2660    use super::*;
2661    use crate::util::new_random_tokio_tcp_listener;
2662    use axum::Router;
2663    use axum::routing::IntoMakeService;
2664    use axum::routing::get;
2665    use axum::serve;
2666    use std::time::Duration;
2667    use tokio::sync::Notify;
2668    use tokio::time::sleep;
2669
2670    async fn get_ping() -> &'static str {
2671        "pong!"
2672    }
2673
2674    #[tokio::test]
2675    #[should_panic]
2676    async fn it_should_panic_when_run_with_mock_http() {
2677        let shutdown_notification = Arc::new(Notify::new());
2678        let waiting_notification = shutdown_notification.clone();
2679
2680        // Build an application with a route.
2681        let app: IntoMakeService<Router> = Router::new()
2682            .route("/ping", get(get_ping))
2683            .into_make_service();
2684        let port = new_random_tokio_tcp_listener().unwrap();
2685        let application = serve(port, app)
2686            .with_graceful_shutdown(async move { waiting_notification.notified().await });
2687
2688        // Run the server.
2689        let server = TestServer::builder()
2690            .build(application)
2691            .expect("Should create test server");
2692
2693        server.get("/ping").await.assert_status_ok();
2694        assert!(server.is_running());
2695
2696        shutdown_notification.notify_one();
2697        sleep(Duration::from_millis(10)).await;
2698
2699        assert!(!server.is_running());
2700        server.get("/ping").await.assert_status_ok();
2701    }
2702}