Skip to main content

axum_test/
test_server.rs

1use crate::TestRequest;
2use crate::TestRequestConfig;
3use crate::TestServerBuilder;
4use crate::TestServerConfig;
5use crate::Transport;
6use crate::internals::AtomicCrossCookieJar;
7use crate::internals::ErrorMessage;
8use crate::internals::ExpectedState;
9use crate::internals::QueryParamsStore;
10use crate::transport_layer::IntoTransportLayer;
11use crate::transport_layer::TransportLayer;
12use crate::transport_layer::TransportLayerBuilder;
13use anyhow::Result;
14use anyhow::anyhow;
15use cookie::Cookie;
16use cookie::CookieJar;
17use http::HeaderName;
18use http::HeaderValue;
19use http::Method;
20use http::Uri;
21use serde::Serialize;
22use std::fmt::Debug;
23use std::sync::Arc;
24use url::Url;
25
26#[cfg(feature = "typed-routing")]
27use axum_extra::routing::TypedPath;
28
29#[cfg(feature = "reqwest")]
30use crate::transport_layer::TransportLayerType;
31#[cfg(feature = "reqwest")]
32use reqwest::Client;
33#[cfg(feature = "reqwest")]
34use reqwest::RequestBuilder;
35#[cfg(feature = "reqwest")]
36use std::cell::OnceCell;
37
38mod server_shared_state;
39pub(crate) use self::server_shared_state::*;
40
41const DEFAULT_URL_ADDRESS: &str = "http://localhost";
42
43///
44/// The `TestServer` runs your Axum application,
45/// allowing you to make HTTP requests against it.
46///
47/// # Building
48///
49/// A `TestServer` can be used to run an [`axum::Router`], an [`::axum::routing::IntoMakeService`],
50/// and others.
51///
52/// The most straight forward approach is to call [`TestServer::new`],
53/// and pass in your application:
54///
55/// ```rust
56/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
57/// #
58/// use axum::Router;
59/// use axum::routing::get;
60///
61/// use axum_test::TestServer;
62///
63/// let app = Router::new()
64///     .route(&"/hello", get(|| async { "hello!" }));
65///
66/// let server = TestServer::new(app);
67/// #
68/// # Ok(())
69/// # }
70/// ```
71///
72/// # Requests
73///
74/// Requests are built by calling [`TestServer::get()`](crate::TestServer::get()),
75/// [`TestServer::post()`](crate::TestServer::post()), [`TestServer::put()`](crate::TestServer::put()),
76/// [`TestServer::delete()`](crate::TestServer::delete()), and [`TestServer::patch()`](crate::TestServer::patch()) methods.
77/// Each returns a [`TestRequest`](crate::TestRequest), which allows for customising the request content.
78///
79/// For example:
80///
81/// ```rust
82/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
83/// #
84/// use axum::Router;
85/// use axum::routing::get;
86///
87/// use axum_test::TestServer;
88///
89/// let app = Router::new()
90///     .route(&"/hello", get(|| async { "hello!" }));
91///
92/// let server = TestServer::new(app);
93///
94/// let response = server.get(&"/hello")
95///     .authorization_bearer("password12345")
96///     .add_header("x-custom-header", "custom-value")
97///     .await;
98///
99/// response.assert_text("hello!");
100/// #
101/// # Ok(())
102/// # }
103/// ```
104///
105/// Request methods also exist for using Axum Extra [`axum_extra::routing::TypedPath`],
106/// or for building Reqwest [`reqwest::RequestBuilder`]. See those methods for detauls.
107///
108/// # Customising
109///
110/// A `TestServer` can be built from a builder, by calling [`TestServer::builder`],
111/// and customising settings. This allows one to set **mocked** (default when possible)
112/// or **real http** networking for your service.
113///
114/// ```rust
115/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
116/// #
117/// use axum::Router;
118/// use axum::routing::get;
119///
120/// use axum_test::TestServer;
121///
122/// let app = Router::new()
123///     .route(&"/hello", get(|| async { "hello!" }));
124///
125/// // Customise server when building
126/// let mut server = TestServer::builder()
127///     .http_transport()
128///     .expect_success_by_default()
129///     .save_cookies()
130///     .build(app);
131///
132/// // Add items to be sent on _all_ all requests
133/// server.add_header("x-custom-for-all", "common-value");
134///
135/// let response = server.get("/hello").await;
136/// #
137/// # Ok(())
138/// # }
139/// ```
140///
141#[derive(Debug)]
142pub struct TestServer {
143    state: ServerSharedState,
144    cookie_jar: Arc<AtomicCrossCookieJar>,
145    transport: Arc<Box<dyn TransportLayer>>,
146    expected_state: ExpectedState,
147    default_content_type: Option<String>,
148    is_http_path_restricted: bool,
149
150    #[cfg(feature = "reqwest")]
151    maybe_reqwest_client: OnceCell<Client>,
152}
153
154impl TestServer {
155    /// A helper function to create a builder for creating a [`TestServer`].
156    pub fn builder() -> TestServerBuilder {
157        TestServerBuilder::default()
158    }
159
160    /// This will run the given Axum app,
161    /// allowing you to make requests against it.
162    ///
163    /// This is the same as creating a new `TestServer` with a configuration,
164    /// and passing [`TestServerConfig::default()`].
165    ///
166    /// Note: this will panic if the `TestServer` cannot be built.
167    /// To catch the error use [`TestServer::try_new`].
168    ///
169    /// ```rust
170    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
171    /// #
172    /// use axum::Router;
173    /// use axum::routing::get;
174    /// use axum_test::TestServer;
175    ///
176    /// let app = Router::new()
177    ///     .route(&"/hello", get(|| async { "hello!" }));
178    ///
179    /// let server = TestServer::new(app);
180    /// #
181    /// # Ok(())
182    /// # }
183    /// ```
184    ///
185    /// The type of applications that can be passed in include:
186    ///
187    ///  - [`axum::Router`]
188    ///  - [`axum::routing::IntoMakeService`]
189    ///  - [`axum::extract::connect_info::IntoMakeServiceWithConnectInfo`]
190    ///  - [`axum::serve::Serve`]
191    ///  - [`axum::serve::WithGracefulShutdown`]
192    ///  - A function returning an [`actix_web::App`]
193    ///
194    pub fn new<A>(app: A) -> Self
195    where
196        A: IntoTransportLayer,
197    {
198        Self::try_new(app).error_message("Failed to build TestServer")
199    }
200
201    /// Attempts to create a [`TestServer`], and returns an error if this fails.
202    pub fn try_new<A>(app: A) -> Result<Self>
203    where
204        A: IntoTransportLayer,
205    {
206        Self::try_new_with_config(app, TestServerConfig::default())
207    }
208
209    /// Similar to [`TestServer::new()`], with a customised configuration.
210    /// This includes type of transport in use (i.e. specify a specific port),
211    /// or change default settings (like the default content type for requests).
212    ///
213    /// This can take a [`TestServerConfig`] or a [`TestServerBuilder`].
214    /// See those for more information on configuration settings.
215    pub fn new_with_config<A, C>(app: A, config: C) -> Self
216    where
217        A: IntoTransportLayer,
218        C: Into<TestServerConfig>,
219    {
220        Self::try_new_with_config(app, config).error_message("Failed to build TestServer")
221    }
222
223    /// Attempts to create a [`TestServer`], and returns an error if this fails.
224    pub fn try_new_with_config<A, C>(app: A, config: C) -> Result<Self>
225    where
226        A: IntoTransportLayer,
227        C: Into<TestServerConfig>,
228    {
229        let config = config.into();
230        let state = ServerSharedState::new();
231
232        let transport = match config.transport {
233            None => {
234                let builder = TransportLayerBuilder::new(None, None);
235                let transport = app.into_default_transport(builder)?;
236                Arc::new(transport)
237            }
238            Some(Transport::HttpRandomPort) => {
239                let builder = TransportLayerBuilder::new(None, None);
240                let transport = app.into_http_transport_layer(builder)?;
241                Arc::new(transport)
242            }
243            Some(Transport::HttpIpPort { ip, port }) => {
244                let builder = TransportLayerBuilder::new(ip, port);
245                let transport = app.into_http_transport_layer(builder)?;
246                Arc::new(transport)
247            }
248            Some(Transport::MockHttp) => {
249                let transport = app.into_mock_transport_layer()?;
250                Arc::new(transport)
251            }
252        };
253
254        let expected_state = match config.expect_success_by_default {
255            true => ExpectedState::Success,
256            false => ExpectedState::None,
257        };
258
259        Ok(Self {
260            state,
261            cookie_jar: Arc::new(AtomicCrossCookieJar::new(config.save_cookies)),
262            transport,
263            expected_state,
264            default_content_type: config.default_content_type,
265            is_http_path_restricted: config.restrict_requests_with_http_scheme,
266
267            #[cfg(feature = "reqwest")]
268            maybe_reqwest_client: Default::default(),
269        })
270    }
271
272    /// Creates a HTTP GET request to the path.
273    pub fn get(&self, path: &str) -> TestRequest {
274        self.method(Method::GET, path)
275    }
276
277    /// Creates a HTTP POST request to the given path.
278    pub fn post(&self, path: &str) -> TestRequest {
279        self.method(Method::POST, path)
280    }
281
282    /// Creates a HTTP PATCH request to the path.
283    pub fn patch(&self, path: &str) -> TestRequest {
284        self.method(Method::PATCH, path)
285    }
286
287    /// Creates a HTTP PUT request to the path.
288    pub fn put(&self, path: &str) -> TestRequest {
289        self.method(Method::PUT, path)
290    }
291
292    /// Creates a HTTP DELETE request to the path.
293    pub fn delete(&self, path: &str) -> TestRequest {
294        self.method(Method::DELETE, path)
295    }
296
297    /// Creates a HTTP request, to the method and path provided.
298    pub fn method(&self, method: Method, path: &str) -> TestRequest {
299        let config = self
300            .build_test_request_config(method.clone(), path)
301            .error_message_fn(|| format!("Failed to build request, for {method} {path}"));
302
303        TestRequest::new(self.transport.clone(), config)
304    }
305
306    #[cfg(feature = "reqwest")]
307    fn reqwest_client(&self) -> &Client {
308        self.maybe_reqwest_client.get_or_init(|| {
309            if self.transport.transport_layer_type() == TransportLayerType::Mock {
310                panic!("Reqwest client is not available, TestServer must be build with HTTP transport for Reqwest to be available");
311            }
312
313            reqwest::Client::builder()
314                .redirect(reqwest::redirect::Policy::none())
315                .cookie_provider(self.cookie_jar.clone())
316                .build()
317                .expect("Failed to build Reqwest Client")
318        })
319    }
320
321    #[cfg(feature = "reqwest")]
322    pub fn reqwest_get(&self, path: &str) -> RequestBuilder {
323        self.reqwest_method(Method::GET, path)
324    }
325
326    #[cfg(feature = "reqwest")]
327    pub fn reqwest_post(&self, path: &str) -> RequestBuilder {
328        self.reqwest_method(Method::POST, path)
329    }
330
331    #[cfg(feature = "reqwest")]
332    pub fn reqwest_put(&self, path: &str) -> RequestBuilder {
333        self.reqwest_method(Method::PUT, path)
334    }
335
336    #[cfg(feature = "reqwest")]
337    pub fn reqwest_patch(&self, path: &str) -> RequestBuilder {
338        self.reqwest_method(Method::PATCH, path)
339    }
340
341    #[cfg(feature = "reqwest")]
342    pub fn reqwest_delete(&self, path: &str) -> RequestBuilder {
343        self.reqwest_method(Method::DELETE, path)
344    }
345
346    #[cfg(feature = "reqwest")]
347    pub fn reqwest_head(&self, path: &str) -> RequestBuilder {
348        self.reqwest_method(Method::HEAD, path)
349    }
350
351    /// Creates a HTTP request, using Reqwest, using the method + path described.
352    /// This expects a relative url to the `TestServer`.
353    ///
354    /// ```rust
355    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
356    /// #
357    /// use axum::Router;
358    /// use axum_test::TestServer;
359    ///
360    /// let my_app = Router::new();
361    /// let server = TestServer::builder()
362    ///     .http_transport() // Important, must be HTTP!
363    ///     .build(my_app);
364    ///
365    /// // Build your request
366    /// let request = server.get(&"/user")
367    ///     .add_header("x-custom-header", "example.com")
368    ///     .content_type("application/yaml");
369    ///
370    /// // await request to execute
371    /// let response = request.await;
372    /// #
373    /// # Ok(()) }
374    /// ```
375    #[cfg(feature = "reqwest")]
376    pub fn reqwest_method(&self, method: Method, path: &str) -> RequestBuilder {
377        let request_url = self
378            .server_url(path)
379            .expect("Failed to generate server url for request {method} {path}");
380
381        self.reqwest_client().request(method, request_url)
382    }
383
384    /// Creates a request to the server, to start a Websocket connection,
385    /// on the path given.
386    ///
387    /// This is the requivalent of making a GET request to the endpoint,
388    /// and setting the various headers needed for making an upgrade request.
389    ///
390    /// *Note*, this requires the server to be running on a real HTTP
391    /// port. Either using a randomly assigned port, or a specified one.
392    /// See the [`TestServerConfig::transport`](crate::TestServerConfig::transport) for more details.
393    ///
394    /// # Example
395    ///
396    /// ```rust
397    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
398    /// #
399    /// use axum::Router;
400    /// use axum_test::TestServer;
401    ///
402    /// let app = Router::new();
403    /// let server = TestServer::builder()
404    ///     .http_transport()
405    ///     .build(app);
406    ///
407    /// let mut websocket = server
408    ///     .get_websocket(&"/my-web-socket-end-point")
409    ///     .await
410    ///     .into_websocket()
411    ///     .await;
412    ///
413    /// websocket.send_text("Hello!").await;
414    /// #
415    /// # Ok(()) }
416    /// ```
417    ///
418    #[cfg(feature = "ws")]
419    pub fn get_websocket(&self, path: &str) -> TestRequest {
420        use http::header;
421
422        self.get(path)
423            .add_header(header::CONNECTION, "upgrade")
424            .add_header(header::UPGRADE, "websocket")
425            .add_header(header::SEC_WEBSOCKET_VERSION, "13")
426            .add_header(
427                header::SEC_WEBSOCKET_KEY,
428                crate::internals::generate_ws_key(),
429            )
430    }
431
432    /// Creates a HTTP GET request, using the typed path provided.
433    ///
434    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
435    ///
436    /// # Example Test
437    ///
438    /// Using a `TypedPath` you can write build and test a route like below:
439    ///
440    /// ```rust
441    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
442    /// #
443    /// use axum::Json;
444    /// use axum::Router;
445    /// use axum::routing::get;
446    /// use axum_extra::routing::RouterExt;
447    /// use axum_extra::routing::TypedPath;
448    /// use serde::Deserialize;
449    /// use serde::Serialize;
450    ///
451    /// use axum_test::TestServer;
452    ///
453    /// #[derive(TypedPath, Deserialize)]
454    /// #[typed_path("/users/{user_id}")]
455    /// struct UserPath {
456    ///     pub user_id: u32,
457    /// }
458    ///
459    /// // Build a typed route:
460    /// async fn route_get_user(UserPath { user_id }: UserPath) -> String {
461    ///     format!("hello user {user_id}")
462    /// }
463    ///
464    /// let app = Router::new()
465    ///     .typed_get(route_get_user);
466    ///
467    /// // Then test the route:
468    /// let server = TestServer::new(app);
469    /// server
470    ///     .typed_get(&UserPath { user_id: 123 })
471    ///     .await
472    ///     .assert_text("hello user 123");
473    /// #
474    /// # Ok(())
475    /// # }
476    /// ```
477    ///
478    #[cfg(feature = "typed-routing")]
479    pub fn typed_get<P>(&self, path: &P) -> TestRequest
480    where
481        P: TypedPath,
482    {
483        self.typed_method(Method::GET, path)
484    }
485
486    /// Creates a HTTP POST request, using the typed path provided.
487    ///
488    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
489    #[cfg(feature = "typed-routing")]
490    pub fn typed_post<P>(&self, path: &P) -> TestRequest
491    where
492        P: TypedPath,
493    {
494        self.typed_method(Method::POST, path)
495    }
496
497    /// Creates a HTTP PATCH request, using the typed path provided.
498    ///
499    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
500    #[cfg(feature = "typed-routing")]
501    pub fn typed_patch<P>(&self, path: &P) -> TestRequest
502    where
503        P: TypedPath,
504    {
505        self.typed_method(Method::PATCH, path)
506    }
507
508    /// Creates a HTTP PUT request, using the typed path provided.
509    ///
510    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
511    #[cfg(feature = "typed-routing")]
512    pub fn typed_put<P>(&self, path: &P) -> TestRequest
513    where
514        P: TypedPath,
515    {
516        self.typed_method(Method::PUT, path)
517    }
518
519    /// Creates a HTTP DELETE request, using the typed path provided.
520    ///
521    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
522    #[cfg(feature = "typed-routing")]
523    pub fn typed_delete<P>(&self, path: &P) -> TestRequest
524    where
525        P: TypedPath,
526    {
527        self.typed_method(Method::DELETE, path)
528    }
529
530    /// Creates a typed HTTP request, using the method provided.
531    ///
532    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
533    #[cfg(feature = "typed-routing")]
534    pub fn typed_method<P>(&self, method: Method, path: &P) -> TestRequest
535    where
536        P: TypedPath,
537    {
538        self.method(method, &path.to_string())
539    }
540
541    /// Returns the local web address for the test server,
542    /// if an address is available.
543    ///
544    /// The address is available when running as a real web server,
545    /// by setting the [`TestServerConfig`](crate::TestServerConfig) `transport` field to `Transport::HttpRandomPort` or `Transport::HttpIpPort`.
546    ///
547    /// This will return `None` when there is mock HTTP transport (the default).
548    pub fn server_address(&self) -> Option<Url> {
549        self.url()
550    }
551
552    /// This turns a relative path, into an absolute path to the server.
553    /// i.e. A path like `/users/123` will become something like `http://127.0.0.1:1234/users/123`.
554    ///
555    /// The absolute address can be used to make requests to the running server,
556    /// using any appropriate client you wish.
557    ///
558    /// # Example
559    ///
560    /// ```rust
561    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
562    /// #
563    /// use axum::Router;
564    /// use axum_test::TestServer;
565    ///
566    /// let app = Router::new();
567    /// let server = TestServer::builder()
568    ///         .http_transport()
569    ///         .build(app);
570    ///
571    /// let full_url = server.server_url(&"/users/123?filter=enabled")?;
572    ///
573    /// // Prints something like ... http://127.0.0.1:1234/users/123?filter=enabled
574    /// println!("{full_url}");
575    /// #
576    /// # Ok(()) }
577    /// ```
578    ///
579    /// This will return an error if you are using the mock transport.
580    /// Real HTTP transport is required to use this method (see [`TestServerConfig`](crate::TestServerConfig) `transport` field).
581    ///
582    /// It will also return an error if you provide an absolute path,
583    /// for example if you pass in `http://google.com`.
584    pub fn server_url(&self, path: &str) -> Result<Url> {
585        let path_uri = path.parse::<Uri>()?;
586        if is_absolute_uri(&path_uri) {
587            return Err(anyhow!(
588                "Absolute path provided for building server url, need to provide a relative uri"
589            ));
590        }
591
592        let server_url = self.url()
593            .ok_or_else(||
594                anyhow!(
595                    "No local address for server, need to run with HTTP transport to have a server address",
596                )
597            )?;
598
599        let mut query_params = self.state.query_params().clone();
600        let mut full_server_url = build_url(
601            server_url,
602            path,
603            &mut query_params,
604            self.is_http_path_restricted,
605        )?;
606
607        // Ensure the query params are present
608        if query_params.has_content() {
609            full_server_url.set_query(Some(&query_params.to_string()));
610        }
611
612        Ok(full_server_url)
613    }
614
615    /// Adds a single cookie to be included on *all* future requests.
616    ///
617    /// If a cookie with the same name already exists,
618    /// then it will be replaced.
619    pub fn add_cookie(&mut self, cookie: Cookie) {
620        self.cookie_jar.add_cookie(cookie);
621    }
622
623    /// Adds extra cookies to be used on *all* future requests.
624    ///
625    /// Any cookies which have the same name as the new cookies,
626    /// will get replaced.
627    pub fn add_cookies(&mut self, cookies: CookieJar) {
628        self.cookie_jar.add_cookies_by_jar(cookies);
629    }
630
631    /// Clears all of the cookies stored internally.
632    pub fn clear_cookies(&mut self) {
633        self.cookie_jar.clear_cookies();
634    }
635
636    /// Requests made using this `TestServer` will save their cookies for future requests to send.
637    /// Including sharing cookies with requests made using the `reqwest` feature.
638    ///
639    /// This behaviour is off by default.
640    pub fn save_cookies(&mut self) {
641        self.cookie_jar.enable_saving();
642    }
643
644    /// Requests made using this `TestServer` will _not_ save their cookies for future requests to send up.
645    /// Including sharing cookies with requests made using the `reqwest` feature.
646    ///
647    /// This is the default behaviour.
648    pub fn do_not_save_cookies(&mut self) {
649        self.cookie_jar.disable_saving();
650    }
651
652    /// Requests made using this `TestServer` will assert a HTTP status in the 2xx range will be returned, unless marked otherwise.
653    ///
654    /// By default this behaviour is off.
655    pub fn expect_success(&mut self) {
656        self.expected_state = ExpectedState::Success;
657    }
658
659    /// Requests made using this `TestServer` will assert a HTTP status is outside the 2xx range will be returned, unless marked otherwise.
660    ///
661    /// By default this behaviour is off.
662    pub fn expect_failure(&mut self) {
663        self.expected_state = ExpectedState::Failure;
664    }
665
666    /// Adds a query parameter to be sent on *all* future requests.
667    pub fn add_query_param<V>(&mut self, key: &str, value: V)
668    where
669        V: Serialize,
670    {
671        self.state
672            .add_query_param(key, value)
673            .error_message("Failed to add query parameter");
674    }
675
676    /// Adds query parameters to be sent on *all* future requests.
677    pub fn add_query_params<V>(&mut self, query_params: V)
678    where
679        V: Serialize,
680    {
681        self.state
682            .add_query_params(query_params)
683            .error_message("Failed to add query parameters");
684    }
685
686    /// Adds a raw query param, with no urlencoding of any kind,
687    /// to be send on *all* future requests.
688    pub fn add_raw_query_param(&mut self, raw_query_param: &str) {
689        self.state.add_raw_query_param(raw_query_param);
690    }
691
692    /// Clears all query params set.
693    pub fn clear_query_params(&mut self) {
694        self.state.clear_query_params();
695    }
696
697    /// Adds a header to be sent with all future requests built from this `TestServer`.
698    ///
699    /// ```rust
700    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
701    /// #
702    /// use axum::Router;
703    /// use axum_test::TestServer;
704    ///
705    /// let app = Router::new();
706    /// let mut server = TestServer::new(app);
707    ///
708    /// server.add_header("x-custom-header", "custom-value");
709    /// server.add_header(http::header::CONTENT_LENGTH, 12345);
710    /// server.add_header(http::header::HOST, "example.com");
711    ///
712    /// let response = server.get(&"/my-end-point")
713    ///     .await;
714    /// #
715    /// # Ok(()) }
716    /// ```
717    pub fn add_header<N, V>(&mut self, name: N, value: V)
718    where
719        N: TryInto<HeaderName>,
720        N::Error: Debug,
721        V: TryInto<HeaderValue>,
722        V::Error: Debug,
723    {
724        let header_name: HeaderName = name
725            .try_into()
726            .expect("Failed to convert header name to HeaderName");
727        let header_value: HeaderValue = value
728            .try_into()
729            .expect("Failed to convert header vlue to HeaderValue");
730
731        self.state.add_header(header_name, header_value);
732    }
733
734    /// Clears all headers set so far.
735    pub fn clear_headers(&mut self) {
736        self.state.clear_headers();
737    }
738
739    pub(crate) fn url(&self) -> Option<Url> {
740        self.transport.url().cloned()
741    }
742
743    pub(crate) fn build_test_request_config(
744        &self,
745        method: Method,
746        path: &str,
747    ) -> Result<TestRequestConfig> {
748        let url = self
749            .url()
750            .unwrap_or_else(|| DEFAULT_URL_ADDRESS.parse().unwrap());
751
752        let mut query_params = self.state.query_params().clone();
753        let headers = self.state.headers().clone();
754        let full_request_url =
755            build_url(url, path, &mut query_params, self.is_http_path_restricted)?;
756
757        Ok(TestRequestConfig {
758            atomic_cookie_jar: self.cookie_jar.clone(),
759
760            // These are copied over from the cookie jar,
761            // as the server could change it's save state after the request is made.
762            is_saving_cookies: self.cookie_jar.is_saving(),
763            cookies: self.cookie_jar.to_cookie_jar(),
764
765            expected_state: self.expected_state,
766            content_type: self.default_content_type.clone(),
767            method,
768
769            full_request_url,
770            query_params,
771            headers,
772        })
773    }
774
775    /// Returns true or false if the underlying service inside the `TestServer`
776    /// is still running. For many types of services this will always return `true`.
777    ///
778    /// When a `TestServer` is built using [`axum::serve::WithGracefulShutdown`],
779    /// this will return false if the service has shutdown.
780    pub fn is_running(&self) -> bool {
781        self.transport.is_running()
782    }
783}
784
785fn build_url(
786    mut url: Url,
787    path: &str,
788    query_params: &mut QueryParamsStore,
789    is_http_restricted: bool,
790) -> Result<Url> {
791    let path_uri = path.parse::<Uri>()?;
792
793    // If there is a scheme, then this is an absolute path.
794    if let Some(scheme) = path_uri.scheme_str() {
795        if is_http_restricted {
796            if has_different_scheme(&url, &path_uri) || has_different_authority(&url, &path_uri) {
797                return Err(anyhow!(
798                    "Request disallowed for path '{path}', requests are only allowed to local server. Turn off 'restrict_requests_with_http_scheme' to change this."
799                ));
800            }
801        } else {
802            url.set_scheme(scheme)
803                .map_err(|_| anyhow!("Failed to set scheme for request, with path '{path}'"))?;
804
805            // We only set the host/port if the scheme is also present.
806            if let Some(authority) = path_uri.authority() {
807                url.set_host(Some(authority.host()))
808                    .map_err(|_| anyhow!("Failed to set host for request, with path '{path}'"))?;
809                url.set_port(authority.port().map(|p| p.as_u16()))
810                    .map_err(|_| anyhow!("Failed to set port for request, with path '{path}'"))?;
811
812                // todo, add username:password support
813            }
814        }
815    }
816
817    // Why does this exist?
818    //
819    // This exists to allow `server.get("/users")` and `server.get("users")` (without a slash)
820    // to go to the same place.
821    //
822    // It does this by saying ...
823    //  - if there is a scheme, it's a full path.
824    //  - if no scheme, it must be a path
825    //
826    if is_absolute_uri(&path_uri) {
827        url.set_path(path_uri.path());
828
829        // In this path we are replacing, so drop any query params on the original url.
830        if url.query().is_some() {
831            url.set_query(None);
832        }
833    } else {
834        // Grab everything up until the query parameters, or everything after that
835        let calculated_path = path.split('?').next().unwrap_or(path);
836        url.set_path(calculated_path);
837
838        // Move any query parameters from the url to the query params store.
839        if let Some(url_query) = url.query() {
840            query_params.add_raw(url_query.to_string());
841            url.set_query(None);
842        }
843    }
844
845    if let Some(path_query) = path_uri.query() {
846        query_params.add_raw(path_query.to_string());
847    }
848
849    Ok(url)
850}
851
852fn is_absolute_uri(path_uri: &Uri) -> bool {
853    path_uri.scheme_str().is_some()
854}
855
856fn has_different_scheme(base_url: &Url, path_uri: &Uri) -> bool {
857    if let Some(scheme) = path_uri.scheme_str() {
858        return scheme != base_url.scheme();
859    }
860
861    false
862}
863
864fn has_different_authority(base_url: &Url, path_uri: &Uri) -> bool {
865    if let Some(authority) = path_uri.authority() {
866        return authority.as_str() != base_url.authority();
867    }
868
869    false
870}
871
872#[cfg(test)]
873mod test_build_url {
874    use super::*;
875
876    #[test]
877    fn it_should_copy_path_to_url_returned_when_restricted() {
878        let base_url = "http://example.com".parse::<Url>().unwrap();
879        let path = "/users";
880        let mut query_params = QueryParamsStore::new();
881        let result = build_url(base_url, &path, &mut query_params, true).unwrap();
882
883        assert_eq!("http://example.com/users", result.as_str());
884        assert!(query_params.is_empty());
885    }
886
887    #[test]
888    fn it_should_copy_all_query_params_to_store_when_restricted() {
889        let base_url = "http://example.com?base=aaa".parse::<Url>().unwrap();
890        let path = "/users?path=bbb&path-flag";
891        let mut query_params = QueryParamsStore::new();
892        let result = build_url(base_url, &path, &mut query_params, true).unwrap();
893
894        assert_eq!("http://example.com/users", result.as_str());
895        assert_eq!("base=aaa&path=bbb&path-flag", query_params.to_string());
896    }
897
898    #[test]
899    fn it_should_not_replace_url_when_restricted_with_different_scheme() {
900        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
901        let path = "ftp://google.com:123/users.csv?limit=456";
902        let mut query_params = QueryParamsStore::new();
903        let result = build_url(base_url, &path, &mut query_params, true);
904
905        assert!(result.is_err());
906    }
907
908    #[test]
909    fn it_should_not_replace_url_when_restricted_with_same_scheme() {
910        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
911        let path = "http://google.com:123/users.csv?limit=456";
912        let mut query_params = QueryParamsStore::new();
913        let result = build_url(base_url, &path, &mut query_params, true);
914
915        assert!(result.is_err());
916    }
917
918    #[test]
919    fn it_should_block_url_when_restricted_with_same_scheme() {
920        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
921        let path = "http://google.com";
922        let mut query_params = QueryParamsStore::new();
923        let result = build_url(base_url, &path, &mut query_params, true);
924
925        assert!(result.is_err());
926    }
927
928    #[test]
929    fn it_should_block_url_when_restricted_and_same_domain_with_different_scheme() {
930        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
931        let path = "ftp://example.com/users";
932        let mut query_params = QueryParamsStore::new();
933        let result = build_url(base_url, &path, &mut query_params, true);
934
935        assert!(result.is_err());
936    }
937
938    #[test]
939    fn it_should_copy_path_to_url_returned_when_unrestricted() {
940        let base_url = "http://example.com".parse::<Url>().unwrap();
941        let path = "/users";
942        let mut query_params = QueryParamsStore::new();
943        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
944
945        assert_eq!("http://example.com/users", result.as_str());
946        assert!(query_params.is_empty());
947    }
948
949    #[test]
950    fn it_should_copy_all_query_params_to_store_when_unrestricted() {
951        let base_url = "http://example.com?base=aaa".parse::<Url>().unwrap();
952        let path = "/users?path=bbb&path-flag";
953        let mut query_params = QueryParamsStore::new();
954        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
955
956        assert_eq!("http://example.com/users", result.as_str());
957        assert_eq!("base=aaa&path=bbb&path-flag", query_params.to_string());
958    }
959
960    #[test]
961    fn it_should_copy_host_like_a_path_when_unrestricted() {
962        let base_url = "http://example.com".parse::<Url>().unwrap();
963        let path = "google.com";
964        let mut query_params = QueryParamsStore::new();
965        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
966
967        assert_eq!("http://example.com/google.com", result.as_str());
968        assert!(query_params.is_empty());
969    }
970
971    #[test]
972    fn it_should_copy_host_like_a_path_when_restricted() {
973        let base_url = "http://example.com".parse::<Url>().unwrap();
974        let path = "google.com";
975        let mut query_params = QueryParamsStore::new();
976        let result = build_url(base_url, &path, &mut query_params, true).unwrap();
977
978        assert_eq!("http://example.com/google.com", result.as_str());
979        assert!(query_params.is_empty());
980    }
981
982    #[test]
983    fn it_should_replace_url_when_unrestricted() {
984        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
985        let path = "ftp://google.com:123/users.csv?limit=456";
986        let mut query_params = QueryParamsStore::new();
987        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
988
989        assert_eq!("ftp://google.com:123/users.csv", result.as_str());
990        assert_eq!("limit=456", query_params.to_string());
991    }
992
993    #[test]
994    fn it_should_allow_different_scheme_when_unrestricted() {
995        let base_url = "http://example.com".parse::<Url>().unwrap();
996        let path = "ftp://example.com";
997        let mut query_params = QueryParamsStore::new();
998        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
999
1000        assert_eq!("ftp://example.com/", result.as_str());
1001    }
1002
1003    #[test]
1004    fn it_should_allow_different_host_when_unrestricted() {
1005        let base_url = "http://example.com".parse::<Url>().unwrap();
1006        let path = "http://google.com";
1007        let mut query_params = QueryParamsStore::new();
1008        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1009
1010        assert_eq!("http://google.com/", result.as_str());
1011    }
1012
1013    #[test]
1014    fn it_should_allow_different_port_when_unrestricted() {
1015        let base_url = "http://example.com:123".parse::<Url>().unwrap();
1016        let path = "http://example.com:456";
1017        let mut query_params = QueryParamsStore::new();
1018        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1019
1020        assert_eq!("http://example.com:456/", result.as_str());
1021    }
1022
1023    #[test]
1024    fn it_should_allow_same_host_port_when_unrestricted() {
1025        let base_url = "http://example.com:123".parse::<Url>().unwrap();
1026        let path = "http://example.com:123";
1027        let mut query_params = QueryParamsStore::new();
1028        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1029
1030        assert_eq!("http://example.com:123/", result.as_str());
1031    }
1032
1033    #[test]
1034    fn it_should_not_allow_different_scheme_when_restricted() {
1035        let base_url = "http://example.com".parse::<Url>().unwrap();
1036        let path = "ftp://example.com";
1037        let mut query_params = QueryParamsStore::new();
1038        let result = build_url(base_url, &path, &mut query_params, true);
1039
1040        assert!(result.is_err());
1041    }
1042
1043    #[test]
1044    fn it_should_not_allow_different_host_when_restricted() {
1045        let base_url = "http://example.com".parse::<Url>().unwrap();
1046        let path = "http://google.com";
1047        let mut query_params = QueryParamsStore::new();
1048        let result = build_url(base_url, &path, &mut query_params, true);
1049
1050        assert!(result.is_err());
1051    }
1052
1053    #[test]
1054    fn it_should_not_allow_different_port_when_restricted() {
1055        let base_url = "http://example.com:123".parse::<Url>().unwrap();
1056        let path = "http://example.com:456";
1057        let mut query_params = QueryParamsStore::new();
1058        let result = build_url(base_url, &path, &mut query_params, true);
1059
1060        assert!(result.is_err());
1061    }
1062
1063    #[test]
1064    fn it_should_allow_same_host_port_when_restricted() {
1065        let base_url = "http://example.com:123".parse::<Url>().unwrap();
1066        let path = "http://example.com:123";
1067        let mut query_params = QueryParamsStore::new();
1068        let result = build_url(base_url, &path, &mut query_params, true).unwrap();
1069
1070        assert_eq!("http://example.com:123/", result.as_str());
1071    }
1072}
1073
1074#[cfg(test)]
1075mod test_new {
1076    use axum::Router;
1077    use axum::routing::get;
1078    use std::net::SocketAddr;
1079
1080    use crate::TestServer;
1081
1082    async fn get_ping() -> &'static str {
1083        "pong!"
1084    }
1085
1086    #[tokio::test]
1087    async fn it_should_run_into_make_into_service_with_connect_info_by_default() {
1088        // Build an application with a route.
1089        let app = Router::new()
1090            .route("/ping", get(get_ping))
1091            .into_make_service_with_connect_info::<SocketAddr>();
1092
1093        // Run the server.
1094        let server = TestServer::new(app);
1095
1096        // Get the request.
1097        server.get(&"/ping").await.assert_text(&"pong!");
1098    }
1099}
1100
1101#[cfg(test)]
1102mod test_get {
1103    use super::*;
1104    use crate::testing::catch_panic_error_message;
1105    use axum::Router;
1106    use axum::routing::get;
1107    use pretty_assertions::assert_str_eq;
1108    use reserve_port::ReservedSocketAddr;
1109
1110    async fn get_ping() -> &'static str {
1111        "pong!"
1112    }
1113
1114    #[tokio::test]
1115    async fn it_should_get_using_relative_path_with_slash() {
1116        let app = Router::new().route("/ping", get(get_ping));
1117        let server = TestServer::new(app);
1118
1119        // Get the request _with_ slash
1120        server.get(&"/ping").await.assert_text(&"pong!");
1121    }
1122
1123    #[tokio::test]
1124    async fn it_should_get_using_relative_path_without_slash() {
1125        let app = Router::new().route("/ping", get(get_ping));
1126        let server = TestServer::new(app);
1127
1128        // Get the request _without_ slash
1129        server.get(&"ping").await.assert_text(&"pong!");
1130    }
1131
1132    #[tokio::test]
1133    async fn it_should_get_using_absolute_path() {
1134        // Build an application with a route.
1135        let app = Router::new().route("/ping", get(get_ping));
1136
1137        // Reserve an address
1138        let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1139        let ip = reserved_address.ip();
1140        let port = reserved_address.port();
1141
1142        // Run the server.
1143        let server = TestServer::builder()
1144            .http_transport_with_ip_port(Some(ip), Some(port))
1145            .try_build(app)
1146            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1147
1148        // Get the request.
1149        let absolute_url = format!("http://{ip}:{port}/ping");
1150        let response = server.get(&absolute_url).await;
1151
1152        response.assert_text(&"pong!");
1153        let request_path = response.request_url();
1154        assert_eq!(request_path.to_string(), format!("http://{ip}:{port}/ping"));
1155    }
1156
1157    #[tokio::test]
1158    async fn it_should_get_using_absolute_path_and_restricted_if_path_is_for_server() {
1159        // Build an application with a route.
1160        let app = Router::new().route("/ping", get(get_ping));
1161
1162        // Reserve an IP / Port
1163        let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1164        let ip = reserved_address.ip();
1165        let port = reserved_address.port();
1166
1167        // Run the server.
1168        let server = TestServer::builder()
1169            .http_transport_with_ip_port(Some(ip), Some(port))
1170            .restrict_requests_with_http_scheme() // Key part of the test!
1171            .try_build(app)
1172            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1173
1174        // Get the request.
1175        let absolute_url = format!("http://{ip}:{port}/ping");
1176        let response = server.get(&absolute_url).await;
1177
1178        response.assert_text(&"pong!");
1179        let request_path = response.request_url();
1180        assert_eq!(request_path.to_string(), format!("http://{ip}:{port}/ping"));
1181    }
1182
1183    #[tokio::test]
1184    async fn it_should_not_get_using_absolute_path_if_restricted_and_different_port() {
1185        // Build an application with a route.
1186        let app = Router::new().route("/ping", get(get_ping));
1187
1188        // Reserve an IP / Port
1189        let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1190        let ip = reserved_address.ip();
1191        let mut port = reserved_address.port();
1192
1193        // Run the server.
1194        let server = TestServer::builder()
1195            .http_transport_with_ip_port(Some(ip), Some(port))
1196            .restrict_requests_with_http_scheme() // Key part of the test!
1197            .try_build(app)
1198            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1199
1200        // Get the request.
1201        port += 1; // << Change the port to be off by one and not match the server
1202        let absolute_url = format!("http://{ip}:{port}/ping");
1203
1204        let message = catch_panic_error_message(|| {
1205            let _ = server.get(&absolute_url);
1206        });
1207
1208        let expected = format!("Failed to build request, for GET http://{ip}:{port}/ping,
1209    Request disallowed for path 'http://{ip}:{port}/ping', requests are only allowed to local server. Turn off 'restrict_requests_with_http_scheme' to change this.
1210");
1211        assert_str_eq!(expected, message);
1212    }
1213
1214    #[tokio::test]
1215    async fn it_should_work_in_parallel() {
1216        let app = Router::new().route("/ping", get(get_ping));
1217        let server = TestServer::new(app);
1218
1219        let future1 = async { server.get("/ping").await };
1220        let future2 = async { server.get("/ping").await };
1221        let (r1, r2) = tokio::join!(future1, future2);
1222
1223        assert_eq!(r1.text(), r2.text());
1224    }
1225
1226    #[tokio::test]
1227    async fn it_should_work_in_parallel_with_sleeping_requests() {
1228        let app = axum::Router::new().route(
1229            &"/slow",
1230            axum::routing::get(|| async {
1231                tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1232                "hello!"
1233            }),
1234        );
1235
1236        let server = TestServer::new(app);
1237
1238        let future1 = async { server.get("/slow").await };
1239        let future2 = async { server.get("/slow").await };
1240        let (r1, r2) = tokio::join!(future1, future2);
1241
1242        assert_eq!(r1.text(), r2.text());
1243    }
1244}
1245
1246#[cfg(feature = "reqwest")]
1247#[cfg(test)]
1248mod test_reqwest_get {
1249    use super::*;
1250    use axum::Router;
1251    use axum::routing::get;
1252
1253    async fn get_ping() -> &'static str {
1254        "pong!"
1255    }
1256
1257    #[tokio::test]
1258    async fn it_should_get_using_relative_path_with_slash() {
1259        let app = Router::new().route("/ping", get(get_ping));
1260        let server = TestServer::builder().http_transport().build(app);
1261
1262        let response = server
1263            .reqwest_get(&"/ping")
1264            .send()
1265            .await
1266            .unwrap()
1267            .text()
1268            .await
1269            .unwrap();
1270
1271        assert_eq!(response, "pong!");
1272    }
1273}
1274
1275#[cfg(feature = "reqwest")]
1276#[cfg(test)]
1277mod test_reqwest_post {
1278    use super::*;
1279    use axum::Json;
1280    use axum::Router;
1281    use axum::routing::post;
1282    use serde::Deserialize;
1283
1284    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1285    struct TestBody {
1286        number: u32,
1287        text: String,
1288    }
1289
1290    async fn post_json(Json(body): Json<TestBody>) -> Json<TestBody> {
1291        let response = TestBody {
1292            number: body.number * 2,
1293            text: format!("{}_plus_response", body.text),
1294        };
1295
1296        Json(response)
1297    }
1298
1299    #[tokio::test]
1300    async fn it_should_post_and_receive_json() {
1301        let app = Router::new().route("/json", post(post_json));
1302        let server = TestServer::builder().http_transport().build(app);
1303
1304        let response = server
1305            .reqwest_post(&"/json")
1306            .json(&TestBody {
1307                number: 111,
1308                text: format!("request"),
1309            })
1310            .send()
1311            .await
1312            .unwrap()
1313            .json::<TestBody>()
1314            .await
1315            .unwrap();
1316
1317        assert_eq!(
1318            response,
1319            TestBody {
1320                number: 222,
1321                text: format!("request_plus_response"),
1322            }
1323        );
1324    }
1325}
1326
1327#[cfg(test)]
1328mod test_server_address {
1329    use super::*;
1330    use axum::Router;
1331    use regex::Regex;
1332    use reserve_port::ReservedPort;
1333    use std::net::Ipv4Addr;
1334
1335    #[tokio::test]
1336    async fn it_should_return_address_used_from_config() {
1337        let reserved_port = ReservedPort::random().unwrap();
1338        let ip = Ipv4Addr::LOCALHOST.into();
1339        let port = reserved_port.port();
1340
1341        // Build an application with a route.
1342        let app = Router::new();
1343        let server = TestServer::builder()
1344            .http_transport_with_ip_port(Some(ip), Some(port))
1345            .try_build(app)
1346            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1347
1348        let expected_ip_port = format!("http://{}:{}/", ip, reserved_port.port());
1349        assert_eq!(
1350            server.server_address().unwrap().to_string(),
1351            expected_ip_port
1352        );
1353    }
1354
1355    #[tokio::test]
1356    async fn it_should_return_default_address_without_ending_slash() {
1357        let app = Router::new();
1358        let server = TestServer::builder().http_transport().build(app);
1359
1360        let address_regex = Regex::new("^http://127\\.0\\.0\\.1:[0-9]+/$").unwrap();
1361        let is_match = address_regex.is_match(&server.server_address().unwrap().to_string());
1362        assert!(is_match);
1363    }
1364
1365    #[tokio::test]
1366    async fn it_should_return_none_on_mock_transport() {
1367        let app = Router::new();
1368        let server = TestServer::builder().mock_transport().build(app);
1369
1370        assert!(server.server_address().is_none());
1371    }
1372}
1373
1374#[cfg(test)]
1375mod test_server_url {
1376    use super::*;
1377    use axum::Router;
1378    use pretty_assertions::assert_str_eq;
1379    use regex::Regex;
1380    use reserve_port::ReservedPort;
1381    use std::net::Ipv4Addr;
1382
1383    #[tokio::test]
1384    async fn it_should_return_address_with_url_on_http_ip_port() {
1385        let reserved_port = ReservedPort::random().unwrap();
1386        let ip = Ipv4Addr::LOCALHOST.into();
1387        let port = reserved_port.port();
1388
1389        // Build an application with a route.
1390        let app = Router::new();
1391        let server = TestServer::builder()
1392            .http_transport_with_ip_port(Some(ip), Some(port))
1393            .try_build(app)
1394            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1395
1396        let expected_ip_port_url = format!("http://{}:{}/users", ip, reserved_port.port());
1397        let absolute_url = server.server_url("/users").unwrap().to_string();
1398        assert_eq!(expected_ip_port_url, absolute_url);
1399    }
1400
1401    #[tokio::test]
1402    async fn it_should_return_address_with_url_on_random_http() {
1403        let app = Router::new();
1404        let server = TestServer::builder().http_transport().build(app);
1405
1406        let address_regex =
1407            Regex::new("^http://127\\.0\\.0\\.1:[0-9]+/users/123\\?filter=enabled$").unwrap();
1408        let absolute_url = &server
1409            .server_url(&"/users/123?filter=enabled")
1410            .unwrap()
1411            .to_string();
1412
1413        let is_match = address_regex.is_match(absolute_url);
1414        assert!(is_match);
1415    }
1416
1417    #[tokio::test]
1418    async fn it_should_error_on_mock_transport() {
1419        // Build an application with a route.
1420        let app = Router::new();
1421        let server = TestServer::builder().mock_transport().build(app);
1422
1423        let result = server.server_url("/users");
1424        assert!(result.is_err());
1425    }
1426
1427    #[tokio::test]
1428    async fn it_should_include_path_query_params() {
1429        let reserved_port = ReservedPort::random().unwrap();
1430        let ip = Ipv4Addr::LOCALHOST.into();
1431        let port = reserved_port.port();
1432
1433        // Build an application with a route.
1434        let app = Router::new();
1435        let server = TestServer::builder()
1436            .http_transport_with_ip_port(Some(ip), Some(port))
1437            .try_build(app)
1438            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1439
1440        let expected_url = format!(
1441            "http://{}:{}/users?filter=enabled",
1442            ip,
1443            reserved_port.port()
1444        );
1445        let received_url = server
1446            .server_url("/users?filter=enabled")
1447            .unwrap()
1448            .to_string();
1449
1450        assert_eq!(expected_url, received_url);
1451    }
1452
1453    #[tokio::test]
1454    async fn it_should_include_server_query_params() {
1455        let reserved_port = ReservedPort::random().unwrap();
1456        let ip = Ipv4Addr::LOCALHOST.into();
1457        let port = reserved_port.port();
1458
1459        // Build an application with a route.
1460        let app = Router::new();
1461        let mut server = TestServer::builder()
1462            .http_transport_with_ip_port(Some(ip), Some(port))
1463            .try_build(app)
1464            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1465
1466        server.add_query_param("filter", "enabled");
1467
1468        let expected_url = format!(
1469            "http://{}:{}/users?filter=enabled",
1470            ip,
1471            reserved_port.port()
1472        );
1473        let received_url = server.server_url("/users").unwrap().to_string();
1474
1475        assert_eq!(expected_url, received_url);
1476    }
1477
1478    #[tokio::test]
1479    async fn it_should_include_server_and_path_query_params() {
1480        let reserved_port = ReservedPort::random().unwrap();
1481        let ip = Ipv4Addr::LOCALHOST.into();
1482        let port = reserved_port.port();
1483
1484        // Build an application with a route.
1485        let app = Router::new();
1486        let mut server = TestServer::builder()
1487            .http_transport_with_ip_port(Some(ip), Some(port))
1488            .try_build(app)
1489            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1490
1491        server.add_query_param("filter", "enabled");
1492
1493        let expected_url = format!(
1494            "http://{}:{}/users?filter=enabled&animal=donkeys",
1495            ip,
1496            reserved_port.port()
1497        );
1498        let received_url = server
1499            .server_url("/users?animal=donkeys")
1500            .unwrap()
1501            .to_string();
1502
1503        assert_eq!(expected_url, received_url);
1504    }
1505
1506    #[tokio::test]
1507    async fn it_should_include_both_server_and_path_queries() {
1508        let reserved_port = ReservedPort::random().unwrap();
1509        let ip = Ipv4Addr::LOCALHOST.into();
1510        let port = reserved_port.port();
1511
1512        // Build an application with a route.
1513        let app = Router::new();
1514        let mut server = TestServer::builder()
1515            .http_transport_with_ip_port(Some(ip), Some(port))
1516            .try_build(app)
1517            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1518
1519        server.add_query_param("query", "server");
1520
1521        let expected_url = format!(
1522            "http://{}:{}/users?query=server&query=path",
1523            ip,
1524            reserved_port.port()
1525        );
1526        let received_url = server.server_url("/users?query=path").unwrap().to_string();
1527
1528        assert_eq!(expected_url, received_url);
1529    }
1530
1531    #[tokio::test]
1532    async fn it_should_work_for_paths_with_leading_slash() {
1533        let reserved_port = ReservedPort::random().unwrap();
1534        let ip = Ipv4Addr::LOCALHOST.into();
1535        let port = reserved_port.port();
1536
1537        // Build an application with a route.
1538        let app = Router::new();
1539        let server = TestServer::builder()
1540            .http_transport_with_ip_port(Some(ip), Some(port))
1541            .try_build(app)
1542            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1543
1544        let expected_url = format!("http://{}:{}/users", ip, reserved_port.port());
1545        let received_url = server.server_url("users").unwrap().to_string();
1546
1547        assert_eq!(expected_url, received_url);
1548    }
1549
1550    // TODO, change this behaviour to allow an empty path. It should be the same as no path at all.
1551    #[tokio::test]
1552    async fn it_should_panic_when_provided_an_empty_path() {
1553        let reserved_port = ReservedPort::random().unwrap();
1554        let ip = Ipv4Addr::LOCALHOST.into();
1555        let port = reserved_port.port();
1556
1557        // Build an application with a route.
1558        let app = Router::new();
1559        let server = TestServer::builder()
1560            .http_transport_with_ip_port(Some(ip), Some(port))
1561            .try_build(app)
1562            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1563
1564        // let expected_url = format!("http://{}:{}", ip, reserved_port.port());
1565        let error_message = server.server_url("").unwrap_err().to_string();
1566
1567        assert_str_eq!("empty string", error_message);
1568    }
1569}
1570
1571#[cfg(test)]
1572mod test_add_cookie {
1573    use crate::TestServer;
1574    use axum::Router;
1575    use axum::routing::get;
1576    use axum_extra::extract::cookie::CookieJar;
1577    use cookie::Cookie;
1578
1579    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
1580
1581    async fn get_cookie(cookies: CookieJar) -> (CookieJar, String) {
1582        let cookie = cookies.get(&TEST_COOKIE_NAME);
1583        let cookie_value = cookie
1584            .map(|c| c.value().to_string())
1585            .unwrap_or_else(|| "cookie-not-found".to_string());
1586
1587        (cookies, cookie_value)
1588    }
1589
1590    #[tokio::test]
1591    async fn it_should_send_cookies_added_to_request() {
1592        let app = Router::new().route("/cookie", get(get_cookie));
1593        let mut server = TestServer::new(app);
1594
1595        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1596        server.add_cookie(cookie);
1597
1598        let response_text = server.get(&"/cookie").await.text();
1599        assert_eq!(response_text, "my-custom-cookie");
1600    }
1601}
1602
1603#[cfg(test)]
1604mod test_add_cookies {
1605    use crate::TestServer;
1606
1607    use axum::Router;
1608    use axum::routing::get;
1609    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1610    use cookie::Cookie;
1611    use cookie::CookieJar;
1612
1613    async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1614        let mut all_cookies = cookies
1615            .iter()
1616            .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1617            .collect::<Vec<String>>();
1618        all_cookies.sort();
1619
1620        all_cookies.join(&", ")
1621    }
1622
1623    #[tokio::test]
1624    async fn it_should_send_all_cookies_added_by_jar() {
1625        let app = Router::new().route("/cookies", get(route_get_cookies));
1626        let mut server = TestServer::new(app);
1627
1628        // Build cookies to send up
1629        let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1630        let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1631        let mut cookie_jar = CookieJar::new();
1632        cookie_jar.add(cookie_1);
1633        cookie_jar.add(cookie_2);
1634
1635        server.add_cookies(cookie_jar);
1636
1637        server
1638            .get(&"/cookies")
1639            .await
1640            .assert_text("first-cookie=my-custom-cookie, second-cookie=other-cookie");
1641    }
1642}
1643
1644#[cfg(test)]
1645mod test_clear_cookies {
1646    use crate::TestServer;
1647
1648    use axum::Router;
1649    use axum::routing::get;
1650    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1651    use cookie::Cookie;
1652    use cookie::CookieJar;
1653
1654    async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1655        let mut all_cookies = cookies
1656            .iter()
1657            .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1658            .collect::<Vec<String>>();
1659        all_cookies.sort();
1660
1661        all_cookies.join(&", ")
1662    }
1663
1664    #[tokio::test]
1665    async fn it_should_not_send_cookies_cleared() {
1666        let app = Router::new().route("/cookies", get(route_get_cookies));
1667        let mut server = TestServer::new(app);
1668
1669        let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1670        let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1671        let mut cookie_jar = CookieJar::new();
1672        cookie_jar.add(cookie_1);
1673        cookie_jar.add(cookie_2);
1674
1675        server.add_cookies(cookie_jar);
1676
1677        // The important bit of this test
1678        server.clear_cookies();
1679
1680        server.get(&"/cookies").await.assert_text("");
1681    }
1682}
1683
1684#[cfg(test)]
1685mod test_add_header {
1686    use super::*;
1687    use crate::TestServer;
1688    use axum::Router;
1689    use axum::extract::FromRequestParts;
1690    use axum::routing::get;
1691    use http::HeaderName;
1692    use http::HeaderValue;
1693    use http::request::Parts;
1694    use hyper::StatusCode;
1695    use std::marker::Sync;
1696
1697    const TEST_HEADER_NAME: &'static str = &"test-header";
1698    const TEST_HEADER_CONTENT: &'static str = &"Test header content";
1699
1700    struct TestHeader(Vec<u8>);
1701
1702    impl<S: Sync> FromRequestParts<S> for TestHeader {
1703        type Rejection = (StatusCode, &'static str);
1704
1705        async fn from_request_parts(
1706            parts: &mut Parts,
1707            _state: &S,
1708        ) -> Result<TestHeader, Self::Rejection> {
1709            parts
1710                .headers
1711                .get(HeaderName::from_static(TEST_HEADER_NAME))
1712                .map(|v| TestHeader(v.as_bytes().to_vec()))
1713                .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
1714        }
1715    }
1716
1717    async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
1718        header
1719    }
1720
1721    #[tokio::test]
1722    async fn it_should_send_header_added_to_server() {
1723        // Build an application with a route.
1724        let app = Router::new().route("/header", get(ping_header));
1725
1726        // Run the server.
1727        let mut server = TestServer::new(app);
1728        server.add_header(
1729            HeaderName::from_static(TEST_HEADER_NAME),
1730            HeaderValue::from_static(TEST_HEADER_CONTENT),
1731        );
1732
1733        // Send a request with the header
1734        let response = server.get(&"/header").await;
1735
1736        // Check it sent back the right text
1737        response.assert_text(TEST_HEADER_CONTENT);
1738    }
1739}
1740
1741#[cfg(test)]
1742mod test_clear_headers {
1743    use super::*;
1744    use crate::TestServer;
1745    use axum::Router;
1746    use axum::extract::FromRequestParts;
1747    use axum::routing::get;
1748    use http::HeaderName;
1749    use http::HeaderValue;
1750    use http::request::Parts;
1751    use hyper::StatusCode;
1752    use std::marker::Sync;
1753
1754    const TEST_HEADER_NAME: &'static str = &"test-header";
1755    const TEST_HEADER_CONTENT: &'static str = &"Test header content";
1756
1757    struct TestHeader(Vec<u8>);
1758
1759    impl<S: Sync> FromRequestParts<S> for TestHeader {
1760        type Rejection = (StatusCode, &'static str);
1761
1762        async fn from_request_parts(
1763            parts: &mut Parts,
1764            _state: &S,
1765        ) -> Result<Self, Self::Rejection> {
1766            parts
1767                .headers
1768                .get(HeaderName::from_static(TEST_HEADER_NAME))
1769                .map(|v| TestHeader(v.as_bytes().to_vec()))
1770                .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
1771        }
1772    }
1773
1774    async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
1775        header
1776    }
1777
1778    #[tokio::test]
1779    async fn it_should_not_send_headers_cleared_by_server() {
1780        // Build an application with a route.
1781        let app = Router::new().route("/header", get(ping_header));
1782
1783        // Run the server.
1784        let mut server = TestServer::new(app);
1785        server.add_header(
1786            HeaderName::from_static(TEST_HEADER_NAME),
1787            HeaderValue::from_static(TEST_HEADER_CONTENT),
1788        );
1789        server.clear_headers();
1790
1791        // Send a request with the header
1792        let response = server.get(&"/header").await;
1793
1794        // Check it sent back the right text
1795        response.assert_status_bad_request();
1796        response.assert_text("Missing test header");
1797    }
1798}
1799
1800#[cfg(test)]
1801mod test_add_query_params {
1802    use axum::Router;
1803    use axum::extract::Query;
1804    use axum::routing::get;
1805
1806    use serde::Deserialize;
1807    use serde::Serialize;
1808    use serde_json::json;
1809
1810    use crate::TestServer;
1811
1812    #[derive(Debug, Deserialize, Serialize)]
1813    struct QueryParam {
1814        message: String,
1815    }
1816
1817    async fn get_query_param(Query(params): Query<QueryParam>) -> String {
1818        params.message
1819    }
1820
1821    #[derive(Debug, Deserialize, Serialize)]
1822    struct QueryParam2 {
1823        message: String,
1824        other: String,
1825    }
1826
1827    async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
1828        format!("{}-{}", params.message, params.other)
1829    }
1830
1831    #[tokio::test]
1832    async fn it_should_pass_up_query_params_from_serialization() {
1833        // Build an application with a route.
1834        let app = Router::new().route("/query", get(get_query_param));
1835
1836        // Run the server.
1837        let mut server = TestServer::new(app);
1838        server.add_query_params(QueryParam {
1839            message: "it works".to_string(),
1840        });
1841
1842        // Get the request.
1843        server.get(&"/query").await.assert_text(&"it works");
1844    }
1845
1846    #[tokio::test]
1847    async fn it_should_pass_up_query_params_from_pairs() {
1848        // Build an application with a route.
1849        let app = Router::new().route("/query", get(get_query_param));
1850
1851        // Run the server.
1852        let mut server = TestServer::new(app);
1853        server.add_query_params(&[("message", "it works")]);
1854
1855        // Get the request.
1856        server.get(&"/query").await.assert_text(&"it works");
1857    }
1858
1859    #[tokio::test]
1860    async fn it_should_pass_up_multiple_query_params_from_multiple_params() {
1861        // Build an application with a route.
1862        let app = Router::new().route("/query-2", get(get_query_param_2));
1863
1864        // Run the server.
1865        let mut server = TestServer::new(app);
1866        server.add_query_params(&[("message", "it works"), ("other", "yup")]);
1867
1868        // Get the request.
1869        server.get(&"/query-2").await.assert_text(&"it works-yup");
1870    }
1871
1872    #[tokio::test]
1873    async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
1874        // Build an application with a route.
1875        let app = Router::new().route("/query-2", get(get_query_param_2));
1876
1877        // Run the server.
1878        let mut server = TestServer::new(app);
1879        server.add_query_params(&[("message", "it works")]);
1880        server.add_query_params(&[("other", "yup")]);
1881
1882        // Get the request.
1883        server.get(&"/query-2").await.assert_text(&"it works-yup");
1884    }
1885
1886    #[tokio::test]
1887    async fn it_should_pass_up_multiple_query_params_from_json() {
1888        // Build an application with a route.
1889        let app = Router::new().route("/query-2", get(get_query_param_2));
1890
1891        // Run the server.
1892        let mut server = TestServer::new(app);
1893        server.add_query_params(json!({
1894            "message": "it works",
1895            "other": "yup"
1896        }));
1897
1898        // Get the request.
1899        server.get(&"/query-2").await.assert_text(&"it works-yup");
1900    }
1901}
1902
1903#[cfg(test)]
1904mod test_add_query_param {
1905    use axum::Router;
1906    use axum::extract::Query;
1907    use axum::routing::get;
1908
1909    use serde::Deserialize;
1910    use serde::Serialize;
1911
1912    use crate::TestServer;
1913
1914    #[derive(Debug, Deserialize, Serialize)]
1915    struct QueryParam {
1916        message: String,
1917    }
1918
1919    async fn get_query_param(Query(params): Query<QueryParam>) -> String {
1920        params.message
1921    }
1922
1923    #[derive(Debug, Deserialize, Serialize)]
1924    struct QueryParam2 {
1925        message: String,
1926        other: String,
1927    }
1928
1929    async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
1930        format!("{}-{}", params.message, params.other)
1931    }
1932
1933    #[tokio::test]
1934    async fn it_should_pass_up_query_params_from_pairs() {
1935        // Build an application with a route.
1936        let app = Router::new().route("/query", get(get_query_param));
1937
1938        // Run the server.
1939        let mut server = TestServer::new(app);
1940        server.add_query_param("message", "it works");
1941
1942        // Get the request.
1943        server.get(&"/query").await.assert_text(&"it works");
1944    }
1945
1946    #[tokio::test]
1947    async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
1948        // Build an application with a route.
1949        let app = Router::new().route("/query-2", get(get_query_param_2));
1950
1951        // Run the server.
1952        let mut server = TestServer::new(app);
1953        server.add_query_param("message", "it works");
1954        server.add_query_param("other", "yup");
1955
1956        // Get the request.
1957        server.get(&"/query-2").await.assert_text(&"it works-yup");
1958    }
1959
1960    #[tokio::test]
1961    async fn it_should_pass_up_multiple_query_params_from_calls_across_server_and_request() {
1962        // Build an application with a route.
1963        let app = Router::new().route("/query-2", get(get_query_param_2));
1964
1965        // Run the server.
1966        let mut server = TestServer::new(app);
1967        server.add_query_param("message", "it works");
1968
1969        // Get the request.
1970        server
1971            .get(&"/query-2")
1972            .add_query_param("other", "yup")
1973            .await
1974            .assert_text(&"it works-yup");
1975    }
1976}
1977
1978#[cfg(test)]
1979mod test_add_raw_query_param {
1980    use axum::Router;
1981    use axum::extract::Query as AxumStdQuery;
1982    use axum::routing::get;
1983    use axum_extra::extract::Query as AxumExtraQuery;
1984    use serde::Deserialize;
1985    use serde::Serialize;
1986    use std::fmt::Write;
1987
1988    use crate::TestServer;
1989
1990    #[derive(Debug, Deserialize, Serialize)]
1991    struct QueryParam {
1992        message: String,
1993    }
1994
1995    async fn get_query_param(AxumStdQuery(params): AxumStdQuery<QueryParam>) -> String {
1996        params.message
1997    }
1998
1999    #[derive(Debug, Deserialize, Serialize)]
2000    struct QueryParamExtra {
2001        #[serde(default)]
2002        items: Vec<String>,
2003
2004        #[serde(default, rename = "arrs[]")]
2005        arrs: Vec<String>,
2006    }
2007
2008    async fn get_query_param_extra(
2009        AxumExtraQuery(params): AxumExtraQuery<QueryParamExtra>,
2010    ) -> String {
2011        let mut output = String::new();
2012
2013        if params.items.len() > 0 {
2014            write!(output, "{}", params.items.join(", ")).unwrap();
2015        }
2016
2017        if params.arrs.len() > 0 {
2018            write!(output, "{}", params.arrs.join(", ")).unwrap();
2019        }
2020
2021        output
2022    }
2023
2024    fn build_app() -> Router {
2025        Router::new()
2026            .route("/query", get(get_query_param))
2027            .route("/query-extra", get(get_query_param_extra))
2028    }
2029
2030    #[tokio::test]
2031    async fn it_should_pass_up_query_param_as_is() {
2032        // Run the server.
2033        let mut server = TestServer::new(build_app());
2034        server.add_raw_query_param(&"message=it-works");
2035
2036        // Get the request.
2037        server.get(&"/query").await.assert_text(&"it-works");
2038    }
2039
2040    #[tokio::test]
2041    async fn it_should_pass_up_array_query_params_as_one_string() {
2042        // Run the server.
2043        let mut server = TestServer::new(build_app());
2044        server.add_raw_query_param(&"items=one&items=two&items=three");
2045
2046        // Get the request.
2047        server
2048            .get(&"/query-extra")
2049            .await
2050            .assert_text(&"one, two, three");
2051    }
2052
2053    #[tokio::test]
2054    async fn it_should_pass_up_array_query_params_as_multiple_params() {
2055        // Run the server.
2056        let mut server = TestServer::new(build_app());
2057        server.add_raw_query_param(&"arrs[]=one");
2058        server.add_raw_query_param(&"arrs[]=two");
2059        server.add_raw_query_param(&"arrs[]=three");
2060
2061        // Get the request.
2062        server
2063            .get(&"/query-extra")
2064            .await
2065            .assert_text(&"one, two, three");
2066    }
2067}
2068
2069#[cfg(test)]
2070mod test_clear_query_params {
2071    use axum::Router;
2072    use axum::extract::Query;
2073    use axum::routing::get;
2074
2075    use serde::Deserialize;
2076    use serde::Serialize;
2077
2078    use crate::TestServer;
2079
2080    #[derive(Debug, Deserialize, Serialize)]
2081    struct QueryParams {
2082        first: Option<String>,
2083        second: Option<String>,
2084    }
2085
2086    async fn get_query_params(Query(params): Query<QueryParams>) -> String {
2087        format!(
2088            "has first? {}, has second? {}",
2089            params.first.is_some(),
2090            params.second.is_some()
2091        )
2092    }
2093
2094    #[tokio::test]
2095    async fn it_should_clear_all_params_set() {
2096        // Build an application with a route.
2097        let app = Router::new().route("/query", get(get_query_params));
2098
2099        // Run the server.
2100        let mut server = TestServer::new(app);
2101        server.add_query_params(QueryParams {
2102            first: Some("first".to_string()),
2103            second: Some("second".to_string()),
2104        });
2105        server.clear_query_params();
2106
2107        // Get the request.
2108        server
2109            .get(&"/query")
2110            .await
2111            .assert_text(&"has first? false, has second? false");
2112    }
2113
2114    #[tokio::test]
2115    async fn it_should_clear_all_params_set_and_allow_replacement() {
2116        // Build an application with a route.
2117        let app = Router::new().route("/query", get(get_query_params));
2118
2119        // Run the server.
2120        let mut server = TestServer::new(app);
2121        server.add_query_params(QueryParams {
2122            first: Some("first".to_string()),
2123            second: Some("second".to_string()),
2124        });
2125        server.clear_query_params();
2126        server.add_query_params(QueryParams {
2127            first: Some("first".to_string()),
2128            second: Some("second".to_string()),
2129        });
2130
2131        // Get the request.
2132        server
2133            .get(&"/query")
2134            .await
2135            .assert_text(&"has first? true, has second? true");
2136    }
2137}
2138
2139#[cfg(test)]
2140mod test_expect_success_by_default {
2141    use super::*;
2142    use crate::testing::catch_panic_error_message_async;
2143    use axum::Router;
2144    use axum::routing::get;
2145    use pretty_assertions::assert_str_eq;
2146
2147    #[tokio::test]
2148    async fn it_should_not_panic_by_default_if_accessing_404_route() {
2149        let app = Router::new();
2150        let server = TestServer::new(app);
2151
2152        server.get(&"/some_unknown_route").await;
2153    }
2154
2155    #[tokio::test]
2156    async fn it_should_not_panic_by_default_if_accessing_200_route() {
2157        let app = Router::new().route("/known_route", get(|| async { "🦊🦊🦊" }));
2158        let server = TestServer::new(app);
2159
2160        server.get(&"/known_route").await;
2161    }
2162
2163    #[tokio::test]
2164    async fn it_should_panic_by_default_if_accessing_404_route_and_expect_success_on() {
2165        let app = Router::new();
2166        let server = TestServer::builder().expect_success_by_default().build(app);
2167
2168        let message = catch_panic_error_message_async(server.get(&"/some_unknown_route")).await;
2169        assert_str_eq!(
2170            "Expect status code within 2xx range, received 404 (Not Found), for request GET http://localhost/some_unknown_route, with body ''",
2171            message
2172        );
2173    }
2174
2175    #[tokio::test]
2176    async fn it_should_not_panic_by_default_if_accessing_200_route_and_expect_success_on() {
2177        let app = Router::new().route("/known_route", get(|| async { "🦊🦊🦊" }));
2178        let server = TestServer::builder().expect_success_by_default().build(app);
2179
2180        server.get(&"/known_route").await;
2181    }
2182}
2183
2184#[cfg(test)]
2185mod test_content_type {
2186    use super::*;
2187    use axum::Router;
2188    use axum::routing::get;
2189    use http::HeaderMap;
2190    use http::header::CONTENT_TYPE;
2191
2192    async fn get_content_type(headers: HeaderMap) -> String {
2193        headers
2194            .get(CONTENT_TYPE)
2195            .map(|h| h.to_str().unwrap().to_string())
2196            .unwrap_or_else(|| "".to_string())
2197    }
2198
2199    #[tokio::test]
2200    async fn it_should_default_to_server_content_type_when_present() {
2201        // Build an application with a route.
2202        let app = Router::new().route("/content_type", get(get_content_type));
2203
2204        // Run the server.
2205        let server = TestServer::builder()
2206            .default_content_type("text/plain")
2207            .build(app);
2208
2209        // Get the request.
2210        let text = server.get(&"/content_type").await.text();
2211
2212        assert_eq!(text, "text/plain");
2213    }
2214}
2215
2216#[cfg(test)]
2217mod test_expect_success {
2218    use crate::TestServer;
2219    use crate::testing::catch_panic_error_message_async;
2220    use axum::Router;
2221    use axum::routing::get;
2222    use http::StatusCode;
2223    use pretty_assertions::assert_str_eq;
2224
2225    #[tokio::test]
2226    async fn it_should_not_panic_if_success_is_returned() {
2227        async fn get_ping() -> &'static str {
2228            "pong!"
2229        }
2230
2231        // Build an application with a route.
2232        let app = Router::new().route("/ping", get(get_ping));
2233
2234        // Run the server.
2235        let mut server = TestServer::new(app);
2236        server.expect_success();
2237
2238        // Get the request.
2239        server.get(&"/ping").await;
2240    }
2241
2242    #[tokio::test]
2243    async fn it_should_not_panic_on_other_2xx_status_code() {
2244        async fn get_accepted() -> StatusCode {
2245            StatusCode::ACCEPTED
2246        }
2247
2248        // Build an application with a route.
2249        let app = Router::new().route("/accepted", get(get_accepted));
2250
2251        // Run the server.
2252        let mut server = TestServer::new(app);
2253        server.expect_success();
2254
2255        // Get the request.
2256        server.get(&"/accepted").await;
2257    }
2258
2259    #[tokio::test]
2260    async fn it_should_panic_on_404() {
2261        // Build an application with a route.
2262        let app = Router::new();
2263
2264        // Run the server.
2265        let mut server = TestServer::new(app);
2266        server.expect_success();
2267
2268        // Get the request.
2269        let message = catch_panic_error_message_async(server.get(&"/some_unknown_route")).await;
2270        assert_str_eq!(
2271            "Expect status code within 2xx range, received 404 (Not Found), for request GET http://localhost/some_unknown_route, with body ''",
2272            message
2273        );
2274    }
2275}
2276
2277#[cfg(test)]
2278mod test_expect_failure {
2279    use crate::TestServer;
2280    use crate::testing::catch_panic_error_message_async;
2281    use axum::Router;
2282    use axum::routing::get;
2283    use http::StatusCode;
2284    use pretty_assertions::assert_str_eq;
2285
2286    #[tokio::test]
2287    async fn it_should_not_panic_if_expect_failure_on_404() {
2288        // Build an application with a route.
2289        let app = Router::new();
2290
2291        // Run the server.
2292        let mut server = TestServer::new(app);
2293        server.expect_failure();
2294
2295        // Get the request.
2296        server.get(&"/some_unknown_route").await;
2297    }
2298
2299    #[tokio::test]
2300    async fn it_should_panic_if_success_is_returned() {
2301        async fn get_ping() -> &'static str {
2302            "pong!"
2303        }
2304
2305        // Build an application with a route.
2306        let app = Router::new().route("/ping", get(get_ping));
2307
2308        // Run the server.
2309        let mut server = TestServer::new(app);
2310        server.expect_failure();
2311
2312        // Get the request.
2313        let message = catch_panic_error_message_async(server.get(&"/ping")).await;
2314        assert_str_eq!(
2315            "Expect status code outside 2xx range, received 200 (OK), for request GET http://localhost/ping, with body 'pong!'",
2316            message
2317        );
2318    }
2319
2320    #[tokio::test]
2321    async fn it_should_panic_on_other_2xx_status_code() {
2322        async fn get_accepted() -> StatusCode {
2323            StatusCode::ACCEPTED
2324        }
2325
2326        // Build an application with a route.
2327        let app = Router::new().route("/accepted", get(get_accepted));
2328
2329        // Run the server.
2330        let mut server = TestServer::new(app);
2331        server.expect_failure();
2332
2333        // Get the request.
2334        let message = catch_panic_error_message_async(server.get(&"/accepted")).await;
2335        assert_str_eq!(
2336            "Expect status code outside 2xx range, received 202 (Accepted), for request GET http://localhost/accepted, with body ''",
2337            message
2338        );
2339    }
2340}
2341
2342#[cfg(feature = "typed-routing")]
2343#[cfg(test)]
2344mod test_typed_get {
2345    use super::*;
2346    use axum::Router;
2347    use axum_extra::routing::RouterExt;
2348    use serde::Deserialize;
2349
2350    #[derive(TypedPath, Deserialize)]
2351    #[typed_path("/path/{id}")]
2352    struct TestingPath {
2353        id: u32,
2354    }
2355
2356    async fn route_get(TestingPath { id }: TestingPath) -> String {
2357        format!("get {id}")
2358    }
2359
2360    fn new_app() -> Router {
2361        Router::new().typed_get(route_get)
2362    }
2363
2364    #[tokio::test]
2365    async fn it_should_send_get() {
2366        let server = TestServer::new(new_app());
2367
2368        server
2369            .typed_get(&TestingPath { id: 123 })
2370            .await
2371            .assert_text("get 123");
2372    }
2373}
2374
2375#[cfg(feature = "typed-routing")]
2376#[cfg(test)]
2377mod test_typed_post {
2378    use super::*;
2379    use axum::Router;
2380    use axum_extra::routing::RouterExt;
2381    use serde::Deserialize;
2382
2383    #[derive(TypedPath, Deserialize)]
2384    #[typed_path("/path/{id}")]
2385    struct TestingPath {
2386        id: u32,
2387    }
2388
2389    async fn route_post(TestingPath { id }: TestingPath) -> String {
2390        format!("post {id}")
2391    }
2392
2393    fn new_app() -> Router {
2394        Router::new().typed_post(route_post)
2395    }
2396
2397    #[tokio::test]
2398    async fn it_should_send_post() {
2399        let server = TestServer::new(new_app());
2400
2401        server
2402            .typed_post(&TestingPath { id: 123 })
2403            .await
2404            .assert_text("post 123");
2405    }
2406}
2407
2408#[cfg(feature = "typed-routing")]
2409#[cfg(test)]
2410mod test_typed_patch {
2411    use super::*;
2412    use axum::Router;
2413    use axum_extra::routing::RouterExt;
2414    use serde::Deserialize;
2415
2416    #[derive(TypedPath, Deserialize)]
2417    #[typed_path("/path/{id}")]
2418    struct TestingPath {
2419        id: u32,
2420    }
2421
2422    async fn route_patch(TestingPath { id }: TestingPath) -> String {
2423        format!("patch {id}")
2424    }
2425
2426    fn new_app() -> Router {
2427        Router::new().typed_patch(route_patch)
2428    }
2429
2430    #[tokio::test]
2431    async fn it_should_send_patch() {
2432        let server = TestServer::new(new_app());
2433
2434        server
2435            .typed_patch(&TestingPath { id: 123 })
2436            .await
2437            .assert_text("patch 123");
2438    }
2439}
2440
2441#[cfg(feature = "typed-routing")]
2442#[cfg(test)]
2443mod test_typed_put {
2444    use super::*;
2445    use axum::Router;
2446    use axum_extra::routing::RouterExt;
2447    use serde::Deserialize;
2448
2449    #[derive(TypedPath, Deserialize)]
2450    #[typed_path("/path/{id}")]
2451    struct TestingPath {
2452        id: u32,
2453    }
2454
2455    async fn route_put(TestingPath { id }: TestingPath) -> String {
2456        format!("put {id}")
2457    }
2458
2459    fn new_app() -> Router {
2460        Router::new().typed_put(route_put)
2461    }
2462
2463    #[tokio::test]
2464    async fn it_should_send_put() {
2465        let server = TestServer::new(new_app());
2466
2467        server
2468            .typed_put(&TestingPath { id: 123 })
2469            .await
2470            .assert_text("put 123");
2471    }
2472}
2473
2474#[cfg(feature = "typed-routing")]
2475#[cfg(test)]
2476mod test_typed_delete {
2477    use super::*;
2478    use axum::Router;
2479    use axum_extra::routing::RouterExt;
2480    use serde::Deserialize;
2481
2482    #[derive(TypedPath, Deserialize)]
2483    #[typed_path("/path/{id}")]
2484    struct TestingPath {
2485        id: u32,
2486    }
2487
2488    async fn route_delete(TestingPath { id }: TestingPath) -> String {
2489        format!("delete {id}")
2490    }
2491
2492    fn new_app() -> Router {
2493        Router::new().typed_delete(route_delete)
2494    }
2495
2496    #[tokio::test]
2497    async fn it_should_send_delete() {
2498        let server = TestServer::new(new_app());
2499
2500        server
2501            .typed_delete(&TestingPath { id: 123 })
2502            .await
2503            .assert_text("delete 123");
2504    }
2505}
2506
2507#[cfg(feature = "typed-routing")]
2508#[cfg(test)]
2509mod test_typed_method {
2510    use super::*;
2511    use axum::Router;
2512    use axum_extra::routing::RouterExt;
2513    use serde::Deserialize;
2514
2515    #[derive(TypedPath, Deserialize)]
2516    #[typed_path("/path/{id}")]
2517    struct TestingPath {
2518        id: u32,
2519    }
2520
2521    async fn route_get(TestingPath { id }: TestingPath) -> String {
2522        format!("get {id}")
2523    }
2524
2525    async fn route_post(TestingPath { id }: TestingPath) -> String {
2526        format!("post {id}")
2527    }
2528
2529    async fn route_patch(TestingPath { id }: TestingPath) -> String {
2530        format!("patch {id}")
2531    }
2532
2533    async fn route_put(TestingPath { id }: TestingPath) -> String {
2534        format!("put {id}")
2535    }
2536
2537    async fn route_delete(TestingPath { id }: TestingPath) -> String {
2538        format!("delete {id}")
2539    }
2540
2541    fn new_app() -> Router {
2542        Router::new()
2543            .typed_get(route_get)
2544            .typed_post(route_post)
2545            .typed_patch(route_patch)
2546            .typed_put(route_put)
2547            .typed_delete(route_delete)
2548    }
2549
2550    #[tokio::test]
2551    async fn it_should_send_get() {
2552        let server = TestServer::new(new_app());
2553
2554        server
2555            .typed_method(Method::GET, &TestingPath { id: 123 })
2556            .await
2557            .assert_text("get 123");
2558    }
2559
2560    #[tokio::test]
2561    async fn it_should_send_post() {
2562        let server = TestServer::new(new_app());
2563
2564        server
2565            .typed_method(Method::POST, &TestingPath { id: 123 })
2566            .await
2567            .assert_text("post 123");
2568    }
2569
2570    #[tokio::test]
2571    async fn it_should_send_patch() {
2572        let server = TestServer::new(new_app());
2573
2574        server
2575            .typed_method(Method::PATCH, &TestingPath { id: 123 })
2576            .await
2577            .assert_text("patch 123");
2578    }
2579
2580    #[tokio::test]
2581    async fn it_should_send_put() {
2582        let server = TestServer::new(new_app());
2583
2584        server
2585            .typed_method(Method::PUT, &TestingPath { id: 123 })
2586            .await
2587            .assert_text("put 123");
2588    }
2589
2590    #[tokio::test]
2591    async fn it_should_send_delete() {
2592        let server = TestServer::new(new_app());
2593
2594        server
2595            .typed_method(Method::DELETE, &TestingPath { id: 123 })
2596            .await
2597            .assert_text("delete 123");
2598    }
2599}
2600
2601#[cfg(test)]
2602mod test_sync {
2603    use super::*;
2604    use axum::Router;
2605    use axum::routing::get;
2606    use std::cell::OnceCell;
2607
2608    #[tokio::test]
2609    async fn it_should_be_able_to_be_in_one_cell() {
2610        let cell: OnceCell<TestServer> = OnceCell::new();
2611        let server = cell.get_or_init(|| {
2612            async fn route_get() -> &'static str {
2613                "it works"
2614            }
2615
2616            let router = Router::new().route("/test", get(route_get));
2617
2618            TestServer::new(router)
2619        });
2620
2621        server.get("/test").await.assert_text("it works");
2622    }
2623}
2624
2625#[cfg(test)]
2626mod test_is_running {
2627    use super::*;
2628    use crate::testing::catch_panic_error_message_async;
2629    use crate::util::new_random_tokio_tcp_listener_with_socket_addr;
2630    use axum::Router;
2631    use axum::routing::IntoMakeService;
2632    use axum::routing::get;
2633    use axum::serve;
2634    use pretty_assertions::assert_str_eq;
2635    use std::time::Duration;
2636    use tokio::sync::Notify;
2637    use tokio::time::sleep;
2638
2639    async fn get_ping() -> &'static str {
2640        "pong!"
2641    }
2642
2643    #[tokio::test]
2644    async fn it_should_panic_when_run_with_mock_http() {
2645        let shutdown_notification = Arc::new(Notify::new());
2646        let waiting_notification = shutdown_notification.clone();
2647
2648        // Build an application with a route.
2649        let app: IntoMakeService<Router> = Router::new()
2650            .route("/ping", get(get_ping))
2651            .into_make_service();
2652        let (listener, ip_port) = new_random_tokio_tcp_listener_with_socket_addr().unwrap();
2653        let application = serve(listener, app)
2654            .with_graceful_shutdown(async move { waiting_notification.notified().await });
2655
2656        // Run the server.
2657        let server = TestServer::builder().build(application);
2658
2659        server.get("/ping").await.assert_status_ok();
2660        assert!(server.is_running());
2661
2662        shutdown_notification.notify_one();
2663        sleep(Duration::from_millis(10)).await;
2664
2665        assert!(!server.is_running());
2666
2667        let ip = ip_port.ip();
2668        let port = ip_port.port();
2669        let expected = format!(
2670            "Sending request failed, for request GET http://{ip}:{port}/ping,
2671    client error (Connect)
2672    tcp connect error
2673    Connection refused (os error 61)
2674"
2675        );
2676        let message = catch_panic_error_message_async(server.get("/ping")).await;
2677        assert_str_eq!(expected, message);
2678    }
2679}
2680
2681#[cfg(test)]
2682mod test_save_cookies {
2683    use crate::TestServer;
2684    use axum::Router;
2685    use axum::extract::Request;
2686    use axum::http::header::HeaderMap;
2687    use axum::routing::get;
2688    use axum::routing::put;
2689    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
2690    use cookie::Cookie;
2691    use cookie::SameSite;
2692    use http_body_util::BodyExt;
2693
2694    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
2695
2696    #[tokio::test]
2697    async fn it_should_save_cookies_across_requests_when_enabled() {
2698        let mut server = TestServer::new(app());
2699
2700        server.save_cookies();
2701
2702        save_cookie_using_axum_test(&server).await;
2703        assert_cookie_using_axum_test(&server).await;
2704    }
2705
2706    #[cfg(feature = "reqwest")]
2707    #[tokio::test]
2708    async fn it_should_save_cookies_across_reqwest_requests_when_enabled() {
2709        let mut server = TestServer::builder().http_transport().build(app());
2710
2711        server.save_cookies();
2712
2713        save_cookie_using_reqwest(&server).await;
2714        save_cookie_using_reqwest(&server).await;
2715    }
2716
2717    #[tokio::test]
2718    async fn it_should_save_cookies_across_axum_test_requests_when_enabled_for_second_request() {
2719        let mut server = TestServer::builder().http_transport().build(app());
2720
2721        save_cookie_using_axum_test(&server).await;
2722        assert_no_cookie_using_axum_test(&server).await;
2723
2724        server.save_cookies();
2725
2726        save_cookie_using_axum_test(&server).await;
2727        assert_cookie_using_axum_test(&server).await;
2728    }
2729
2730    #[cfg(feature = "reqwest")]
2731    #[tokio::test]
2732    async fn it_should_save_cookies_across_reqwest_requests_when_enabled_for_second_request() {
2733        let mut server = TestServer::builder().http_transport().build(app());
2734
2735        save_cookie_using_reqwest(&server).await;
2736        assert_no_cookie_using_reqwest(&server).await;
2737
2738        server.save_cookies();
2739
2740        save_cookie_using_reqwest(&server).await;
2741        assert_cookie_using_reqwest(&server).await;
2742    }
2743
2744    #[cfg(feature = "reqwest")]
2745    #[tokio::test]
2746    async fn it_should_save_cookies_when_set_by_reqwest_and_read_by_axum_test() {
2747        let mut server = TestServer::builder().http_transport().build(app());
2748
2749        server.save_cookies();
2750
2751        save_cookie_using_reqwest(&server).await;
2752        assert_cookie_using_axum_test(&server).await;
2753    }
2754
2755    #[cfg(feature = "reqwest")]
2756    #[tokio::test]
2757    async fn it_should_save_cookies_when_set_by_axum_test_and_read_by_reqwest() {
2758        let mut server = TestServer::builder().http_transport().build(app());
2759
2760        server.save_cookies();
2761
2762        save_cookie_using_axum_test(&server).await;
2763        assert_cookie_using_reqwest(&server).await;
2764    }
2765
2766    fn app() -> Router {
2767        async fn put_cookie_with_attributes(
2768            mut cookies: AxumCookieJar,
2769            request: Request,
2770        ) -> (AxumCookieJar, &'static str) {
2771            let body_bytes = request
2772                .into_body()
2773                .collect()
2774                .await
2775                .expect("Should turn the body into bytes")
2776                .to_bytes();
2777
2778            let body_text: String = String::from_utf8_lossy(&body_bytes).to_string();
2779            let cookie = Cookie::build((TEST_COOKIE_NAME, body_text))
2780                .http_only(true)
2781                .secure(true)
2782                .same_site(SameSite::Strict)
2783                .path("/cookie")
2784                .build();
2785            cookies = cookies.add(cookie);
2786
2787            (cookies, &"done")
2788        }
2789
2790        async fn get_cookie_headers_joined(headers: HeaderMap) -> String {
2791            let cookies: String = headers
2792                .get_all("cookie")
2793                .into_iter()
2794                .map(|c| c.to_str().unwrap_or("").to_string())
2795                .reduce(|a, b| a + "; " + &b)
2796                .unwrap_or_else(|| String::new());
2797
2798            cookies
2799        }
2800
2801        Router::new()
2802            .route("/cookie", put(put_cookie_with_attributes))
2803            .route("/cookie", get(get_cookie_headers_joined))
2804    }
2805
2806    async fn save_cookie_using_axum_test(server: &TestServer) {
2807        server.put(&"/cookie").text(&"cookie-found!").await;
2808    }
2809
2810    #[cfg(feature = "reqwest")]
2811    async fn save_cookie_using_reqwest(server: &TestServer) {
2812        server
2813            .reqwest_put(&"/cookie")
2814            .body("cookie-found!".to_string())
2815            .send()
2816            .await
2817            .unwrap();
2818    }
2819
2820    async fn assert_cookie_using_axum_test(server: &TestServer) {
2821        server
2822            .get(&"/cookie")
2823            .await
2824            .assert_text("test-cookie=cookie-found!");
2825    }
2826
2827    #[cfg(feature = "reqwest")]
2828    async fn assert_cookie_using_reqwest(server: &TestServer) {
2829        let response_text = server
2830            .reqwest_get(&"/cookie")
2831            .send()
2832            .await
2833            .unwrap()
2834            .text()
2835            .await
2836            .unwrap();
2837
2838        assert_eq!("test-cookie=cookie-found!", response_text);
2839    }
2840
2841    async fn assert_no_cookie_using_axum_test(server: &TestServer) {
2842        server.get(&"/cookie").await.assert_text("");
2843    }
2844
2845    #[cfg(feature = "reqwest")]
2846    async fn assert_no_cookie_using_reqwest(server: &TestServer) {
2847        let response_text = server
2848            .reqwest_get(&"/cookie")
2849            .send()
2850            .await
2851            .unwrap()
2852            .text()
2853            .await
2854            .unwrap();
2855
2856        assert_eq!("", response_text);
2857    }
2858}