axum_test/
test_server.rs

1use anyhow::Context;
2use anyhow::Result;
3use anyhow::anyhow;
4use cookie::Cookie;
5use cookie::CookieJar;
6use http::HeaderName;
7use http::HeaderValue;
8use http::Method;
9use http::Uri;
10use serde::Serialize;
11use std::fmt::Debug;
12use std::sync::Arc;
13use std::sync::Mutex;
14use url::Url;
15
16#[cfg(feature = "typed-routing")]
17use axum_extra::routing::TypedPath;
18
19#[cfg(feature = "reqwest")]
20use crate::transport_layer::TransportLayerType;
21#[cfg(feature = "reqwest")]
22use reqwest::Client;
23#[cfg(feature = "reqwest")]
24use reqwest::RequestBuilder;
25
26use crate::TestRequest;
27use crate::TestRequestConfig;
28use crate::TestServerBuilder;
29use crate::TestServerConfig;
30use crate::Transport;
31use crate::internals::ExpectedState;
32use crate::internals::QueryParamsStore;
33use crate::internals::RequestPathFormatter;
34use crate::transport_layer::IntoTransportLayer;
35use crate::transport_layer::TransportLayer;
36use crate::transport_layer::TransportLayerBuilder;
37
38mod server_shared_state;
39pub(crate) use self::server_shared_state::*;
40
41const DEFAULT_URL_ADDRESS: &str = "http://localhost";
42
43///
44/// The `TestServer` runs your Axum application,
45/// allowing you to make HTTP requests against it.
46///
47/// # Building
48///
49/// A `TestServer` can be used to run an [`axum::Router`], an [`::axum::routing::IntoMakeService`],
50/// and others.
51///
52/// The most straight forward approach is to call [`crate::TestServer::new`],
53/// and pass in your application:
54///
55/// ```rust
56/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
57/// #
58/// use axum::Router;
59/// use axum::routing::get;
60///
61/// use axum_test::TestServer;
62///
63/// let app = Router::new()
64///     .route(&"/hello", get(|| async { "hello!" }));
65///
66/// let server = TestServer::new(app)?;
67/// #
68/// # Ok(())
69/// # }
70/// ```
71///
72/// # Requests
73///
74/// Requests are built by calling [`TestServer::get()`](crate::TestServer::get()),
75/// [`TestServer::post()`](crate::TestServer::post()), [`TestServer::put()`](crate::TestServer::put()),
76/// [`TestServer::delete()`](crate::TestServer::delete()), and [`TestServer::patch()`](crate::TestServer::patch()) methods.
77/// Each returns a [`TestRequest`](crate::TestRequest), which allows for customising the request content.
78///
79/// For example:
80///
81/// ```rust
82/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
83/// #
84/// use axum::Router;
85/// use axum::routing::get;
86///
87/// use axum_test::TestServer;
88///
89/// let app = Router::new()
90///     .route(&"/hello", get(|| async { "hello!" }));
91///
92/// let server = TestServer::new(app)?;
93///
94/// let response = server.get(&"/hello")
95///     .authorization_bearer("password12345")
96///     .add_header("x-custom-header", "custom-value")
97///     .await;
98///
99/// response.assert_text("hello!");
100/// #
101/// # Ok(())
102/// # }
103/// ```
104///
105/// Request methods also exist for using Axum Extra [`axum_extra::routing::TypedPath`],
106/// or for building Reqwest [`reqwest::RequestBuilder`]. See those methods for detauls.
107///
108/// # Customising
109///
110/// A `TestServer` can be built from a builder, by calling [`crate::TestServer::builder`],
111/// and customising settings. This allows one to set **mocked** (default when possible)
112/// or **real http** networking for your service.
113///
114/// ```rust
115/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
116/// #
117/// use axum::Router;
118/// use axum::routing::get;
119///
120/// use axum_test::TestServer;
121///
122/// let app = Router::new()
123///     .route(&"/hello", get(|| async { "hello!" }));
124///
125/// // Customise server when building
126/// let mut server = TestServer::builder()
127///     .http_transport()
128///     .expect_success_by_default()
129///     .save_cookies()
130///     .build(app)?;
131///
132/// // Add items to be sent on _all_ all requests
133/// server.add_header("x-custom-for-all", "common-value");
134///
135/// let response = server.get("/hello").await;
136/// #
137/// # Ok(())
138/// # }
139/// ```
140///
141#[derive(Debug)]
142pub struct TestServer {
143    state: Arc<Mutex<ServerSharedState>>,
144    transport: Arc<Box<dyn TransportLayer>>,
145    save_cookies: bool,
146    expected_state: ExpectedState,
147    default_content_type: Option<String>,
148    is_http_path_restricted: bool,
149
150    #[cfg(feature = "reqwest")]
151    maybe_reqwest_client: Option<Client>,
152}
153
154impl TestServer {
155    /// A helper function to create a builder for creating a [`TestServer`].
156    pub fn builder() -> TestServerBuilder {
157        TestServerBuilder::default()
158    }
159
160    /// This will run the given Axum app,
161    /// allowing you to make requests against it.
162    ///
163    /// This is the same as creating a new `TestServer` with a configuration,
164    /// and passing [`crate::TestServerConfig::default()`].
165    ///
166    /// ```rust
167    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
168    /// #
169    /// use axum::Router;
170    /// use axum::routing::get;
171    /// use axum_test::TestServer;
172    ///
173    /// let app = Router::new()
174    ///     .route(&"/hello", get(|| async { "hello!" }));
175    ///
176    /// let server = TestServer::new(app)?;
177    /// #
178    /// # Ok(())
179    /// # }
180    /// ```
181    ///
182    /// The of applications that can be passed in include:
183    ///
184    ///  - [`axum::Router`]
185    ///  - [`axum::routing::IntoMakeService`]
186    ///  - [`axum::extract::connect_info::IntoMakeServiceWithConnectInfo`]
187    ///  - [`axum::serve::Serve`]
188    ///  - [`axum::serve::WithGracefulShutdown`]
189    ///
190    pub fn new<A>(app: A) -> Result<Self>
191    where
192        A: IntoTransportLayer,
193    {
194        Self::new_with_config(app, TestServerConfig::default())
195    }
196
197    /// Similar to [`TestServer::new()`], with a customised configuration.
198    /// This includes type of transport in use (i.e. specify a specific port),
199    /// or change default settings (like the default content type for requests).
200    ///
201    /// This can take a [`crate::TestServerConfig`] or a [`crate::TestServerBuilder`].
202    /// See those for more information on configuration settings.
203    pub fn new_with_config<A, C>(app: A, config: C) -> Result<Self>
204    where
205        A: IntoTransportLayer,
206        C: Into<TestServerConfig>,
207    {
208        let config = config.into();
209        let mut shared_state = ServerSharedState::new();
210        if let Some(scheme) = config.default_scheme {
211            shared_state.set_scheme_unlocked(scheme);
212        }
213
214        let shared_state_mutex = Mutex::new(shared_state);
215        let state = Arc::new(shared_state_mutex);
216
217        let transport = match config.transport {
218            None => {
219                let builder = TransportLayerBuilder::new(None, None);
220                let transport = app.into_default_transport(builder)?;
221                Arc::new(transport)
222            }
223            Some(Transport::HttpRandomPort) => {
224                let builder = TransportLayerBuilder::new(None, None);
225                let transport = app.into_http_transport_layer(builder)?;
226                Arc::new(transport)
227            }
228            Some(Transport::HttpIpPort { ip, port }) => {
229                let builder = TransportLayerBuilder::new(ip, port);
230                let transport = app.into_http_transport_layer(builder)?;
231                Arc::new(transport)
232            }
233            Some(Transport::MockHttp) => {
234                let transport = app.into_mock_transport_layer()?;
235                Arc::new(transport)
236            }
237        };
238
239        let expected_state = match config.expect_success_by_default {
240            true => ExpectedState::Success,
241            false => ExpectedState::None,
242        };
243
244        #[cfg(feature = "reqwest")]
245        let maybe_reqwest_client = match transport.transport_layer_type() {
246            TransportLayerType::Http => {
247                let reqwest_client = reqwest::Client::builder()
248                    .redirect(reqwest::redirect::Policy::none())
249                    .cookie_store(config.save_cookies)
250                    .build()
251                    .expect("Failed to build Reqwest Client");
252
253                Some(reqwest_client)
254            }
255            TransportLayerType::Mock => None,
256        };
257
258        Ok(Self {
259            state,
260            transport,
261            save_cookies: config.save_cookies,
262            expected_state,
263            default_content_type: config.default_content_type,
264            is_http_path_restricted: config.restrict_requests_with_http_schema,
265
266            #[cfg(feature = "reqwest")]
267            maybe_reqwest_client,
268        })
269    }
270
271    /// Creates a HTTP GET request to the path.
272    pub fn get(&self, path: &str) -> TestRequest {
273        self.method(Method::GET, path)
274    }
275
276    /// Creates a HTTP POST request to the given path.
277    pub fn post(&self, path: &str) -> TestRequest {
278        self.method(Method::POST, path)
279    }
280
281    /// Creates a HTTP PATCH request to the path.
282    pub fn patch(&self, path: &str) -> TestRequest {
283        self.method(Method::PATCH, path)
284    }
285
286    /// Creates a HTTP PUT request to the path.
287    pub fn put(&self, path: &str) -> TestRequest {
288        self.method(Method::PUT, path)
289    }
290
291    /// Creates a HTTP DELETE request to the path.
292    pub fn delete(&self, path: &str) -> TestRequest {
293        self.method(Method::DELETE, path)
294    }
295
296    /// Creates a HTTP request, to the method and path provided.
297    pub fn method(&self, method: Method, path: &str) -> TestRequest {
298        let maybe_config = self.build_test_request_config(method.clone(), path);
299        let config = maybe_config
300            .with_context(|| format!("Failed to build, for request {method} {path}"))
301            .unwrap();
302
303        TestRequest::new(self.state.clone(), self.transport.clone(), config)
304    }
305
306    #[cfg(feature = "reqwest")]
307    fn reqwest_client(&self) -> &Client {
308        self.maybe_reqwest_client
309            .as_ref()
310            .expect("Reqwest client is not available, TestServer must be build with HTTP transport for Reqwest to be available")
311    }
312
313    #[cfg(feature = "reqwest")]
314    pub fn reqwest_get(&self, path: &str) -> RequestBuilder {
315        self.reqwest_method(Method::GET, path)
316    }
317
318    #[cfg(feature = "reqwest")]
319    pub fn reqwest_post(&self, path: &str) -> RequestBuilder {
320        self.reqwest_method(Method::POST, path)
321    }
322
323    #[cfg(feature = "reqwest")]
324    pub fn reqwest_put(&self, path: &str) -> RequestBuilder {
325        self.reqwest_method(Method::PUT, path)
326    }
327
328    #[cfg(feature = "reqwest")]
329    pub fn reqwest_patch(&self, path: &str) -> RequestBuilder {
330        self.reqwest_method(Method::PATCH, path)
331    }
332
333    #[cfg(feature = "reqwest")]
334    pub fn reqwest_delete(&self, path: &str) -> RequestBuilder {
335        self.reqwest_method(Method::DELETE, path)
336    }
337
338    #[cfg(feature = "reqwest")]
339    pub fn reqwest_head(&self, path: &str) -> RequestBuilder {
340        self.reqwest_method(Method::HEAD, path)
341    }
342
343    /// Creates a HTTP request, using Reqwest, using the method + path described.
344    /// This expects a relative url to the `TestServer`.
345    ///
346    /// ```rust
347    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
348    /// #
349    /// use axum::Router;
350    /// use axum_test::TestServer;
351    ///
352    /// let my_app = Router::new();
353    /// let server = TestServer::builder()
354    ///     .http_transport() // Important, must be HTTP!
355    ///     .build(my_app)?;
356    ///
357    /// // Build your request
358    /// let request = server.get(&"/user")
359    ///     .add_header("x-custom-header", "example.com")
360    ///     .content_type("application/yaml");
361    ///
362    /// // await request to execute
363    /// let response = request.await;
364    /// #
365    /// # Ok(()) }
366    /// ```
367    #[cfg(feature = "reqwest")]
368    pub fn reqwest_method(&self, method: Method, path: &str) -> RequestBuilder {
369        let request_url = self
370            .server_url(path)
371            .expect("Failed to generate server url for request {method} {path}");
372
373        self.reqwest_client().request(method, request_url)
374    }
375
376    /// Creates a request to the server, to start a Websocket connection,
377    /// on the path given.
378    ///
379    /// This is the requivalent of making a GET request to the endpoint,
380    /// and setting the various headers needed for making an upgrade request.
381    ///
382    /// *Note*, this requires the server to be running on a real HTTP
383    /// port. Either using a randomly assigned port, or a specified one.
384    /// See the [`TestServerConfig::transport`](crate::TestServerConfig::transport) for more details.
385    ///
386    /// # Example
387    ///
388    /// ```rust
389    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
390    /// #
391    /// use axum::Router;
392    /// use axum_test::TestServer;
393    ///
394    /// let app = Router::new();
395    /// let server = TestServer::builder()
396    ///     .http_transport()
397    ///     .build(app)?;
398    ///
399    /// let mut websocket = server
400    ///     .get_websocket(&"/my-web-socket-end-point")
401    ///     .await
402    ///     .into_websocket()
403    ///     .await;
404    ///
405    /// websocket.send_text("Hello!").await;
406    /// #
407    /// # Ok(()) }
408    /// ```
409    ///
410    #[cfg(feature = "ws")]
411    pub fn get_websocket(&self, path: &str) -> TestRequest {
412        use http::header;
413
414        self.get(path)
415            .add_header(header::CONNECTION, "upgrade")
416            .add_header(header::UPGRADE, "websocket")
417            .add_header(header::SEC_WEBSOCKET_VERSION, "13")
418            .add_header(
419                header::SEC_WEBSOCKET_KEY,
420                crate::internals::generate_ws_key(),
421            )
422    }
423
424    /// Creates a HTTP GET request, using the typed path provided.
425    ///
426    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
427    ///
428    /// # Example Test
429    ///
430    /// Using a `TypedPath` you can write build and test a route like below:
431    ///
432    /// ```rust
433    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
434    /// #
435    /// use axum::Json;
436    /// use axum::Router;
437    /// use axum::routing::get;
438    /// use axum_extra::routing::RouterExt;
439    /// use axum_extra::routing::TypedPath;
440    /// use serde::Deserialize;
441    /// use serde::Serialize;
442    ///
443    /// use axum_test::TestServer;
444    ///
445    /// #[derive(TypedPath, Deserialize)]
446    /// #[typed_path("/users/{user_id}")]
447    /// struct UserPath {
448    ///     pub user_id: u32,
449    /// }
450    ///
451    /// // Build a typed route:
452    /// async fn route_get_user(UserPath { user_id }: UserPath) -> String {
453    ///     format!("hello user {user_id}")
454    /// }
455    ///
456    /// let app = Router::new()
457    ///     .typed_get(route_get_user);
458    ///
459    /// // Then test the route:
460    /// let server = TestServer::new(app)?;
461    /// server
462    ///     .typed_get(&UserPath { user_id: 123 })
463    ///     .await
464    ///     .assert_text("hello user 123");
465    /// #
466    /// # Ok(())
467    /// # }
468    /// ```
469    ///
470    #[cfg(feature = "typed-routing")]
471    pub fn typed_get<P>(&self, path: &P) -> TestRequest
472    where
473        P: TypedPath,
474    {
475        self.typed_method(Method::GET, path)
476    }
477
478    /// Creates a HTTP POST request, using the typed path provided.
479    ///
480    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
481    #[cfg(feature = "typed-routing")]
482    pub fn typed_post<P>(&self, path: &P) -> TestRequest
483    where
484        P: TypedPath,
485    {
486        self.typed_method(Method::POST, path)
487    }
488
489    /// Creates a HTTP PATCH request, using the typed path provided.
490    ///
491    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
492    #[cfg(feature = "typed-routing")]
493    pub fn typed_patch<P>(&self, path: &P) -> TestRequest
494    where
495        P: TypedPath,
496    {
497        self.typed_method(Method::PATCH, path)
498    }
499
500    /// Creates a HTTP PUT request, using the typed path provided.
501    ///
502    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
503    #[cfg(feature = "typed-routing")]
504    pub fn typed_put<P>(&self, path: &P) -> TestRequest
505    where
506        P: TypedPath,
507    {
508        self.typed_method(Method::PUT, path)
509    }
510
511    /// Creates a HTTP DELETE request, using the typed path provided.
512    ///
513    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
514    #[cfg(feature = "typed-routing")]
515    pub fn typed_delete<P>(&self, path: &P) -> TestRequest
516    where
517        P: TypedPath,
518    {
519        self.typed_method(Method::DELETE, path)
520    }
521
522    /// Creates a typed HTTP request, using the method provided.
523    ///
524    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
525    #[cfg(feature = "typed-routing")]
526    pub fn typed_method<P>(&self, method: Method, path: &P) -> TestRequest
527    where
528        P: TypedPath,
529    {
530        self.method(method, &path.to_string())
531    }
532
533    /// Returns the local web address for the test server,
534    /// if an address is available.
535    ///
536    /// The address is available when running as a real web server,
537    /// by setting the [`TestServerConfig`](crate::TestServerConfig) `transport` field to `Transport::HttpRandomPort` or `Transport::HttpIpPort`.
538    ///
539    /// This will return `None` when there is mock HTTP transport (the default).
540    pub fn server_address(&self) -> Option<Url> {
541        self.url()
542    }
543
544    /// This turns a relative path, into an absolute path to the server.
545    /// i.e. A path like `/users/123` will become something like `http://127.0.0.1:1234/users/123`.
546    ///
547    /// The absolute address can be used to make requests to the running server,
548    /// using any appropriate client you wish.
549    ///
550    /// # Example
551    ///
552    /// ```rust
553    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
554    /// #
555    /// use axum::Router;
556    /// use axum_test::TestServer;
557    ///
558    /// let app = Router::new();
559    /// let server = TestServer::builder()
560    ///         .http_transport()
561    ///         .build(app)?;
562    ///
563    /// let full_url = server.server_url(&"/users/123?filter=enabled")?;
564    ///
565    /// // Prints something like ... http://127.0.0.1:1234/users/123?filter=enabled
566    /// println!("{full_url}");
567    /// #
568    /// # Ok(()) }
569    /// ```
570    ///
571    /// This will return an error if you are using the mock transport.
572    /// Real HTTP transport is required to use this method (see [`TestServerConfig`](crate::TestServerConfig) `transport` field).
573    ///
574    /// It will also return an error if you provide an absolute path,
575    /// for example if you pass in `http://google.com`.
576    pub fn server_url(&self, path: &str) -> Result<Url> {
577        let path_uri = path.parse::<Uri>()?;
578        if is_absolute_uri(&path_uri) {
579            return Err(anyhow!(
580                "Absolute path provided for building server url, need to provide a relative uri"
581            ));
582        }
583
584        let server_url = self.url()
585            .ok_or_else(||
586                anyhow!(
587                    "No local address for server, need to run with HTTP transport to have a server address",
588                )
589            )?;
590
591        let server_locked = self.state.as_ref().lock().map_err(|err| {
592            anyhow!("Failed to lock InternalTestServer, for building server_url, received {err:?}",)
593        })?;
594        let mut query_params = server_locked.query_params().clone();
595        let mut full_server_url = build_url(
596            server_url,
597            path,
598            &mut query_params,
599            self.is_http_path_restricted,
600        )?;
601
602        // Ensure the query params are present
603        if query_params.has_content() {
604            full_server_url.set_query(Some(&query_params.to_string()));
605        }
606
607        Ok(full_server_url)
608    }
609
610    /// Adds a single cookie to be included on *all* future requests.
611    ///
612    /// If a cookie with the same name already exists,
613    /// then it will be replaced.
614    pub fn add_cookie(&mut self, cookie: Cookie) {
615        ServerSharedState::add_cookie(&self.state, cookie)
616            .context("Trying to call add_cookie")
617            .unwrap()
618    }
619
620    /// Adds extra cookies to be used on *all* future requests.
621    ///
622    /// Any cookies which have the same name as the new cookies,
623    /// will get replaced.
624    pub fn add_cookies(&mut self, cookies: CookieJar) {
625        ServerSharedState::add_cookies(&self.state, cookies)
626            .context("Trying to call add_cookies")
627            .unwrap()
628    }
629
630    /// Clears all of the cookies stored internally.
631    pub fn clear_cookies(&mut self) {
632        ServerSharedState::clear_cookies(&self.state)
633            .context("Trying to call clear_cookies")
634            .unwrap()
635    }
636
637    /// Requests made using this `TestServer` will save their cookies for future requests to send.
638    ///
639    /// This behaviour is off by default.
640    pub fn save_cookies(&mut self) {
641        self.save_cookies = true;
642    }
643
644    /// Requests made using this `TestServer` will _not_ save their cookies for future requests to send up.
645    ///
646    /// This is the default behaviour.
647    pub fn do_not_save_cookies(&mut self) {
648        self.save_cookies = false;
649    }
650
651    /// Requests made using this `TestServer` will assert a HTTP status in the 2xx range will be returned, unless marked otherwise.
652    ///
653    /// By default this behaviour is off.
654    pub fn expect_success(&mut self) {
655        self.expected_state = ExpectedState::Success;
656    }
657
658    /// Requests made using this `TestServer` will assert a HTTP status is outside the 2xx range will be returned, unless marked otherwise.
659    ///
660    /// By default this behaviour is off.
661    pub fn expect_failure(&mut self) {
662        self.expected_state = ExpectedState::Failure;
663    }
664
665    /// Adds a query parameter to be sent on *all* future requests.
666    pub fn add_query_param<V>(&mut self, key: &str, value: V)
667    where
668        V: Serialize,
669    {
670        ServerSharedState::add_query_param(&self.state, key, value)
671            .context("Trying to call add_query_param")
672            .unwrap()
673    }
674
675    /// Adds query parameters to be sent on *all* future requests.
676    pub fn add_query_params<V>(&mut self, query_params: V)
677    where
678        V: Serialize,
679    {
680        ServerSharedState::add_query_params(&self.state, query_params)
681            .context("Trying to call add_query_params")
682            .unwrap()
683    }
684
685    /// Adds a raw query param, with no urlencoding of any kind,
686    /// to be send on *all* future requests.
687    pub fn add_raw_query_param(&mut self, raw_query_param: &str) {
688        ServerSharedState::add_raw_query_param(&self.state, raw_query_param)
689            .context("Trying to call add_raw_query_param")
690            .unwrap()
691    }
692
693    /// Clears all query params set.
694    pub fn clear_query_params(&mut self) {
695        ServerSharedState::clear_query_params(&self.state)
696            .context("Trying to call clear_query_params")
697            .unwrap()
698    }
699
700    /// Adds a header to be sent with all future requests built from this `TestServer`.
701    ///
702    /// ```rust
703    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
704    /// #
705    /// use axum::Router;
706    /// use axum_test::TestServer;
707    ///
708    /// let app = Router::new();
709    /// let mut server = TestServer::new(app)?;
710    ///
711    /// server.add_header("x-custom-header", "custom-value");
712    /// server.add_header(http::header::CONTENT_LENGTH, 12345);
713    /// server.add_header(http::header::HOST, "example.com");
714    ///
715    /// let response = server.get(&"/my-end-point")
716    ///     .await;
717    /// #
718    /// # Ok(()) }
719    /// ```
720    pub fn add_header<N, V>(&mut self, name: N, value: V)
721    where
722        N: TryInto<HeaderName>,
723        N::Error: Debug,
724        V: TryInto<HeaderValue>,
725        V::Error: Debug,
726    {
727        let header_name: HeaderName = name
728            .try_into()
729            .expect("Failed to convert header name to HeaderName");
730        let header_value: HeaderValue = value
731            .try_into()
732            .expect("Failed to convert header vlue to HeaderValue");
733
734        ServerSharedState::add_header(&self.state, header_name, header_value)
735            .context("Trying to call add_header")
736            .unwrap()
737    }
738
739    /// Clears all headers set so far.
740    pub fn clear_headers(&mut self) {
741        ServerSharedState::clear_headers(&self.state)
742            .context("Trying to call clear_headers")
743            .unwrap()
744    }
745
746    /// Sets the scheme to use when making _all_ requests from the `TestServer`.
747    /// i.e. http or https.
748    ///
749    /// The default scheme is 'http'.
750    ///
751    /// ```rust
752    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
753    /// #
754    /// use axum::Router;
755    /// use axum_test::TestServer;
756    ///
757    /// let app = Router::new();
758    /// let mut server = TestServer::new(app)?;
759    /// server
760    ///     .scheme(&"https");
761    ///
762    /// let response = server
763    ///     .get(&"/my-end-point")
764    ///     .await;
765    /// #
766    /// # Ok(()) }
767    /// ```
768    ///
769    pub fn scheme(&mut self, scheme: &str) {
770        ServerSharedState::set_scheme(&self.state, scheme.to_string())
771            .context("Trying to call set_scheme")
772            .unwrap()
773    }
774
775    pub(crate) fn url(&self) -> Option<Url> {
776        self.transport.url().cloned()
777    }
778
779    pub(crate) fn build_test_request_config(
780        &self,
781        method: Method,
782        path: &str,
783    ) -> Result<TestRequestConfig> {
784        let url = self
785            .url()
786            .unwrap_or_else(|| DEFAULT_URL_ADDRESS.parse().unwrap());
787
788        let server_locked = self.state.as_ref().lock().map_err(|err| {
789            anyhow!(
790                "Failed to lock InternalTestServer, for request {method} {path}, received {err:?}",
791            )
792        })?;
793
794        let cookies = server_locked.cookies().clone();
795        let mut query_params = server_locked.query_params().clone();
796        let headers = server_locked.headers().clone();
797        let mut full_request_url =
798            build_url(url, path, &mut query_params, self.is_http_path_restricted)?;
799
800        if let Some(scheme) = server_locked.scheme() {
801            full_request_url.set_scheme(scheme).map_err(|_| {
802                let debug_request_format = RequestPathFormatter::new(&method, full_request_url.as_str(), Some(&query_params));
803                anyhow!("Scheme '{scheme}' from TestServer cannot be set to request {debug_request_format}")
804            })?;
805        }
806
807        ::std::mem::drop(server_locked);
808
809        Ok(TestRequestConfig {
810            is_saving_cookies: self.save_cookies,
811            expected_state: self.expected_state,
812            content_type: self.default_content_type.clone(),
813            method,
814
815            full_request_url,
816            cookies,
817            query_params,
818            headers,
819        })
820    }
821
822    /// Returns true or false if the underlying service inside the `TestServer`
823    /// is still running. For many types of services this will always return `true`.
824    ///
825    /// When a `TestServer` is built using [`axum::serve::WithGracefulShutdown`],
826    /// this will return false if the service has shutdown.
827    pub fn is_running(&self) -> bool {
828        self.transport.is_running()
829    }
830}
831
832fn build_url(
833    mut url: Url,
834    path: &str,
835    query_params: &mut QueryParamsStore,
836    is_http_restricted: bool,
837) -> Result<Url> {
838    let path_uri = path.parse::<Uri>()?;
839
840    // If there is a scheme, then this is an absolute path.
841    if let Some(scheme) = path_uri.scheme_str() {
842        if is_http_restricted {
843            if has_different_schema(&url, &path_uri) || has_different_authority(&url, &path_uri) {
844                return Err(anyhow!(
845                    "Request disallowed for path '{path}', requests are only allowed to local server. Turn off 'restrict_requests_with_http_schema' to change this."
846                ));
847            }
848        } else {
849            url.set_scheme(scheme)
850                .map_err(|_| anyhow!("Failed to set scheme for request, with path '{path}'"))?;
851
852            // We only set the host/port if the scheme is also present.
853            if let Some(authority) = path_uri.authority() {
854                url.set_host(Some(authority.host()))
855                    .map_err(|_| anyhow!("Failed to set host for request, with path '{path}'"))?;
856                url.set_port(authority.port().map(|p| p.as_u16()))
857                    .map_err(|_| anyhow!("Failed to set port for request, with path '{path}'"))?;
858
859                // todo, add username:password support
860            }
861        }
862    }
863
864    // Why does this exist?
865    //
866    // This exists to allow `server.get("/users")` and `server.get("users")` (without a slash)
867    // to go to the same place.
868    //
869    // It does this by saying ...
870    //  - if there is a scheme, it's a full path.
871    //  - if no scheme, it must be a path
872    //
873    if is_absolute_uri(&path_uri) {
874        url.set_path(path_uri.path());
875
876        // In this path we are replacing, so drop any query params on the original url.
877        if url.query().is_some() {
878            url.set_query(None);
879        }
880    } else {
881        // Grab everything up until the query parameters, or everything after that
882        let calculated_path = path.split('?').next().unwrap_or(path);
883        url.set_path(calculated_path);
884
885        // Move any query parameters from the url to the query params store.
886        if let Some(url_query) = url.query() {
887            query_params.add_raw(url_query.to_string());
888            url.set_query(None);
889        }
890    }
891
892    if let Some(path_query) = path_uri.query() {
893        query_params.add_raw(path_query.to_string());
894    }
895
896    Ok(url)
897}
898
899fn is_absolute_uri(path_uri: &Uri) -> bool {
900    path_uri.scheme_str().is_some()
901}
902
903fn has_different_schema(base_url: &Url, path_uri: &Uri) -> bool {
904    if let Some(scheme) = path_uri.scheme_str() {
905        return scheme != base_url.scheme();
906    }
907
908    false
909}
910
911fn has_different_authority(base_url: &Url, path_uri: &Uri) -> bool {
912    if let Some(authority) = path_uri.authority() {
913        return authority.as_str() != base_url.authority();
914    }
915
916    false
917}
918
919#[cfg(test)]
920mod test_build_url {
921    use super::*;
922
923    #[test]
924    fn it_should_copy_path_to_url_returned_when_restricted() {
925        let base_url = "http://example.com".parse::<Url>().unwrap();
926        let path = "/users";
927        let mut query_params = QueryParamsStore::new();
928        let result = build_url(base_url, &path, &mut query_params, true).unwrap();
929
930        assert_eq!("http://example.com/users", result.as_str());
931        assert!(query_params.is_empty());
932    }
933
934    #[test]
935    fn it_should_copy_all_query_params_to_store_when_restricted() {
936        let base_url = "http://example.com?base=aaa".parse::<Url>().unwrap();
937        let path = "/users?path=bbb&path-flag";
938        let mut query_params = QueryParamsStore::new();
939        let result = build_url(base_url, &path, &mut query_params, true).unwrap();
940
941        assert_eq!("http://example.com/users", result.as_str());
942        assert_eq!("base=aaa&path=bbb&path-flag", query_params.to_string());
943    }
944
945    #[test]
946    fn it_should_not_replace_url_when_restricted_with_different_scheme() {
947        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
948        let path = "ftp://google.com:123/users.csv?limit=456";
949        let mut query_params = QueryParamsStore::new();
950        let result = build_url(base_url, &path, &mut query_params, true);
951
952        assert!(result.is_err());
953    }
954
955    #[test]
956    fn it_should_not_replace_url_when_restricted_with_same_scheme() {
957        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
958        let path = "http://google.com:123/users.csv?limit=456";
959        let mut query_params = QueryParamsStore::new();
960        let result = build_url(base_url, &path, &mut query_params, true);
961
962        assert!(result.is_err());
963    }
964
965    #[test]
966    fn it_should_block_url_when_restricted_with_same_scheme() {
967        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
968        let path = "http://google.com";
969        let mut query_params = QueryParamsStore::new();
970        let result = build_url(base_url, &path, &mut query_params, true);
971
972        assert!(result.is_err());
973    }
974
975    #[test]
976    fn it_should_block_url_when_restricted_and_same_domain_with_different_scheme() {
977        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
978        let path = "ftp://example.com/users";
979        let mut query_params = QueryParamsStore::new();
980        let result = build_url(base_url, &path, &mut query_params, true);
981
982        assert!(result.is_err());
983    }
984
985    #[test]
986    fn it_should_copy_path_to_url_returned_when_unrestricted() {
987        let base_url = "http://example.com".parse::<Url>().unwrap();
988        let path = "/users";
989        let mut query_params = QueryParamsStore::new();
990        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
991
992        assert_eq!("http://example.com/users", result.as_str());
993        assert!(query_params.is_empty());
994    }
995
996    #[test]
997    fn it_should_copy_all_query_params_to_store_when_unrestricted() {
998        let base_url = "http://example.com?base=aaa".parse::<Url>().unwrap();
999        let path = "/users?path=bbb&path-flag";
1000        let mut query_params = QueryParamsStore::new();
1001        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1002
1003        assert_eq!("http://example.com/users", result.as_str());
1004        assert_eq!("base=aaa&path=bbb&path-flag", query_params.to_string());
1005    }
1006
1007    #[test]
1008    fn it_should_copy_host_like_a_path_when_unrestricted() {
1009        let base_url = "http://example.com".parse::<Url>().unwrap();
1010        let path = "google.com";
1011        let mut query_params = QueryParamsStore::new();
1012        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1013
1014        assert_eq!("http://example.com/google.com", result.as_str());
1015        assert!(query_params.is_empty());
1016    }
1017
1018    #[test]
1019    fn it_should_copy_host_like_a_path_when_restricted() {
1020        let base_url = "http://example.com".parse::<Url>().unwrap();
1021        let path = "google.com";
1022        let mut query_params = QueryParamsStore::new();
1023        let result = build_url(base_url, &path, &mut query_params, true).unwrap();
1024
1025        assert_eq!("http://example.com/google.com", result.as_str());
1026        assert!(query_params.is_empty());
1027    }
1028
1029    #[test]
1030    fn it_should_replace_url_when_unrestricted() {
1031        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
1032        let path = "ftp://google.com:123/users.csv?limit=456";
1033        let mut query_params = QueryParamsStore::new();
1034        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1035
1036        assert_eq!("ftp://google.com:123/users.csv", result.as_str());
1037        assert_eq!("limit=456", query_params.to_string());
1038    }
1039
1040    #[test]
1041    fn it_should_allow_different_scheme_when_unrestricted() {
1042        let base_url = "http://example.com".parse::<Url>().unwrap();
1043        let path = "ftp://example.com";
1044        let mut query_params = QueryParamsStore::new();
1045        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1046
1047        assert_eq!("ftp://example.com/", result.as_str());
1048    }
1049
1050    #[test]
1051    fn it_should_allow_different_host_when_unrestricted() {
1052        let base_url = "http://example.com".parse::<Url>().unwrap();
1053        let path = "http://google.com";
1054        let mut query_params = QueryParamsStore::new();
1055        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1056
1057        assert_eq!("http://google.com/", result.as_str());
1058    }
1059
1060    #[test]
1061    fn it_should_allow_different_port_when_unrestricted() {
1062        let base_url = "http://example.com:123".parse::<Url>().unwrap();
1063        let path = "http://example.com:456";
1064        let mut query_params = QueryParamsStore::new();
1065        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1066
1067        assert_eq!("http://example.com:456/", result.as_str());
1068    }
1069
1070    #[test]
1071    fn it_should_allow_same_host_port_when_unrestricted() {
1072        let base_url = "http://example.com:123".parse::<Url>().unwrap();
1073        let path = "http://example.com:123";
1074        let mut query_params = QueryParamsStore::new();
1075        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1076
1077        assert_eq!("http://example.com:123/", result.as_str());
1078    }
1079
1080    #[test]
1081    fn it_should_not_allow_different_scheme_when_restricted() {
1082        let base_url = "http://example.com".parse::<Url>().unwrap();
1083        let path = "ftp://example.com";
1084        let mut query_params = QueryParamsStore::new();
1085        let result = build_url(base_url, &path, &mut query_params, true);
1086
1087        assert!(result.is_err());
1088    }
1089
1090    #[test]
1091    fn it_should_not_allow_different_host_when_restricted() {
1092        let base_url = "http://example.com".parse::<Url>().unwrap();
1093        let path = "http://google.com";
1094        let mut query_params = QueryParamsStore::new();
1095        let result = build_url(base_url, &path, &mut query_params, true);
1096
1097        assert!(result.is_err());
1098    }
1099
1100    #[test]
1101    fn it_should_not_allow_different_port_when_restricted() {
1102        let base_url = "http://example.com:123".parse::<Url>().unwrap();
1103        let path = "http://example.com:456";
1104        let mut query_params = QueryParamsStore::new();
1105        let result = build_url(base_url, &path, &mut query_params, true);
1106
1107        assert!(result.is_err());
1108    }
1109
1110    #[test]
1111    fn it_should_allow_same_host_port_when_restricted() {
1112        let base_url = "http://example.com:123".parse::<Url>().unwrap();
1113        let path = "http://example.com:123";
1114        let mut query_params = QueryParamsStore::new();
1115        let result = build_url(base_url, &path, &mut query_params, true).unwrap();
1116
1117        assert_eq!("http://example.com:123/", result.as_str());
1118    }
1119}
1120
1121#[cfg(test)]
1122mod test_new {
1123    use axum::Router;
1124    use axum::routing::get;
1125    use std::net::SocketAddr;
1126
1127    use crate::TestServer;
1128
1129    async fn get_ping() -> &'static str {
1130        "pong!"
1131    }
1132
1133    #[tokio::test]
1134    async fn it_should_run_into_make_into_service_with_connect_info_by_default() {
1135        // Build an application with a route.
1136        let app = Router::new()
1137            .route("/ping", get(get_ping))
1138            .into_make_service_with_connect_info::<SocketAddr>();
1139
1140        // Run the server.
1141        let server = TestServer::new(app).expect("Should create test server");
1142
1143        // Get the request.
1144        server.get(&"/ping").await.assert_text(&"pong!");
1145    }
1146}
1147
1148#[cfg(test)]
1149mod test_get {
1150    use super::*;
1151
1152    use axum::Router;
1153    use axum::routing::get;
1154    use reserve_port::ReservedSocketAddr;
1155
1156    async fn get_ping() -> &'static str {
1157        "pong!"
1158    }
1159
1160    #[tokio::test]
1161    async fn it_should_get_using_relative_path_with_slash() {
1162        let app = Router::new().route("/ping", get(get_ping));
1163        let server = TestServer::new(app).expect("Should create test server");
1164
1165        // Get the request _with_ slash
1166        server.get(&"/ping").await.assert_text(&"pong!");
1167    }
1168
1169    #[tokio::test]
1170    async fn it_should_get_using_relative_path_without_slash() {
1171        let app = Router::new().route("/ping", get(get_ping));
1172        let server = TestServer::new(app).expect("Should create test server");
1173
1174        // Get the request _without_ slash
1175        server.get(&"ping").await.assert_text(&"pong!");
1176    }
1177
1178    #[tokio::test]
1179    async fn it_should_get_using_absolute_path() {
1180        // Build an application with a route.
1181        let app = Router::new().route("/ping", get(get_ping));
1182
1183        // Reserve an address
1184        let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1185        let ip = reserved_address.ip();
1186        let port = reserved_address.port();
1187
1188        // Run the server.
1189        let server = TestServer::builder()
1190            .http_transport_with_ip_port(Some(ip), Some(port))
1191            .build(app)
1192            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1193            .unwrap();
1194
1195        // Get the request.
1196        let absolute_url = format!("http://{ip}:{port}/ping");
1197        let response = server.get(&absolute_url).await;
1198
1199        response.assert_text(&"pong!");
1200        let request_path = response.request_url();
1201        assert_eq!(request_path.to_string(), format!("http://{ip}:{port}/ping"));
1202    }
1203
1204    #[tokio::test]
1205    async fn it_should_get_using_absolute_path_and_restricted_if_path_is_for_server() {
1206        // Build an application with a route.
1207        let app = Router::new().route("/ping", get(get_ping));
1208
1209        // Reserve an IP / Port
1210        let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1211        let ip = reserved_address.ip();
1212        let port = reserved_address.port();
1213
1214        // Run the server.
1215        let server = TestServer::builder()
1216            .http_transport_with_ip_port(Some(ip), Some(port))
1217            .restrict_requests_with_http_schema() // Key part of the test!
1218            .build(app)
1219            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1220            .unwrap();
1221
1222        // Get the request.
1223        let absolute_url = format!("http://{ip}:{port}/ping");
1224        let response = server.get(&absolute_url).await;
1225
1226        response.assert_text(&"pong!");
1227        let request_path = response.request_url();
1228        assert_eq!(request_path.to_string(), format!("http://{ip}:{port}/ping"));
1229    }
1230
1231    #[tokio::test]
1232    #[should_panic]
1233    async fn it_should_not_get_using_absolute_path_if_restricted_and_different_port() {
1234        // Build an application with a route.
1235        let app = Router::new().route("/ping", get(get_ping));
1236
1237        // Reserve an IP / Port
1238        let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1239        let ip = reserved_address.ip();
1240        let mut port = reserved_address.port();
1241
1242        // Run the server.
1243        let server = TestServer::builder()
1244            .http_transport_with_ip_port(Some(ip), Some(port))
1245            .restrict_requests_with_http_schema() // Key part of the test!
1246            .build(app)
1247            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1248            .unwrap();
1249
1250        // Get the request.
1251        port += 1; // << Change the port to be off by one and not match the server
1252        let absolute_url = format!("http://{ip}:{port}/ping");
1253        server.get(&absolute_url).await;
1254    }
1255
1256    #[tokio::test]
1257    async fn it_should_work_in_parallel() {
1258        let app = Router::new().route("/ping", get(get_ping));
1259        let server = TestServer::new(app).expect("Should create test server");
1260
1261        let future1 = async { server.get("/ping").await };
1262        let future2 = async { server.get("/ping").await };
1263        let (r1, r2) = tokio::join!(future1, future2);
1264
1265        assert_eq!(r1.text(), r2.text());
1266    }
1267
1268    #[tokio::test]
1269    async fn it_should_work_in_parallel_with_sleeping_requests() {
1270        let app = axum::Router::new().route(
1271            &"/slow",
1272            axum::routing::get(|| async {
1273                tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1274                "hello!"
1275            }),
1276        );
1277
1278        let server = TestServer::new(app).expect("Should create test server");
1279
1280        let future1 = async { server.get("/slow").await };
1281        let future2 = async { server.get("/slow").await };
1282        let (r1, r2) = tokio::join!(future1, future2);
1283
1284        assert_eq!(r1.text(), r2.text());
1285    }
1286}
1287
1288#[cfg(feature = "reqwest")]
1289#[cfg(test)]
1290mod test_reqwest_get {
1291    use super::*;
1292
1293    use axum::Router;
1294    use axum::routing::get;
1295
1296    async fn get_ping() -> &'static str {
1297        "pong!"
1298    }
1299
1300    #[tokio::test]
1301    async fn it_should_get_using_relative_path_with_slash() {
1302        let app = Router::new().route("/ping", get(get_ping));
1303        let server = TestServer::builder()
1304            .http_transport()
1305            .build(app)
1306            .expect("Should create test server");
1307
1308        let response = server
1309            .reqwest_get(&"/ping")
1310            .send()
1311            .await
1312            .unwrap()
1313            .text()
1314            .await
1315            .unwrap();
1316
1317        assert_eq!(response, "pong!");
1318    }
1319}
1320
1321#[cfg(feature = "reqwest")]
1322#[cfg(test)]
1323mod test_reqwest_post {
1324    use super::*;
1325
1326    use axum::Json;
1327    use axum::Router;
1328    use axum::routing::post;
1329    use serde::Deserialize;
1330
1331    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1332    struct TestBody {
1333        number: u32,
1334        text: String,
1335    }
1336
1337    async fn post_json(Json(body): Json<TestBody>) -> Json<TestBody> {
1338        let response = TestBody {
1339            number: body.number * 2,
1340            text: format!("{}_plus_response", body.text),
1341        };
1342
1343        Json(response)
1344    }
1345
1346    #[tokio::test]
1347    async fn it_should_post_and_receive_json() {
1348        let app = Router::new().route("/json", post(post_json));
1349        let server = TestServer::builder()
1350            .http_transport()
1351            .build(app)
1352            .expect("Should create test server");
1353
1354        let response = server
1355            .reqwest_post(&"/json")
1356            .json(&TestBody {
1357                number: 111,
1358                text: format!("request"),
1359            })
1360            .send()
1361            .await
1362            .unwrap()
1363            .json::<TestBody>()
1364            .await
1365            .unwrap();
1366
1367        assert_eq!(
1368            response,
1369            TestBody {
1370                number: 222,
1371                text: format!("request_plus_response"),
1372            }
1373        );
1374    }
1375}
1376
1377#[cfg(test)]
1378mod test_server_address {
1379    use super::*;
1380
1381    use axum::Router;
1382    use local_ip_address::local_ip;
1383    use regex::Regex;
1384    use reserve_port::ReservedPort;
1385
1386    #[tokio::test]
1387    async fn it_should_return_address_used_from_config() {
1388        let reserved_port = ReservedPort::random().unwrap();
1389        let ip = local_ip().unwrap();
1390        let port = reserved_port.port();
1391
1392        // Build an application with a route.
1393        let app = Router::new();
1394        let server = TestServer::builder()
1395            .http_transport_with_ip_port(Some(ip), Some(port))
1396            .build(app)
1397            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1398            .unwrap();
1399
1400        let expected_ip_port = format!("http://{}:{}/", ip, reserved_port.port());
1401        assert_eq!(
1402            server.server_address().unwrap().to_string(),
1403            expected_ip_port
1404        );
1405    }
1406
1407    #[tokio::test]
1408    async fn it_should_return_default_address_without_ending_slash() {
1409        let app = Router::new();
1410        let server = TestServer::builder()
1411            .http_transport()
1412            .build(app)
1413            .expect("Should create test server");
1414
1415        let address_regex = Regex::new("^http://127\\.0\\.0\\.1:[0-9]+/$").unwrap();
1416        let is_match = address_regex.is_match(&server.server_address().unwrap().to_string());
1417        assert!(is_match);
1418    }
1419
1420    #[tokio::test]
1421    async fn it_should_return_none_on_mock_transport() {
1422        let app = Router::new();
1423        let server = TestServer::builder()
1424            .mock_transport()
1425            .build(app)
1426            .expect("Should create test server");
1427
1428        assert!(server.server_address().is_none());
1429    }
1430}
1431
1432#[cfg(test)]
1433mod test_server_url {
1434    use super::*;
1435
1436    use axum::Router;
1437    use local_ip_address::local_ip;
1438    use regex::Regex;
1439    use reserve_port::ReservedPort;
1440
1441    #[tokio::test]
1442    async fn it_should_return_address_with_url_on_http_ip_port() {
1443        let reserved_port = ReservedPort::random().unwrap();
1444        let ip = local_ip().unwrap();
1445        let port = reserved_port.port();
1446
1447        // Build an application with a route.
1448        let app = Router::new();
1449        let server = TestServer::builder()
1450            .http_transport_with_ip_port(Some(ip), Some(port))
1451            .build(app)
1452            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1453            .unwrap();
1454
1455        let expected_ip_port_url = format!("http://{}:{}/users", ip, reserved_port.port());
1456        let absolute_url = server.server_url("/users").unwrap().to_string();
1457        assert_eq!(absolute_url, expected_ip_port_url);
1458    }
1459
1460    #[tokio::test]
1461    async fn it_should_return_address_with_url_on_random_http() {
1462        let app = Router::new();
1463        let server = TestServer::builder()
1464            .http_transport()
1465            .build(app)
1466            .expect("Should create test server");
1467
1468        let address_regex =
1469            Regex::new("^http://127\\.0\\.0\\.1:[0-9]+/users/123\\?filter=enabled$").unwrap();
1470        let absolute_url = &server
1471            .server_url(&"/users/123?filter=enabled")
1472            .unwrap()
1473            .to_string();
1474
1475        let is_match = address_regex.is_match(absolute_url);
1476        assert!(is_match);
1477    }
1478
1479    #[tokio::test]
1480    async fn it_should_error_on_mock_transport() {
1481        // Build an application with a route.
1482        let app = Router::new();
1483        let server = TestServer::builder()
1484            .mock_transport()
1485            .build(app)
1486            .expect("Should create test server");
1487
1488        let result = server.server_url("/users");
1489        assert!(result.is_err());
1490    }
1491
1492    #[tokio::test]
1493    async fn it_should_include_path_query_params() {
1494        let reserved_port = ReservedPort::random().unwrap();
1495        let ip = local_ip().unwrap();
1496        let port = reserved_port.port();
1497
1498        // Build an application with a route.
1499        let app = Router::new();
1500        let server = TestServer::builder()
1501            .http_transport_with_ip_port(Some(ip), Some(port))
1502            .build(app)
1503            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1504            .unwrap();
1505
1506        let expected_url = format!(
1507            "http://{}:{}/users?filter=enabled",
1508            ip,
1509            reserved_port.port()
1510        );
1511        let received_url = server
1512            .server_url("/users?filter=enabled")
1513            .unwrap()
1514            .to_string();
1515
1516        assert_eq!(received_url, expected_url);
1517    }
1518
1519    #[tokio::test]
1520    async fn it_should_include_server_query_params() {
1521        let reserved_port = ReservedPort::random().unwrap();
1522        let ip = local_ip().unwrap();
1523        let port = reserved_port.port();
1524
1525        // Build an application with a route.
1526        let app = Router::new();
1527        let mut server = TestServer::builder()
1528            .http_transport_with_ip_port(Some(ip), Some(port))
1529            .build(app)
1530            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1531            .unwrap();
1532
1533        server.add_query_param("filter", "enabled");
1534
1535        let expected_url = format!(
1536            "http://{}:{}/users?filter=enabled",
1537            ip,
1538            reserved_port.port()
1539        );
1540        let received_url = server.server_url("/users").unwrap().to_string();
1541
1542        assert_eq!(received_url, expected_url);
1543    }
1544
1545    #[tokio::test]
1546    async fn it_should_include_server_and_path_query_params() {
1547        let reserved_port = ReservedPort::random().unwrap();
1548        let ip = local_ip().unwrap();
1549        let port = reserved_port.port();
1550
1551        // Build an application with a route.
1552        let app = Router::new();
1553        let mut server = TestServer::builder()
1554            .http_transport_with_ip_port(Some(ip), Some(port))
1555            .build(app)
1556            .with_context(|| format!("Should create test server with address {}:{}", ip, port))
1557            .unwrap();
1558
1559        server.add_query_param("filter", "enabled");
1560
1561        let expected_url = format!(
1562            "http://{}:{}/users?filter=enabled&animal=donkeys",
1563            ip,
1564            reserved_port.port()
1565        );
1566        let received_url = server
1567            .server_url("/users?animal=donkeys")
1568            .unwrap()
1569            .to_string();
1570
1571        assert_eq!(received_url, expected_url);
1572    }
1573}
1574
1575#[cfg(test)]
1576mod test_add_cookie {
1577    use crate::TestServer;
1578
1579    use axum::Router;
1580    use axum::routing::get;
1581    use axum_extra::extract::cookie::CookieJar;
1582    use cookie::Cookie;
1583
1584    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
1585
1586    async fn get_cookie(cookies: CookieJar) -> (CookieJar, String) {
1587        let cookie = cookies.get(&TEST_COOKIE_NAME);
1588        let cookie_value = cookie
1589            .map(|c| c.value().to_string())
1590            .unwrap_or_else(|| "cookie-not-found".to_string());
1591
1592        (cookies, cookie_value)
1593    }
1594
1595    #[tokio::test]
1596    async fn it_should_send_cookies_added_to_request() {
1597        let app = Router::new().route("/cookie", get(get_cookie));
1598        let mut server = TestServer::new(app).expect("Should create test server");
1599
1600        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1601        server.add_cookie(cookie);
1602
1603        let response_text = server.get(&"/cookie").await.text();
1604        assert_eq!(response_text, "my-custom-cookie");
1605    }
1606}
1607
1608#[cfg(test)]
1609mod test_add_cookies {
1610    use crate::TestServer;
1611
1612    use axum::Router;
1613    use axum::routing::get;
1614    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1615    use cookie::Cookie;
1616    use cookie::CookieJar;
1617
1618    async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1619        let mut all_cookies = cookies
1620            .iter()
1621            .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1622            .collect::<Vec<String>>();
1623        all_cookies.sort();
1624
1625        all_cookies.join(&", ")
1626    }
1627
1628    #[tokio::test]
1629    async fn it_should_send_all_cookies_added_by_jar() {
1630        let app = Router::new().route("/cookies", get(route_get_cookies));
1631        let mut server = TestServer::new(app).expect("Should create test server");
1632
1633        // Build cookies to send up
1634        let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1635        let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1636        let mut cookie_jar = CookieJar::new();
1637        cookie_jar.add(cookie_1);
1638        cookie_jar.add(cookie_2);
1639
1640        server.add_cookies(cookie_jar);
1641
1642        server
1643            .get(&"/cookies")
1644            .await
1645            .assert_text("first-cookie=my-custom-cookie, second-cookie=other-cookie");
1646    }
1647}
1648
1649#[cfg(test)]
1650mod test_clear_cookies {
1651    use crate::TestServer;
1652
1653    use axum::Router;
1654    use axum::routing::get;
1655    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1656    use cookie::Cookie;
1657    use cookie::CookieJar;
1658
1659    async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1660        let mut all_cookies = cookies
1661            .iter()
1662            .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1663            .collect::<Vec<String>>();
1664        all_cookies.sort();
1665
1666        all_cookies.join(&", ")
1667    }
1668
1669    #[tokio::test]
1670    async fn it_should_not_send_cookies_cleared() {
1671        let app = Router::new().route("/cookies", get(route_get_cookies));
1672        let mut server = TestServer::new(app).expect("Should create test server");
1673
1674        let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1675        let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1676        let mut cookie_jar = CookieJar::new();
1677        cookie_jar.add(cookie_1);
1678        cookie_jar.add(cookie_2);
1679
1680        server.add_cookies(cookie_jar);
1681
1682        // The important bit of this test
1683        server.clear_cookies();
1684
1685        server.get(&"/cookies").await.assert_text("");
1686    }
1687}
1688
1689#[cfg(test)]
1690mod test_add_header {
1691    use super::*;
1692    use crate::TestServer;
1693    use axum::Router;
1694    use axum::extract::FromRequestParts;
1695    use axum::routing::get;
1696    use http::HeaderName;
1697    use http::HeaderValue;
1698    use http::request::Parts;
1699    use hyper::StatusCode;
1700    use std::marker::Sync;
1701
1702    const TEST_HEADER_NAME: &'static str = &"test-header";
1703    const TEST_HEADER_CONTENT: &'static str = &"Test header content";
1704
1705    struct TestHeader(Vec<u8>);
1706
1707    impl<S: Sync> FromRequestParts<S> for TestHeader {
1708        type Rejection = (StatusCode, &'static str);
1709
1710        async fn from_request_parts(
1711            parts: &mut Parts,
1712            _state: &S,
1713        ) -> Result<TestHeader, Self::Rejection> {
1714            parts
1715                .headers
1716                .get(HeaderName::from_static(TEST_HEADER_NAME))
1717                .map(|v| TestHeader(v.as_bytes().to_vec()))
1718                .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
1719        }
1720    }
1721
1722    async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
1723        header
1724    }
1725
1726    #[tokio::test]
1727    async fn it_should_send_header_added_to_server() {
1728        // Build an application with a route.
1729        let app = Router::new().route("/header", get(ping_header));
1730
1731        // Run the server.
1732        let mut server = TestServer::new(app).expect("Should create test server");
1733        server.add_header(
1734            HeaderName::from_static(TEST_HEADER_NAME),
1735            HeaderValue::from_static(TEST_HEADER_CONTENT),
1736        );
1737
1738        // Send a request with the header
1739        let response = server.get(&"/header").await;
1740
1741        // Check it sent back the right text
1742        response.assert_text(TEST_HEADER_CONTENT)
1743    }
1744}
1745
1746#[cfg(test)]
1747mod test_clear_headers {
1748    use super::*;
1749    use crate::TestServer;
1750    use axum::Router;
1751    use axum::extract::FromRequestParts;
1752    use axum::routing::get;
1753    use http::HeaderName;
1754    use http::HeaderValue;
1755    use http::request::Parts;
1756    use hyper::StatusCode;
1757    use std::marker::Sync;
1758
1759    const TEST_HEADER_NAME: &'static str = &"test-header";
1760    const TEST_HEADER_CONTENT: &'static str = &"Test header content";
1761
1762    struct TestHeader(Vec<u8>);
1763
1764    impl<S: Sync> FromRequestParts<S> for TestHeader {
1765        type Rejection = (StatusCode, &'static str);
1766
1767        async fn from_request_parts(
1768            parts: &mut Parts,
1769            _state: &S,
1770        ) -> Result<Self, Self::Rejection> {
1771            parts
1772                .headers
1773                .get(HeaderName::from_static(TEST_HEADER_NAME))
1774                .map(|v| TestHeader(v.as_bytes().to_vec()))
1775                .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
1776        }
1777    }
1778
1779    async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
1780        header
1781    }
1782
1783    #[tokio::test]
1784    async fn it_should_not_send_headers_cleared_by_server() {
1785        // Build an application with a route.
1786        let app = Router::new().route("/header", get(ping_header));
1787
1788        // Run the server.
1789        let mut server = TestServer::new(app).expect("Should create test server");
1790        server.add_header(
1791            HeaderName::from_static(TEST_HEADER_NAME),
1792            HeaderValue::from_static(TEST_HEADER_CONTENT),
1793        );
1794        server.clear_headers();
1795
1796        // Send a request with the header
1797        let response = server.get(&"/header").await;
1798
1799        // Check it sent back the right text
1800        response.assert_status_bad_request();
1801        response.assert_text("Missing test header");
1802    }
1803}
1804
1805#[cfg(test)]
1806mod test_add_query_params {
1807    use axum::Router;
1808    use axum::extract::Query;
1809    use axum::routing::get;
1810
1811    use serde::Deserialize;
1812    use serde::Serialize;
1813    use serde_json::json;
1814
1815    use crate::TestServer;
1816
1817    #[derive(Debug, Deserialize, Serialize)]
1818    struct QueryParam {
1819        message: String,
1820    }
1821
1822    async fn get_query_param(Query(params): Query<QueryParam>) -> String {
1823        params.message
1824    }
1825
1826    #[derive(Debug, Deserialize, Serialize)]
1827    struct QueryParam2 {
1828        message: String,
1829        other: String,
1830    }
1831
1832    async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
1833        format!("{}-{}", params.message, params.other)
1834    }
1835
1836    #[tokio::test]
1837    async fn it_should_pass_up_query_params_from_serialization() {
1838        // Build an application with a route.
1839        let app = Router::new().route("/query", get(get_query_param));
1840
1841        // Run the server.
1842        let mut server = TestServer::new(app).expect("Should create test server");
1843        server.add_query_params(QueryParam {
1844            message: "it works".to_string(),
1845        });
1846
1847        // Get the request.
1848        server.get(&"/query").await.assert_text(&"it works");
1849    }
1850
1851    #[tokio::test]
1852    async fn it_should_pass_up_query_params_from_pairs() {
1853        // Build an application with a route.
1854        let app = Router::new().route("/query", get(get_query_param));
1855
1856        // Run the server.
1857        let mut server = TestServer::new(app).expect("Should create test server");
1858        server.add_query_params(&[("message", "it works")]);
1859
1860        // Get the request.
1861        server.get(&"/query").await.assert_text(&"it works");
1862    }
1863
1864    #[tokio::test]
1865    async fn it_should_pass_up_multiple_query_params_from_multiple_params() {
1866        // Build an application with a route.
1867        let app = Router::new().route("/query-2", get(get_query_param_2));
1868
1869        // Run the server.
1870        let mut server = TestServer::new(app).expect("Should create test server");
1871        server.add_query_params(&[("message", "it works"), ("other", "yup")]);
1872
1873        // Get the request.
1874        server.get(&"/query-2").await.assert_text(&"it works-yup");
1875    }
1876
1877    #[tokio::test]
1878    async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
1879        // Build an application with a route.
1880        let app = Router::new().route("/query-2", get(get_query_param_2));
1881
1882        // Run the server.
1883        let mut server = TestServer::new(app).expect("Should create test server");
1884        server.add_query_params(&[("message", "it works")]);
1885        server.add_query_params(&[("other", "yup")]);
1886
1887        // Get the request.
1888        server.get(&"/query-2").await.assert_text(&"it works-yup");
1889    }
1890
1891    #[tokio::test]
1892    async fn it_should_pass_up_multiple_query_params_from_json() {
1893        // Build an application with a route.
1894        let app = Router::new().route("/query-2", get(get_query_param_2));
1895
1896        // Run the server.
1897        let mut server = TestServer::new(app).expect("Should create test server");
1898        server.add_query_params(json!({
1899            "message": "it works",
1900            "other": "yup"
1901        }));
1902
1903        // Get the request.
1904        server.get(&"/query-2").await.assert_text(&"it works-yup");
1905    }
1906}
1907
1908#[cfg(test)]
1909mod test_add_query_param {
1910    use axum::Router;
1911    use axum::extract::Query;
1912    use axum::routing::get;
1913
1914    use serde::Deserialize;
1915    use serde::Serialize;
1916
1917    use crate::TestServer;
1918
1919    #[derive(Debug, Deserialize, Serialize)]
1920    struct QueryParam {
1921        message: String,
1922    }
1923
1924    async fn get_query_param(Query(params): Query<QueryParam>) -> String {
1925        params.message
1926    }
1927
1928    #[derive(Debug, Deserialize, Serialize)]
1929    struct QueryParam2 {
1930        message: String,
1931        other: String,
1932    }
1933
1934    async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
1935        format!("{}-{}", params.message, params.other)
1936    }
1937
1938    #[tokio::test]
1939    async fn it_should_pass_up_query_params_from_pairs() {
1940        // Build an application with a route.
1941        let app = Router::new().route("/query", get(get_query_param));
1942
1943        // Run the server.
1944        let mut server = TestServer::new(app).expect("Should create test server");
1945        server.add_query_param("message", "it works");
1946
1947        // Get the request.
1948        server.get(&"/query").await.assert_text(&"it works");
1949    }
1950
1951    #[tokio::test]
1952    async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
1953        // Build an application with a route.
1954        let app = Router::new().route("/query-2", get(get_query_param_2));
1955
1956        // Run the server.
1957        let mut server = TestServer::new(app).expect("Should create test server");
1958        server.add_query_param("message", "it works");
1959        server.add_query_param("other", "yup");
1960
1961        // Get the request.
1962        server.get(&"/query-2").await.assert_text(&"it works-yup");
1963    }
1964
1965    #[tokio::test]
1966    async fn it_should_pass_up_multiple_query_params_from_calls_across_server_and_request() {
1967        // Build an application with a route.
1968        let app = Router::new().route("/query-2", get(get_query_param_2));
1969
1970        // Run the server.
1971        let mut server = TestServer::new(app).expect("Should create test server");
1972        server.add_query_param("message", "it works");
1973
1974        // Get the request.
1975        server
1976            .get(&"/query-2")
1977            .add_query_param("other", "yup")
1978            .await
1979            .assert_text(&"it works-yup");
1980    }
1981}
1982
1983#[cfg(test)]
1984mod test_add_raw_query_param {
1985    use axum::Router;
1986    use axum::extract::Query as AxumStdQuery;
1987    use axum::routing::get;
1988    use axum_extra::extract::Query as AxumExtraQuery;
1989    use serde::Deserialize;
1990    use serde::Serialize;
1991    use std::fmt::Write;
1992
1993    use crate::TestServer;
1994
1995    #[derive(Debug, Deserialize, Serialize)]
1996    struct QueryParam {
1997        message: String,
1998    }
1999
2000    async fn get_query_param(AxumStdQuery(params): AxumStdQuery<QueryParam>) -> String {
2001        params.message
2002    }
2003
2004    #[derive(Debug, Deserialize, Serialize)]
2005    struct QueryParamExtra {
2006        #[serde(default)]
2007        items: Vec<String>,
2008
2009        #[serde(default, rename = "arrs[]")]
2010        arrs: Vec<String>,
2011    }
2012
2013    async fn get_query_param_extra(
2014        AxumExtraQuery(params): AxumExtraQuery<QueryParamExtra>,
2015    ) -> String {
2016        let mut output = String::new();
2017
2018        if params.items.len() > 0 {
2019            write!(output, "{}", params.items.join(", ")).unwrap();
2020        }
2021
2022        if params.arrs.len() > 0 {
2023            write!(output, "{}", params.arrs.join(", ")).unwrap();
2024        }
2025
2026        output
2027    }
2028
2029    fn build_app() -> Router {
2030        Router::new()
2031            .route("/query", get(get_query_param))
2032            .route("/query-extra", get(get_query_param_extra))
2033    }
2034
2035    #[tokio::test]
2036    async fn it_should_pass_up_query_param_as_is() {
2037        // Run the server.
2038        let mut server = TestServer::new(build_app()).expect("Should create test server");
2039        server.add_raw_query_param(&"message=it-works");
2040
2041        // Get the request.
2042        server.get(&"/query").await.assert_text(&"it-works");
2043    }
2044
2045    #[tokio::test]
2046    async fn it_should_pass_up_array_query_params_as_one_string() {
2047        // Run the server.
2048        let mut server = TestServer::new(build_app()).expect("Should create test server");
2049        server.add_raw_query_param(&"items=one&items=two&items=three");
2050
2051        // Get the request.
2052        server
2053            .get(&"/query-extra")
2054            .await
2055            .assert_text(&"one, two, three");
2056    }
2057
2058    #[tokio::test]
2059    async fn it_should_pass_up_array_query_params_as_multiple_params() {
2060        // Run the server.
2061        let mut server = TestServer::new(build_app()).expect("Should create test server");
2062        server.add_raw_query_param(&"arrs[]=one");
2063        server.add_raw_query_param(&"arrs[]=two");
2064        server.add_raw_query_param(&"arrs[]=three");
2065
2066        // Get the request.
2067        server
2068            .get(&"/query-extra")
2069            .await
2070            .assert_text(&"one, two, three");
2071    }
2072}
2073
2074#[cfg(test)]
2075mod test_clear_query_params {
2076    use axum::Router;
2077    use axum::extract::Query;
2078    use axum::routing::get;
2079
2080    use serde::Deserialize;
2081    use serde::Serialize;
2082
2083    use crate::TestServer;
2084
2085    #[derive(Debug, Deserialize, Serialize)]
2086    struct QueryParams {
2087        first: Option<String>,
2088        second: Option<String>,
2089    }
2090
2091    async fn get_query_params(Query(params): Query<QueryParams>) -> String {
2092        format!(
2093            "has first? {}, has second? {}",
2094            params.first.is_some(),
2095            params.second.is_some()
2096        )
2097    }
2098
2099    #[tokio::test]
2100    async fn it_should_clear_all_params_set() {
2101        // Build an application with a route.
2102        let app = Router::new().route("/query", get(get_query_params));
2103
2104        // Run the server.
2105        let mut server = TestServer::new(app).expect("Should create test server");
2106        server.add_query_params(QueryParams {
2107            first: Some("first".to_string()),
2108            second: Some("second".to_string()),
2109        });
2110        server.clear_query_params();
2111
2112        // Get the request.
2113        server
2114            .get(&"/query")
2115            .await
2116            .assert_text(&"has first? false, has second? false");
2117    }
2118
2119    #[tokio::test]
2120    async fn it_should_clear_all_params_set_and_allow_replacement() {
2121        // Build an application with a route.
2122        let app = Router::new().route("/query", get(get_query_params));
2123
2124        // Run the server.
2125        let mut server = TestServer::new(app).expect("Should create test server");
2126        server.add_query_params(QueryParams {
2127            first: Some("first".to_string()),
2128            second: Some("second".to_string()),
2129        });
2130        server.clear_query_params();
2131        server.add_query_params(QueryParams {
2132            first: Some("first".to_string()),
2133            second: Some("second".to_string()),
2134        });
2135
2136        // Get the request.
2137        server
2138            .get(&"/query")
2139            .await
2140            .assert_text(&"has first? true, has second? true");
2141    }
2142}
2143
2144#[cfg(test)]
2145mod test_expect_success_by_default {
2146    use super::*;
2147
2148    use axum::Router;
2149    use axum::routing::get;
2150
2151    #[tokio::test]
2152    async fn it_should_not_panic_by_default_if_accessing_404_route() {
2153        let app = Router::new();
2154        let server = TestServer::new(app).expect("Should create test server");
2155
2156        server.get(&"/some_unknown_route").await;
2157    }
2158
2159    #[tokio::test]
2160    async fn it_should_not_panic_by_default_if_accessing_200_route() {
2161        let app = Router::new().route("/known_route", get(|| async { "🦊🦊🦊" }));
2162        let server = TestServer::new(app).expect("Should create test server");
2163
2164        server.get(&"/known_route").await;
2165    }
2166
2167    #[tokio::test]
2168    #[should_panic]
2169    async fn it_should_panic_by_default_if_accessing_404_route_and_expect_success_on() {
2170        let app = Router::new();
2171        let server = TestServer::builder()
2172            .expect_success_by_default()
2173            .build(app)
2174            .expect("Should create test server");
2175
2176        server.get(&"/some_unknown_route").await;
2177    }
2178
2179    #[tokio::test]
2180    async fn it_should_not_panic_by_default_if_accessing_200_route_and_expect_success_on() {
2181        let app = Router::new().route("/known_route", get(|| async { "🦊🦊🦊" }));
2182        let server = TestServer::builder()
2183            .expect_success_by_default()
2184            .build(app)
2185            .expect("Should create test server");
2186
2187        server.get(&"/known_route").await;
2188    }
2189}
2190
2191#[cfg(test)]
2192mod test_content_type {
2193    use super::*;
2194
2195    use axum::Router;
2196    use axum::routing::get;
2197    use http::HeaderMap;
2198    use http::header::CONTENT_TYPE;
2199
2200    async fn get_content_type(headers: HeaderMap) -> String {
2201        headers
2202            .get(CONTENT_TYPE)
2203            .map(|h| h.to_str().unwrap().to_string())
2204            .unwrap_or_else(|| "".to_string())
2205    }
2206
2207    #[tokio::test]
2208    async fn it_should_default_to_server_content_type_when_present() {
2209        // Build an application with a route.
2210        let app = Router::new().route("/content_type", get(get_content_type));
2211
2212        // Run the server.
2213        let server = TestServer::builder()
2214            .default_content_type("text/plain")
2215            .build(app)
2216            .expect("Should create test server");
2217
2218        // Get the request.
2219        let text = server.get(&"/content_type").await.text();
2220
2221        assert_eq!(text, "text/plain");
2222    }
2223}
2224
2225#[cfg(test)]
2226mod test_expect_success {
2227    use crate::TestServer;
2228    use axum::Router;
2229    use axum::routing::get;
2230    use http::StatusCode;
2231
2232    #[tokio::test]
2233    async fn it_should_not_panic_if_success_is_returned() {
2234        async fn get_ping() -> &'static str {
2235            "pong!"
2236        }
2237
2238        // Build an application with a route.
2239        let app = Router::new().route("/ping", get(get_ping));
2240
2241        // Run the server.
2242        let mut server = TestServer::new(app).expect("Should create test server");
2243        server.expect_success();
2244
2245        // Get the request.
2246        server.get(&"/ping").await;
2247    }
2248
2249    #[tokio::test]
2250    async fn it_should_not_panic_on_other_2xx_status_code() {
2251        async fn get_accepted() -> StatusCode {
2252            StatusCode::ACCEPTED
2253        }
2254
2255        // Build an application with a route.
2256        let app = Router::new().route("/accepted", get(get_accepted));
2257
2258        // Run the server.
2259        let mut server = TestServer::new(app).expect("Should create test server");
2260        server.expect_success();
2261
2262        // Get the request.
2263        server.get(&"/accepted").await;
2264    }
2265
2266    #[tokio::test]
2267    #[should_panic]
2268    async fn it_should_panic_on_404() {
2269        // Build an application with a route.
2270        let app = Router::new();
2271
2272        // Run the server.
2273        let mut server = TestServer::new(app).expect("Should create test server");
2274        server.expect_success();
2275
2276        // Get the request.
2277        server.get(&"/some_unknown_route").await;
2278    }
2279}
2280
2281#[cfg(test)]
2282mod test_expect_failure {
2283    use crate::TestServer;
2284    use axum::Router;
2285    use axum::routing::get;
2286    use http::StatusCode;
2287
2288    #[tokio::test]
2289    async fn it_should_not_panic_if_expect_failure_on_404() {
2290        // Build an application with a route.
2291        let app = Router::new();
2292
2293        // Run the server.
2294        let mut server = TestServer::new(app).expect("Should create test server");
2295        server.expect_failure();
2296
2297        // Get the request.
2298        server.get(&"/some_unknown_route").await;
2299    }
2300
2301    #[tokio::test]
2302    #[should_panic]
2303    async fn it_should_panic_if_success_is_returned() {
2304        async fn get_ping() -> &'static str {
2305            "pong!"
2306        }
2307
2308        // Build an application with a route.
2309        let app = Router::new().route("/ping", get(get_ping));
2310
2311        // Run the server.
2312        let mut server = TestServer::new(app).expect("Should create test server");
2313        server.expect_failure();
2314
2315        // Get the request.
2316        server.get(&"/ping").await;
2317    }
2318
2319    #[tokio::test]
2320    #[should_panic]
2321    async fn it_should_panic_on_other_2xx_status_code() {
2322        async fn get_accepted() -> StatusCode {
2323            StatusCode::ACCEPTED
2324        }
2325
2326        // Build an application with a route.
2327        let app = Router::new().route("/accepted", get(get_accepted));
2328
2329        // Run the server.
2330        let mut server = TestServer::new(app).expect("Should create test server");
2331        server.expect_failure();
2332
2333        // Get the request.
2334        server.get(&"/accepted").await;
2335    }
2336}
2337
2338#[cfg(test)]
2339mod test_scheme {
2340    use axum::Router;
2341    use axum::extract::Request;
2342    use axum::routing::get;
2343
2344    use crate::TestServer;
2345
2346    async fn route_get_scheme(request: Request) -> String {
2347        request.uri().scheme_str().unwrap().to_string()
2348    }
2349
2350    #[tokio::test]
2351    async fn it_should_return_http_by_default() {
2352        let router = Router::new().route("/scheme", get(route_get_scheme));
2353        let server = TestServer::builder().build(router).unwrap();
2354
2355        server.get("/scheme").await.assert_text("http");
2356    }
2357
2358    #[tokio::test]
2359    async fn it_should_return_https_across_multiple_requests_when_set() {
2360        let router = Router::new().route("/scheme", get(route_get_scheme));
2361        let mut server = TestServer::builder().build(router).unwrap();
2362        server.scheme(&"https");
2363
2364        server.get("/scheme").await.assert_text("https");
2365    }
2366}
2367
2368#[cfg(feature = "typed-routing")]
2369#[cfg(test)]
2370mod test_typed_get {
2371    use super::*;
2372
2373    use axum::Router;
2374    use axum_extra::routing::RouterExt;
2375    use serde::Deserialize;
2376
2377    #[derive(TypedPath, Deserialize)]
2378    #[typed_path("/path/{id}")]
2379    struct TestingPath {
2380        id: u32,
2381    }
2382
2383    async fn route_get(TestingPath { id }: TestingPath) -> String {
2384        format!("get {id}")
2385    }
2386
2387    fn new_app() -> Router {
2388        Router::new().typed_get(route_get)
2389    }
2390
2391    #[tokio::test]
2392    async fn it_should_send_get() {
2393        let server = TestServer::new(new_app()).unwrap();
2394
2395        server
2396            .typed_get(&TestingPath { id: 123 })
2397            .await
2398            .assert_text("get 123");
2399    }
2400}
2401
2402#[cfg(feature = "typed-routing")]
2403#[cfg(test)]
2404mod test_typed_post {
2405    use super::*;
2406
2407    use axum::Router;
2408    use axum_extra::routing::RouterExt;
2409    use serde::Deserialize;
2410
2411    #[derive(TypedPath, Deserialize)]
2412    #[typed_path("/path/{id}")]
2413    struct TestingPath {
2414        id: u32,
2415    }
2416
2417    async fn route_post(TestingPath { id }: TestingPath) -> String {
2418        format!("post {id}")
2419    }
2420
2421    fn new_app() -> Router {
2422        Router::new().typed_post(route_post)
2423    }
2424
2425    #[tokio::test]
2426    async fn it_should_send_post() {
2427        let server = TestServer::new(new_app()).unwrap();
2428
2429        server
2430            .typed_post(&TestingPath { id: 123 })
2431            .await
2432            .assert_text("post 123");
2433    }
2434}
2435
2436#[cfg(feature = "typed-routing")]
2437#[cfg(test)]
2438mod test_typed_patch {
2439    use super::*;
2440
2441    use axum::Router;
2442    use axum_extra::routing::RouterExt;
2443    use serde::Deserialize;
2444
2445    #[derive(TypedPath, Deserialize)]
2446    #[typed_path("/path/{id}")]
2447    struct TestingPath {
2448        id: u32,
2449    }
2450
2451    async fn route_patch(TestingPath { id }: TestingPath) -> String {
2452        format!("patch {id}")
2453    }
2454
2455    fn new_app() -> Router {
2456        Router::new().typed_patch(route_patch)
2457    }
2458
2459    #[tokio::test]
2460    async fn it_should_send_patch() {
2461        let server = TestServer::new(new_app()).unwrap();
2462
2463        server
2464            .typed_patch(&TestingPath { id: 123 })
2465            .await
2466            .assert_text("patch 123");
2467    }
2468}
2469
2470#[cfg(feature = "typed-routing")]
2471#[cfg(test)]
2472mod test_typed_put {
2473    use super::*;
2474
2475    use axum::Router;
2476    use axum_extra::routing::RouterExt;
2477    use serde::Deserialize;
2478
2479    #[derive(TypedPath, Deserialize)]
2480    #[typed_path("/path/{id}")]
2481    struct TestingPath {
2482        id: u32,
2483    }
2484
2485    async fn route_put(TestingPath { id }: TestingPath) -> String {
2486        format!("put {id}")
2487    }
2488
2489    fn new_app() -> Router {
2490        Router::new().typed_put(route_put)
2491    }
2492
2493    #[tokio::test]
2494    async fn it_should_send_put() {
2495        let server = TestServer::new(new_app()).unwrap();
2496
2497        server
2498            .typed_put(&TestingPath { id: 123 })
2499            .await
2500            .assert_text("put 123");
2501    }
2502}
2503
2504#[cfg(feature = "typed-routing")]
2505#[cfg(test)]
2506mod test_typed_delete {
2507    use super::*;
2508
2509    use axum::Router;
2510    use axum_extra::routing::RouterExt;
2511    use serde::Deserialize;
2512
2513    #[derive(TypedPath, Deserialize)]
2514    #[typed_path("/path/{id}")]
2515    struct TestingPath {
2516        id: u32,
2517    }
2518
2519    async fn route_delete(TestingPath { id }: TestingPath) -> String {
2520        format!("delete {id}")
2521    }
2522
2523    fn new_app() -> Router {
2524        Router::new().typed_delete(route_delete)
2525    }
2526
2527    #[tokio::test]
2528    async fn it_should_send_delete() {
2529        let server = TestServer::new(new_app()).unwrap();
2530
2531        server
2532            .typed_delete(&TestingPath { id: 123 })
2533            .await
2534            .assert_text("delete 123");
2535    }
2536}
2537
2538#[cfg(feature = "typed-routing")]
2539#[cfg(test)]
2540mod test_typed_method {
2541    use super::*;
2542
2543    use axum::Router;
2544    use axum_extra::routing::RouterExt;
2545    use serde::Deserialize;
2546
2547    #[derive(TypedPath, Deserialize)]
2548    #[typed_path("/path/{id}")]
2549    struct TestingPath {
2550        id: u32,
2551    }
2552
2553    async fn route_get(TestingPath { id }: TestingPath) -> String {
2554        format!("get {id}")
2555    }
2556
2557    async fn route_post(TestingPath { id }: TestingPath) -> String {
2558        format!("post {id}")
2559    }
2560
2561    async fn route_patch(TestingPath { id }: TestingPath) -> String {
2562        format!("patch {id}")
2563    }
2564
2565    async fn route_put(TestingPath { id }: TestingPath) -> String {
2566        format!("put {id}")
2567    }
2568
2569    async fn route_delete(TestingPath { id }: TestingPath) -> String {
2570        format!("delete {id}")
2571    }
2572
2573    fn new_app() -> Router {
2574        Router::new()
2575            .typed_get(route_get)
2576            .typed_post(route_post)
2577            .typed_patch(route_patch)
2578            .typed_put(route_put)
2579            .typed_delete(route_delete)
2580    }
2581
2582    #[tokio::test]
2583    async fn it_should_send_get() {
2584        let server = TestServer::new(new_app()).unwrap();
2585
2586        server
2587            .typed_method(Method::GET, &TestingPath { id: 123 })
2588            .await
2589            .assert_text("get 123");
2590    }
2591
2592    #[tokio::test]
2593    async fn it_should_send_post() {
2594        let server = TestServer::new(new_app()).unwrap();
2595
2596        server
2597            .typed_method(Method::POST, &TestingPath { id: 123 })
2598            .await
2599            .assert_text("post 123");
2600    }
2601
2602    #[tokio::test]
2603    async fn it_should_send_patch() {
2604        let server = TestServer::new(new_app()).unwrap();
2605
2606        server
2607            .typed_method(Method::PATCH, &TestingPath { id: 123 })
2608            .await
2609            .assert_text("patch 123");
2610    }
2611
2612    #[tokio::test]
2613    async fn it_should_send_put() {
2614        let server = TestServer::new(new_app()).unwrap();
2615
2616        server
2617            .typed_method(Method::PUT, &TestingPath { id: 123 })
2618            .await
2619            .assert_text("put 123");
2620    }
2621
2622    #[tokio::test]
2623    async fn it_should_send_delete() {
2624        let server = TestServer::new(new_app()).unwrap();
2625
2626        server
2627            .typed_method(Method::DELETE, &TestingPath { id: 123 })
2628            .await
2629            .assert_text("delete 123");
2630    }
2631}
2632
2633#[cfg(test)]
2634mod test_sync {
2635    use super::*;
2636    use axum::Router;
2637    use axum::routing::get;
2638    use std::cell::OnceCell;
2639
2640    #[tokio::test]
2641    async fn it_should_be_able_to_be_in_one_cell() {
2642        let cell: OnceCell<TestServer> = OnceCell::new();
2643        let server = cell.get_or_init(|| {
2644            async fn route_get() -> &'static str {
2645                "it works"
2646            }
2647
2648            let router = Router::new().route("/test", get(route_get));
2649
2650            TestServer::new(router).unwrap()
2651        });
2652
2653        server.get("/test").await.assert_text("it works");
2654    }
2655}
2656
2657#[cfg(test)]
2658mod test_is_running {
2659    use super::*;
2660    use crate::util::new_random_tokio_tcp_listener;
2661    use axum::Router;
2662    use axum::routing::IntoMakeService;
2663    use axum::routing::get;
2664    use axum::serve;
2665    use std::time::Duration;
2666    use tokio::sync::Notify;
2667    use tokio::time::sleep;
2668
2669    async fn get_ping() -> &'static str {
2670        "pong!"
2671    }
2672
2673    #[tokio::test]
2674    #[should_panic]
2675    async fn it_should_panic_when_run_with_mock_http() {
2676        let shutdown_notification = Arc::new(Notify::new());
2677        let waiting_notification = shutdown_notification.clone();
2678
2679        // Build an application with a route.
2680        let app: IntoMakeService<Router> = Router::new()
2681            .route("/ping", get(get_ping))
2682            .into_make_service();
2683        let port = new_random_tokio_tcp_listener().unwrap();
2684        let application = serve(port, app)
2685            .with_graceful_shutdown(async move { waiting_notification.notified().await });
2686
2687        // Run the server.
2688        let server = TestServer::builder()
2689            .build(application)
2690            .expect("Should create test server");
2691
2692        server.get("/ping").await.assert_status_ok();
2693        assert!(server.is_running());
2694
2695        shutdown_notification.notify_one();
2696        sleep(Duration::from_millis(10)).await;
2697
2698        assert!(!server.is_running());
2699        server.get("/ping").await.assert_status_ok();
2700    }
2701}