Skip to main content

axum_test/
test_server.rs

1use crate::TestRequest;
2use crate::TestRequestConfig;
3use crate::TestServerBuilder;
4use crate::TestServerConfig;
5use crate::Transport;
6use crate::internals::AtomicCrossCookieJar;
7use crate::internals::ErrorMessage;
8use crate::internals::ExpectedState;
9use crate::internals::QueryParamsStore;
10use crate::transport_layer::IntoTransportLayer;
11use crate::transport_layer::TransportLayer;
12use crate::transport_layer::TransportLayerBuilder;
13use anyhow::Result;
14use anyhow::anyhow;
15use cookie::Cookie;
16use cookie::CookieJar;
17use http::HeaderName;
18use http::HeaderValue;
19use http::Method;
20use http::Uri;
21use serde::Serialize;
22use std::fmt::Debug;
23use std::sync::Arc;
24use url::Url;
25
26#[cfg(feature = "typed-routing")]
27use axum_extra::routing::TypedPath;
28
29#[cfg(feature = "reqwest")]
30use crate::transport_layer::TransportLayerType;
31#[cfg(feature = "reqwest")]
32use reqwest::Client;
33#[cfg(feature = "reqwest")]
34use reqwest::RequestBuilder;
35#[cfg(feature = "reqwest")]
36use std::cell::OnceCell;
37
38mod server_shared_state;
39pub(crate) use self::server_shared_state::*;
40
41const DEFAULT_URL_ADDRESS: &str = "http://localhost";
42
43///
44/// The `TestServer` runs your Axum application,
45/// allowing you to make HTTP requests against it.
46///
47/// # Building
48///
49/// A `TestServer` can be used to run an [`axum::Router`], an [`::axum::routing::IntoMakeService`],
50/// and others.
51///
52/// The most straight forward approach is to call [`TestServer::new`],
53/// and pass in your application:
54///
55/// ```rust
56/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
57/// #
58/// use axum::Router;
59/// use axum::routing::get;
60///
61/// use axum_test::TestServer;
62///
63/// let app = Router::new()
64///     .route(&"/hello", get(|| async { "hello!" }));
65///
66/// let server = TestServer::new(app);
67/// #
68/// # Ok(())
69/// # }
70/// ```
71///
72/// # Requests
73///
74/// Requests are built by calling [`TestServer::get()`](crate::TestServer::get()),
75/// [`TestServer::post()`](crate::TestServer::post()), [`TestServer::put()`](crate::TestServer::put()),
76/// [`TestServer::delete()`](crate::TestServer::delete()), and [`TestServer::patch()`](crate::TestServer::patch()) methods.
77/// Each returns a [`TestRequest`](crate::TestRequest), which allows for customising the request content.
78///
79/// For example:
80///
81/// ```rust
82/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
83/// #
84/// use axum::Router;
85/// use axum::routing::get;
86///
87/// use axum_test::TestServer;
88///
89/// let app = Router::new()
90///     .route(&"/hello", get(|| async { "hello!" }));
91///
92/// let server = TestServer::new(app);
93///
94/// let response = server.get(&"/hello")
95///     .authorization_bearer("password12345")
96///     .add_header("x-custom-header", "custom-value")
97///     .await;
98///
99/// response.assert_text("hello!");
100/// #
101/// # Ok(())
102/// # }
103/// ```
104///
105/// Request methods also exist for using Axum Extra [`axum_extra::routing::TypedPath`],
106/// or for building Reqwest [`reqwest::RequestBuilder`]. See those methods for detauls.
107///
108/// # Customising
109///
110/// A `TestServer` can be built from a builder, by calling [`TestServer::builder`],
111/// and customising settings. This allows one to set **mocked** (default when possible)
112/// or **real http** networking for your service.
113///
114/// ```rust
115/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
116/// #
117/// use axum::Router;
118/// use axum::routing::get;
119///
120/// use axum_test::TestServer;
121///
122/// let app = Router::new()
123///     .route(&"/hello", get(|| async { "hello!" }));
124///
125/// // Customise server when building
126/// let mut server = TestServer::builder()
127///     .http_transport()
128///     .expect_success_by_default()
129///     .save_cookies()
130///     .build(app);
131///
132/// // Add items to be sent on _all_ all requests
133/// server.add_header("x-custom-for-all", "common-value");
134///
135/// let response = server.get("/hello").await;
136/// #
137/// # Ok(())
138/// # }
139/// ```
140///
141#[derive(Debug)]
142pub struct TestServer {
143    state: ServerSharedState,
144    cookie_jar: Arc<AtomicCrossCookieJar>,
145    transport: Arc<Box<dyn TransportLayer>>,
146    expected_state: ExpectedState,
147    default_content_type: Option<String>,
148    is_http_path_restricted: bool,
149
150    #[cfg(feature = "reqwest")]
151    maybe_reqwest_client: OnceCell<Client>,
152}
153
154impl TestServer {
155    /// A helper function to create a builder for creating a [`TestServer`].
156    pub fn builder() -> TestServerBuilder {
157        TestServerBuilder::default()
158    }
159
160    /// This will run the given Axum app,
161    /// allowing you to make requests against it.
162    ///
163    /// This is the same as creating a new `TestServer` with a configuration,
164    /// and passing [`TestServerConfig::default()`].
165    ///
166    /// Note: this will panic if the `TestServer` cannot be built.
167    /// To catch the error use [`TestServer::try_new`].
168    ///
169    /// ```rust
170    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
171    /// #
172    /// use axum::Router;
173    /// use axum::routing::get;
174    /// use axum_test::TestServer;
175    ///
176    /// let app = Router::new()
177    ///     .route(&"/hello", get(|| async { "hello!" }));
178    ///
179    /// let server = TestServer::new(app);
180    /// #
181    /// # Ok(())
182    /// # }
183    /// ```
184    ///
185    /// The type of applications that can be passed in include:
186    ///
187    ///  - [`axum::Router`]
188    ///  - [`axum::routing::IntoMakeService`]
189    ///  - [`axum::extract::connect_info::IntoMakeServiceWithConnectInfo`]
190    ///  - [`axum::serve::Serve`]
191    ///  - [`axum::serve::WithGracefulShutdown`]
192    ///  - A function returning an [`actix_web::App`]
193    ///
194    pub fn new<A>(app: A) -> Self
195    where
196        A: IntoTransportLayer,
197    {
198        Self::try_new(app).error_message("Failed to build TestServer")
199    }
200
201    /// Attempts to create a [`TestServer`], and returns an error if this fails.
202    pub fn try_new<A>(app: A) -> Result<Self>
203    where
204        A: IntoTransportLayer,
205    {
206        Self::try_new_with_config(app, TestServerConfig::default())
207    }
208
209    /// Similar to [`TestServer::new()`], with a customised configuration.
210    /// This includes type of transport in use (i.e. specify a specific port),
211    /// or change default settings (like the default content type for requests).
212    ///
213    /// This can take a [`TestServerConfig`] or a [`TestServerBuilder`].
214    /// See those for more information on configuration settings.
215    pub fn new_with_config<A, C>(app: A, config: C) -> Self
216    where
217        A: IntoTransportLayer,
218        C: Into<TestServerConfig>,
219    {
220        Self::try_new_with_config(app, config).error_message("Failed to build TestServer")
221    }
222
223    /// Attempts to create a [`TestServer`], and returns an error if this fails.
224    pub fn try_new_with_config<A, C>(app: A, config: C) -> Result<Self>
225    where
226        A: IntoTransportLayer,
227        C: Into<TestServerConfig>,
228    {
229        let config = config.into();
230        let state = ServerSharedState::new();
231
232        let transport = match config.transport {
233            None => {
234                let builder = TransportLayerBuilder::from_ip_port(None, None);
235                let transport = app.into_default_transport(builder)?;
236                Arc::new(transport)
237            }
238            Some(Transport::HttpRandomPort) => {
239                let builder = TransportLayerBuilder::from_ip_port(None, None);
240                let transport = app.into_http_transport_layer(builder)?;
241                Arc::new(transport)
242            }
243            Some(Transport::HttpIpPort { ip, port }) => {
244                let builder = TransportLayerBuilder::from_ip_port(ip, port);
245                let transport = app.into_http_transport_layer(builder)?;
246                Arc::new(transport)
247            }
248            Some(Transport::HttpTcpListner { tcp_listener }) => {
249                let builder = TransportLayerBuilder::from_tcp_listener(tcp_listener);
250                let transport = app.into_http_transport_layer(builder)?;
251                Arc::new(transport)
252            }
253            Some(Transport::MockHttp) => {
254                let transport = app.into_mock_transport_layer()?;
255                Arc::new(transport)
256            }
257        };
258
259        let expected_state = match config.expect_success_by_default {
260            true => ExpectedState::Success,
261            false => ExpectedState::None,
262        };
263
264        Ok(Self {
265            state,
266            cookie_jar: Arc::new(AtomicCrossCookieJar::new(config.save_cookies)),
267            transport,
268            expected_state,
269            default_content_type: config.default_content_type,
270            is_http_path_restricted: config.restrict_requests_with_http_scheme,
271
272            #[cfg(feature = "reqwest")]
273            maybe_reqwest_client: Default::default(),
274        })
275    }
276
277    /// Creates a HTTP GET request to the path.
278    pub fn get(&self, path: &str) -> TestRequest {
279        self.method(Method::GET, path)
280    }
281
282    /// Creates a HTTP POST request to the given path.
283    pub fn post(&self, path: &str) -> TestRequest {
284        self.method(Method::POST, path)
285    }
286
287    /// Creates a HTTP PATCH request to the path.
288    pub fn patch(&self, path: &str) -> TestRequest {
289        self.method(Method::PATCH, path)
290    }
291
292    /// Creates a HTTP PUT request to the path.
293    pub fn put(&self, path: &str) -> TestRequest {
294        self.method(Method::PUT, path)
295    }
296
297    /// Creates a HTTP DELETE request to the path.
298    pub fn delete(&self, path: &str) -> TestRequest {
299        self.method(Method::DELETE, path)
300    }
301
302    /// Creates a HTTP request, to the method and path provided.
303    pub fn method(&self, method: Method, path: &str) -> TestRequest {
304        let config = self
305            .build_test_request_config(method.clone(), path)
306            .error_message_fn(|| format!("Failed to build request, for {method} {path}"));
307
308        TestRequest::new(self.transport.clone(), config)
309    }
310
311    #[cfg(feature = "reqwest")]
312    fn reqwest_client(&self) -> &Client {
313        self.maybe_reqwest_client.get_or_init(|| {
314            if self.transport.transport_layer_type() == TransportLayerType::Mock {
315                panic!("Reqwest client is not available, TestServer must be build with HTTP transport for Reqwest to be available");
316            }
317
318            reqwest::Client::builder()
319                .redirect(reqwest::redirect::Policy::none())
320                .cookie_provider(self.cookie_jar.clone())
321                .build()
322                .expect("Failed to build Reqwest Client")
323        })
324    }
325
326    #[cfg(feature = "reqwest")]
327    pub fn reqwest_get(&self, path: &str) -> RequestBuilder {
328        self.reqwest_method(Method::GET, path)
329    }
330
331    #[cfg(feature = "reqwest")]
332    pub fn reqwest_post(&self, path: &str) -> RequestBuilder {
333        self.reqwest_method(Method::POST, path)
334    }
335
336    #[cfg(feature = "reqwest")]
337    pub fn reqwest_put(&self, path: &str) -> RequestBuilder {
338        self.reqwest_method(Method::PUT, path)
339    }
340
341    #[cfg(feature = "reqwest")]
342    pub fn reqwest_patch(&self, path: &str) -> RequestBuilder {
343        self.reqwest_method(Method::PATCH, path)
344    }
345
346    #[cfg(feature = "reqwest")]
347    pub fn reqwest_delete(&self, path: &str) -> RequestBuilder {
348        self.reqwest_method(Method::DELETE, path)
349    }
350
351    #[cfg(feature = "reqwest")]
352    pub fn reqwest_head(&self, path: &str) -> RequestBuilder {
353        self.reqwest_method(Method::HEAD, path)
354    }
355
356    /// Creates a HTTP request, using Reqwest, using the method + path described.
357    /// This expects a relative url to the `TestServer`.
358    ///
359    /// ```rust
360    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
361    /// #
362    /// use axum::Router;
363    /// use axum_test::TestServer;
364    ///
365    /// let my_app = Router::new();
366    /// let server = TestServer::builder()
367    ///     .http_transport() // Important, must be HTTP!
368    ///     .build(my_app);
369    ///
370    /// // Build your request
371    /// let request = server.get(&"/user")
372    ///     .add_header("x-custom-header", "example.com")
373    ///     .content_type("application/yaml");
374    ///
375    /// // await request to execute
376    /// let response = request.await;
377    /// #
378    /// # Ok(()) }
379    /// ```
380    #[cfg(feature = "reqwest")]
381    pub fn reqwest_method(&self, method: Method, path: &str) -> RequestBuilder {
382        let request_url = self
383            .server_url(path)
384            .expect("Failed to generate server url for request {method} {path}");
385
386        self.reqwest_client().request(method, request_url)
387    }
388
389    /// Creates a request to the server, to start a Websocket connection,
390    /// on the path given.
391    ///
392    /// This is the requivalent of making a GET request to the endpoint,
393    /// and setting the various headers needed for making an upgrade request.
394    ///
395    /// *Note*, this requires the server to be running on a real HTTP
396    /// port. Either using a randomly assigned port, or a specified one.
397    /// See the [`TestServerConfig::transport`](crate::TestServerConfig::transport) for more details.
398    ///
399    /// # Example
400    ///
401    /// ```rust
402    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
403    /// #
404    /// use axum::Router;
405    /// use axum_test::TestServer;
406    ///
407    /// let app = Router::new();
408    /// let server = TestServer::builder()
409    ///     .http_transport()
410    ///     .build(app);
411    ///
412    /// let mut websocket = server
413    ///     .get_websocket(&"/my-web-socket-end-point")
414    ///     .await
415    ///     .into_websocket()
416    ///     .await;
417    ///
418    /// websocket.send_text("Hello!").await;
419    /// #
420    /// # Ok(()) }
421    /// ```
422    ///
423    #[cfg(feature = "ws")]
424    pub fn get_websocket(&self, path: &str) -> TestRequest {
425        use http::header;
426
427        self.get(path)
428            .add_header(header::CONNECTION, "upgrade")
429            .add_header(header::UPGRADE, "websocket")
430            .add_header(header::SEC_WEBSOCKET_VERSION, "13")
431            .add_header(
432                header::SEC_WEBSOCKET_KEY,
433                crate::internals::generate_ws_key(),
434            )
435    }
436
437    /// Creates a HTTP GET request, using the typed path provided.
438    ///
439    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
440    ///
441    /// # Example Test
442    ///
443    /// Using a `TypedPath` you can write build and test a route like below:
444    ///
445    /// ```rust
446    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
447    /// #
448    /// use axum::Json;
449    /// use axum::Router;
450    /// use axum::routing::get;
451    /// use axum_extra::routing::RouterExt;
452    /// use axum_extra::routing::TypedPath;
453    /// use serde::Deserialize;
454    /// use serde::Serialize;
455    ///
456    /// use axum_test::TestServer;
457    ///
458    /// #[derive(TypedPath, Deserialize)]
459    /// #[typed_path("/users/{user_id}")]
460    /// struct UserPath {
461    ///     pub user_id: u32,
462    /// }
463    ///
464    /// // Build a typed route:
465    /// async fn route_get_user(UserPath { user_id }: UserPath) -> String {
466    ///     format!("hello user {user_id}")
467    /// }
468    ///
469    /// let app = Router::new()
470    ///     .typed_get(route_get_user);
471    ///
472    /// // Then test the route:
473    /// let server = TestServer::new(app);
474    /// server
475    ///     .typed_get(&UserPath { user_id: 123 })
476    ///     .await
477    ///     .assert_text("hello user 123");
478    /// #
479    /// # Ok(())
480    /// # }
481    /// ```
482    ///
483    #[cfg(feature = "typed-routing")]
484    pub fn typed_get<P>(&self, path: &P) -> TestRequest
485    where
486        P: TypedPath,
487    {
488        self.typed_method(Method::GET, path)
489    }
490
491    /// Creates a HTTP POST request, using the typed path provided.
492    ///
493    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
494    #[cfg(feature = "typed-routing")]
495    pub fn typed_post<P>(&self, path: &P) -> TestRequest
496    where
497        P: TypedPath,
498    {
499        self.typed_method(Method::POST, path)
500    }
501
502    /// Creates a HTTP PATCH request, using the typed path provided.
503    ///
504    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
505    #[cfg(feature = "typed-routing")]
506    pub fn typed_patch<P>(&self, path: &P) -> TestRequest
507    where
508        P: TypedPath,
509    {
510        self.typed_method(Method::PATCH, path)
511    }
512
513    /// Creates a HTTP PUT request, using the typed path provided.
514    ///
515    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
516    #[cfg(feature = "typed-routing")]
517    pub fn typed_put<P>(&self, path: &P) -> TestRequest
518    where
519        P: TypedPath,
520    {
521        self.typed_method(Method::PUT, path)
522    }
523
524    /// Creates a HTTP DELETE request, using the typed path provided.
525    ///
526    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
527    #[cfg(feature = "typed-routing")]
528    pub fn typed_delete<P>(&self, path: &P) -> TestRequest
529    where
530        P: TypedPath,
531    {
532        self.typed_method(Method::DELETE, path)
533    }
534
535    /// Creates a typed HTTP request, using the method provided.
536    ///
537    /// See [`axum-extra`](https://docs.rs/axum-extra) for full documentation on [`TypedPath`](axum_extra::routing::TypedPath).
538    #[cfg(feature = "typed-routing")]
539    pub fn typed_method<P>(&self, method: Method, path: &P) -> TestRequest
540    where
541        P: TypedPath,
542    {
543        self.method(method, &path.to_string())
544    }
545
546    /// Returns the local web address for the test server,
547    /// if an address is available.
548    ///
549    /// The address is available when running as a real web server,
550    /// by setting the [`TestServerConfig`](crate::TestServerConfig) `transport` field to `Transport::HttpRandomPort` or `Transport::HttpIpPort`.
551    ///
552    /// This will return `None` when there is mock HTTP transport (the default).
553    pub fn server_address(&self) -> Option<Url> {
554        self.url()
555    }
556
557    /// This turns a relative path, into an absolute path to the server.
558    /// i.e. A path like `/users/123` will become something like `http://127.0.0.1:1234/users/123`.
559    ///
560    /// The absolute address can be used to make requests to the running server,
561    /// using any appropriate client you wish.
562    ///
563    /// # Example
564    ///
565    /// ```rust
566    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
567    /// #
568    /// use axum::Router;
569    /// use axum_test::TestServer;
570    ///
571    /// let app = Router::new();
572    /// let server = TestServer::builder()
573    ///         .http_transport()
574    ///         .build(app);
575    ///
576    /// let full_url = server.server_url(&"/users/123?filter=enabled")?;
577    ///
578    /// // Prints something like ... http://127.0.0.1:1234/users/123?filter=enabled
579    /// println!("{full_url}");
580    /// #
581    /// # Ok(()) }
582    /// ```
583    ///
584    /// This will return an error if you are using the mock transport.
585    /// Real HTTP transport is required to use this method (see [`TestServerConfig`](crate::TestServerConfig) `transport` field).
586    ///
587    /// It will also return an error if you provide an absolute path,
588    /// for example if you pass in `http://google.com`.
589    pub fn server_url(&self, path: &str) -> Result<Url> {
590        let path_uri = path.parse::<Uri>()?;
591        if is_absolute_uri(&path_uri) {
592            return Err(anyhow!(
593                "Absolute path provided for building server url, need to provide a relative uri"
594            ));
595        }
596
597        let server_url = self.url()
598            .ok_or_else(||
599                anyhow!(
600                    "No local address for server, need to run with HTTP transport to have a server address",
601                )
602            )?;
603
604        let mut query_params = self.state.query_params().clone();
605        let mut full_server_url = build_url(
606            server_url,
607            path,
608            &mut query_params,
609            self.is_http_path_restricted,
610        )?;
611
612        // Ensure the query params are present
613        if query_params.has_content() {
614            full_server_url.set_query(Some(&query_params.to_string()));
615        }
616
617        Ok(full_server_url)
618    }
619
620    /// Adds a single cookie to be included on *all* future requests.
621    ///
622    /// If a cookie with the same name already exists,
623    /// then it will be replaced.
624    pub fn add_cookie(&mut self, cookie: Cookie) {
625        self.cookie_jar.add_cookie(cookie);
626    }
627
628    /// Adds extra cookies to be used on *all* future requests.
629    ///
630    /// Any cookies which have the same name as the new cookies,
631    /// will get replaced.
632    pub fn add_cookies(&mut self, cookies: CookieJar) {
633        self.cookie_jar.add_cookies_by_jar(cookies);
634    }
635
636    /// Clears all of the cookies stored internally.
637    pub fn clear_cookies(&mut self) {
638        self.cookie_jar.clear_cookies();
639    }
640
641    /// Requests made using this `TestServer` will save their cookies for future requests to send.
642    /// Including sharing cookies with requests made using the `reqwest` feature.
643    ///
644    /// This behaviour is off by default.
645    pub fn save_cookies(&mut self) {
646        self.cookie_jar.enable_saving();
647    }
648
649    /// Requests made using this `TestServer` will _not_ save their cookies for future requests to send up.
650    /// Including sharing cookies with requests made using the `reqwest` feature.
651    ///
652    /// This is the default behaviour.
653    pub fn do_not_save_cookies(&mut self) {
654        self.cookie_jar.disable_saving();
655    }
656
657    /// Requests made using this `TestServer` will assert a HTTP status in the 2xx range will be returned, unless marked otherwise.
658    ///
659    /// By default this behaviour is off.
660    pub fn expect_success(&mut self) {
661        self.expected_state = ExpectedState::Success;
662    }
663
664    /// Requests made using this `TestServer` will assert a HTTP status is outside the 2xx range will be returned, unless marked otherwise.
665    ///
666    /// By default this behaviour is off.
667    pub fn expect_failure(&mut self) {
668        self.expected_state = ExpectedState::Failure;
669    }
670
671    /// Adds a query parameter to be sent on *all* future requests.
672    pub fn add_query_param<V>(&mut self, key: &str, value: V)
673    where
674        V: Serialize,
675    {
676        self.state
677            .add_query_param(key, value)
678            .error_message("Failed to add query parameter");
679    }
680
681    /// Adds query parameters to be sent on *all* future requests.
682    pub fn add_query_params<V>(&mut self, query_params: V)
683    where
684        V: Serialize,
685    {
686        self.state
687            .add_query_params(query_params)
688            .error_message("Failed to add query parameters");
689    }
690
691    /// Adds a raw query param, with no urlencoding of any kind,
692    /// to be send on *all* future requests.
693    pub fn add_raw_query_param(&mut self, raw_query_param: &str) {
694        self.state.add_raw_query_param(raw_query_param);
695    }
696
697    /// Clears all query params set.
698    pub fn clear_query_params(&mut self) {
699        self.state.clear_query_params();
700    }
701
702    /// Adds a header to be sent with all future requests built from this `TestServer`.
703    ///
704    /// ```rust
705    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
706    /// #
707    /// use axum::Router;
708    /// use axum_test::TestServer;
709    ///
710    /// let app = Router::new();
711    /// let mut server = TestServer::new(app);
712    ///
713    /// server.add_header("x-custom-header", "custom-value");
714    /// server.add_header(http::header::CONTENT_LENGTH, 12345);
715    /// server.add_header(http::header::HOST, "example.com");
716    ///
717    /// let response = server.get(&"/my-end-point")
718    ///     .await;
719    /// #
720    /// # Ok(()) }
721    /// ```
722    pub fn add_header<N, V>(&mut self, name: N, value: V)
723    where
724        N: TryInto<HeaderName>,
725        N::Error: Debug,
726        V: TryInto<HeaderValue>,
727        V::Error: Debug,
728    {
729        let header_name: HeaderName = name
730            .try_into()
731            .expect("Failed to convert header name to HeaderName");
732        let header_value: HeaderValue = value
733            .try_into()
734            .expect("Failed to convert header vlue to HeaderValue");
735
736        self.state.add_header(header_name, header_value);
737    }
738
739    /// Clears all headers set so far.
740    pub fn clear_headers(&mut self) {
741        self.state.clear_headers();
742    }
743
744    pub(crate) fn url(&self) -> Option<Url> {
745        self.transport.url().cloned()
746    }
747
748    pub(crate) fn build_test_request_config(
749        &self,
750        method: Method,
751        path: &str,
752    ) -> Result<TestRequestConfig> {
753        let url = self
754            .url()
755            .unwrap_or_else(|| DEFAULT_URL_ADDRESS.parse().unwrap());
756
757        let mut query_params = self.state.query_params().clone();
758        let headers = self.state.headers().clone();
759        let full_request_url =
760            build_url(url, path, &mut query_params, self.is_http_path_restricted)?;
761
762        Ok(TestRequestConfig {
763            atomic_cookie_jar: self.cookie_jar.clone(),
764
765            // These are copied over from the cookie jar,
766            // as the server could change it's save state after the request is made.
767            is_saving_cookies: self.cookie_jar.is_saving(),
768            cookies: self.cookie_jar.to_cookie_jar(),
769
770            expected_state: self.expected_state,
771            content_type: self.default_content_type.clone(),
772            method,
773
774            full_request_url,
775            query_params,
776            headers,
777        })
778    }
779
780    /// Returns true or false if the underlying service inside the `TestServer`
781    /// is still running. For many types of services this will always return `true`.
782    ///
783    /// When a `TestServer` is built using [`axum::serve::WithGracefulShutdown`],
784    /// this will return false if the service has shutdown.
785    pub fn is_running(&self) -> bool {
786        self.transport.is_running()
787    }
788}
789
790fn build_url(
791    mut url: Url,
792    path: &str,
793    query_params: &mut QueryParamsStore,
794    is_http_restricted: bool,
795) -> Result<Url> {
796    let path_uri = path.parse::<Uri>()?;
797
798    // If there is a scheme, then this is an absolute path.
799    if let Some(scheme) = path_uri.scheme_str() {
800        if is_http_restricted {
801            if has_different_scheme(&url, &path_uri) || has_different_authority(&url, &path_uri) {
802                return Err(anyhow!(
803                    "Request disallowed for path '{path}', requests are only allowed to local server. Turn off 'restrict_requests_with_http_scheme' to change this."
804                ));
805            }
806        } else {
807            url.set_scheme(scheme)
808                .map_err(|_| anyhow!("Failed to set scheme for request, with path '{path}'"))?;
809
810            // We only set the host/port if the scheme is also present.
811            if let Some(authority) = path_uri.authority() {
812                url.set_host(Some(authority.host()))
813                    .map_err(|_| anyhow!("Failed to set host for request, with path '{path}'"))?;
814                url.set_port(authority.port().map(|p| p.as_u16()))
815                    .map_err(|_| anyhow!("Failed to set port for request, with path '{path}'"))?;
816
817                // todo, add username:password support
818            }
819        }
820    }
821
822    // Why does this exist?
823    //
824    // This exists to allow `server.get("/users")` and `server.get("users")` (without a slash)
825    // to go to the same place.
826    //
827    // It does this by saying ...
828    //  - if there is a scheme, it's a full path.
829    //  - if no scheme, it must be a path
830    //
831    if is_absolute_uri(&path_uri) {
832        url.set_path(path_uri.path());
833
834        // In this path we are replacing, so drop any query params on the original url.
835        if url.query().is_some() {
836            url.set_query(None);
837        }
838    } else {
839        // Grab everything up until the query parameters, or everything after that
840        let calculated_path = path.split('?').next().unwrap_or(path);
841        url.set_path(calculated_path);
842
843        // Move any query parameters from the url to the query params store.
844        if let Some(url_query) = url.query() {
845            query_params.add_raw(url_query.to_string());
846            url.set_query(None);
847        }
848    }
849
850    if let Some(path_query) = path_uri.query() {
851        query_params.add_raw(path_query.to_string());
852    }
853
854    Ok(url)
855}
856
857fn is_absolute_uri(path_uri: &Uri) -> bool {
858    path_uri.scheme_str().is_some()
859}
860
861fn has_different_scheme(base_url: &Url, path_uri: &Uri) -> bool {
862    if let Some(scheme) = path_uri.scheme_str() {
863        return scheme != base_url.scheme();
864    }
865
866    false
867}
868
869fn has_different_authority(base_url: &Url, path_uri: &Uri) -> bool {
870    if let Some(authority) = path_uri.authority() {
871        return authority.as_str() != base_url.authority();
872    }
873
874    false
875}
876
877#[cfg(test)]
878mod test_build_url {
879    use super::*;
880
881    #[test]
882    fn it_should_copy_path_to_url_returned_when_restricted() {
883        let base_url = "http://example.com".parse::<Url>().unwrap();
884        let path = "/users";
885        let mut query_params = QueryParamsStore::new();
886        let result = build_url(base_url, &path, &mut query_params, true).unwrap();
887
888        assert_eq!("http://example.com/users", result.as_str());
889        assert!(query_params.is_empty());
890    }
891
892    #[test]
893    fn it_should_copy_all_query_params_to_store_when_restricted() {
894        let base_url = "http://example.com?base=aaa".parse::<Url>().unwrap();
895        let path = "/users?path=bbb&path-flag";
896        let mut query_params = QueryParamsStore::new();
897        let result = build_url(base_url, &path, &mut query_params, true).unwrap();
898
899        assert_eq!("http://example.com/users", result.as_str());
900        assert_eq!("base=aaa&path=bbb&path-flag", query_params.to_string());
901    }
902
903    #[test]
904    fn it_should_not_replace_url_when_restricted_with_different_scheme() {
905        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
906        let path = "ftp://google.com:123/users.csv?limit=456";
907        let mut query_params = QueryParamsStore::new();
908        let result = build_url(base_url, &path, &mut query_params, true);
909
910        assert!(result.is_err());
911    }
912
913    #[test]
914    fn it_should_not_replace_url_when_restricted_with_same_scheme() {
915        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
916        let path = "http://google.com:123/users.csv?limit=456";
917        let mut query_params = QueryParamsStore::new();
918        let result = build_url(base_url, &path, &mut query_params, true);
919
920        assert!(result.is_err());
921    }
922
923    #[test]
924    fn it_should_block_url_when_restricted_with_same_scheme() {
925        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
926        let path = "http://google.com";
927        let mut query_params = QueryParamsStore::new();
928        let result = build_url(base_url, &path, &mut query_params, true);
929
930        assert!(result.is_err());
931    }
932
933    #[test]
934    fn it_should_block_url_when_restricted_and_same_domain_with_different_scheme() {
935        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
936        let path = "ftp://example.com/users";
937        let mut query_params = QueryParamsStore::new();
938        let result = build_url(base_url, &path, &mut query_params, true);
939
940        assert!(result.is_err());
941    }
942
943    #[test]
944    fn it_should_copy_path_to_url_returned_when_unrestricted() {
945        let base_url = "http://example.com".parse::<Url>().unwrap();
946        let path = "/users";
947        let mut query_params = QueryParamsStore::new();
948        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
949
950        assert_eq!("http://example.com/users", result.as_str());
951        assert!(query_params.is_empty());
952    }
953
954    #[test]
955    fn it_should_copy_all_query_params_to_store_when_unrestricted() {
956        let base_url = "http://example.com?base=aaa".parse::<Url>().unwrap();
957        let path = "/users?path=bbb&path-flag";
958        let mut query_params = QueryParamsStore::new();
959        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
960
961        assert_eq!("http://example.com/users", result.as_str());
962        assert_eq!("base=aaa&path=bbb&path-flag", query_params.to_string());
963    }
964
965    #[test]
966    fn it_should_copy_host_like_a_path_when_unrestricted() {
967        let base_url = "http://example.com".parse::<Url>().unwrap();
968        let path = "google.com";
969        let mut query_params = QueryParamsStore::new();
970        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
971
972        assert_eq!("http://example.com/google.com", result.as_str());
973        assert!(query_params.is_empty());
974    }
975
976    #[test]
977    fn it_should_copy_host_like_a_path_when_restricted() {
978        let base_url = "http://example.com".parse::<Url>().unwrap();
979        let path = "google.com";
980        let mut query_params = QueryParamsStore::new();
981        let result = build_url(base_url, &path, &mut query_params, true).unwrap();
982
983        assert_eq!("http://example.com/google.com", result.as_str());
984        assert!(query_params.is_empty());
985    }
986
987    #[test]
988    fn it_should_replace_url_when_unrestricted() {
989        let base_url = "http://example.com?base=666".parse::<Url>().unwrap();
990        let path = "ftp://google.com:123/users.csv?limit=456";
991        let mut query_params = QueryParamsStore::new();
992        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
993
994        assert_eq!("ftp://google.com:123/users.csv", result.as_str());
995        assert_eq!("limit=456", query_params.to_string());
996    }
997
998    #[test]
999    fn it_should_allow_different_scheme_when_unrestricted() {
1000        let base_url = "http://example.com".parse::<Url>().unwrap();
1001        let path = "ftp://example.com";
1002        let mut query_params = QueryParamsStore::new();
1003        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1004
1005        assert_eq!("ftp://example.com/", result.as_str());
1006    }
1007
1008    #[test]
1009    fn it_should_allow_different_host_when_unrestricted() {
1010        let base_url = "http://example.com".parse::<Url>().unwrap();
1011        let path = "http://google.com";
1012        let mut query_params = QueryParamsStore::new();
1013        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1014
1015        assert_eq!("http://google.com/", result.as_str());
1016    }
1017
1018    #[test]
1019    fn it_should_allow_different_port_when_unrestricted() {
1020        let base_url = "http://example.com:123".parse::<Url>().unwrap();
1021        let path = "http://example.com:456";
1022        let mut query_params = QueryParamsStore::new();
1023        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1024
1025        assert_eq!("http://example.com:456/", result.as_str());
1026    }
1027
1028    #[test]
1029    fn it_should_allow_same_host_port_when_unrestricted() {
1030        let base_url = "http://example.com:123".parse::<Url>().unwrap();
1031        let path = "http://example.com:123";
1032        let mut query_params = QueryParamsStore::new();
1033        let result = build_url(base_url, &path, &mut query_params, false).unwrap();
1034
1035        assert_eq!("http://example.com:123/", result.as_str());
1036    }
1037
1038    #[test]
1039    fn it_should_not_allow_different_scheme_when_restricted() {
1040        let base_url = "http://example.com".parse::<Url>().unwrap();
1041        let path = "ftp://example.com";
1042        let mut query_params = QueryParamsStore::new();
1043        let result = build_url(base_url, &path, &mut query_params, true);
1044
1045        assert!(result.is_err());
1046    }
1047
1048    #[test]
1049    fn it_should_not_allow_different_host_when_restricted() {
1050        let base_url = "http://example.com".parse::<Url>().unwrap();
1051        let path = "http://google.com";
1052        let mut query_params = QueryParamsStore::new();
1053        let result = build_url(base_url, &path, &mut query_params, true);
1054
1055        assert!(result.is_err());
1056    }
1057
1058    #[test]
1059    fn it_should_not_allow_different_port_when_restricted() {
1060        let base_url = "http://example.com:123".parse::<Url>().unwrap();
1061        let path = "http://example.com:456";
1062        let mut query_params = QueryParamsStore::new();
1063        let result = build_url(base_url, &path, &mut query_params, true);
1064
1065        assert!(result.is_err());
1066    }
1067
1068    #[test]
1069    fn it_should_allow_same_host_port_when_restricted() {
1070        let base_url = "http://example.com:123".parse::<Url>().unwrap();
1071        let path = "http://example.com:123";
1072        let mut query_params = QueryParamsStore::new();
1073        let result = build_url(base_url, &path, &mut query_params, true).unwrap();
1074
1075        assert_eq!("http://example.com:123/", result.as_str());
1076    }
1077}
1078
1079#[cfg(test)]
1080mod test_new {
1081    use axum::Router;
1082    use axum::routing::get;
1083    use std::net::SocketAddr;
1084
1085    use crate::TestServer;
1086
1087    async fn get_ping() -> &'static str {
1088        "pong!"
1089    }
1090
1091    #[tokio::test]
1092    async fn it_should_run_into_make_into_service_with_connect_info_by_default() {
1093        // Build an application with a route.
1094        let app = Router::new()
1095            .route("/ping", get(get_ping))
1096            .into_make_service_with_connect_info::<SocketAddr>();
1097
1098        // Run the server.
1099        let server = TestServer::new(app);
1100
1101        // Get the request.
1102        server.get(&"/ping").await.assert_text(&"pong!");
1103    }
1104}
1105
1106#[cfg(test)]
1107mod test_get {
1108    use super::*;
1109    use crate::testing::catch_panic_error_message;
1110    use axum::Router;
1111    use axum::routing::get;
1112    use pretty_assertions::assert_str_eq;
1113    use reserve_port::ReservedSocketAddr;
1114
1115    async fn get_ping() -> &'static str {
1116        "pong!"
1117    }
1118
1119    #[tokio::test]
1120    async fn it_should_get_using_relative_path_with_slash() {
1121        let app = Router::new().route("/ping", get(get_ping));
1122        let server = TestServer::new(app);
1123
1124        // Get the request _with_ slash
1125        server.get(&"/ping").await.assert_text(&"pong!");
1126    }
1127
1128    #[tokio::test]
1129    async fn it_should_get_using_relative_path_without_slash() {
1130        let app = Router::new().route("/ping", get(get_ping));
1131        let server = TestServer::new(app);
1132
1133        // Get the request _without_ slash
1134        server.get(&"ping").await.assert_text(&"pong!");
1135    }
1136
1137    #[tokio::test]
1138    async fn it_should_get_using_absolute_path() {
1139        // Build an application with a route.
1140        let app = Router::new().route("/ping", get(get_ping));
1141
1142        // Reserve an address
1143        let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1144        let ip = reserved_address.ip();
1145        let port = reserved_address.port();
1146
1147        // Run the server.
1148        let server = TestServer::builder()
1149            .http_transport_with_ip_port(Some(ip), Some(port))
1150            .try_build(app)
1151            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1152
1153        // Get the request.
1154        let absolute_url = format!("http://{ip}:{port}/ping");
1155        let response = server.get(&absolute_url).await;
1156
1157        response.assert_text(&"pong!");
1158        let request_path = response.request_url();
1159        assert_eq!(request_path.to_string(), format!("http://{ip}:{port}/ping"));
1160    }
1161
1162    #[tokio::test]
1163    async fn it_should_get_using_absolute_path_and_restricted_if_path_is_for_server() {
1164        // Build an application with a route.
1165        let app = Router::new().route("/ping", get(get_ping));
1166
1167        // Reserve an IP / Port
1168        let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1169        let ip = reserved_address.ip();
1170        let port = reserved_address.port();
1171
1172        // Run the server.
1173        let server = TestServer::builder()
1174            .http_transport_with_ip_port(Some(ip), Some(port))
1175            .restrict_requests_with_http_scheme() // Key part of the test!
1176            .try_build(app)
1177            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1178
1179        // Get the request.
1180        let absolute_url = format!("http://{ip}:{port}/ping");
1181        let response = server.get(&absolute_url).await;
1182
1183        response.assert_text(&"pong!");
1184        let request_path = response.request_url();
1185        assert_eq!(request_path.to_string(), format!("http://{ip}:{port}/ping"));
1186    }
1187
1188    #[tokio::test]
1189    async fn it_should_not_get_using_absolute_path_if_restricted_and_different_port() {
1190        // Build an application with a route.
1191        let app = Router::new().route("/ping", get(get_ping));
1192
1193        // Reserve an IP / Port
1194        let reserved_address = ReservedSocketAddr::reserve_random_socket_addr().unwrap();
1195        let ip = reserved_address.ip();
1196        let mut port = reserved_address.port();
1197
1198        // Run the server.
1199        let server = TestServer::builder()
1200            .http_transport_with_ip_port(Some(ip), Some(port))
1201            .restrict_requests_with_http_scheme() // Key part of the test!
1202            .try_build(app)
1203            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1204
1205        // Get the request.
1206        port += 1; // << Change the port to be off by one and not match the server
1207        let absolute_url = format!("http://{ip}:{port}/ping");
1208
1209        let message = catch_panic_error_message(|| {
1210            let _ = server.get(&absolute_url);
1211        });
1212
1213        let expected = format!("Failed to build request, for GET http://{ip}:{port}/ping,
1214    Request disallowed for path 'http://{ip}:{port}/ping', requests are only allowed to local server. Turn off 'restrict_requests_with_http_scheme' to change this.
1215");
1216        assert_str_eq!(expected, message);
1217    }
1218
1219    #[tokio::test]
1220    async fn it_should_work_in_parallel() {
1221        let app = Router::new().route("/ping", get(get_ping));
1222        let server = TestServer::new(app);
1223
1224        let future1 = async { server.get("/ping").await };
1225        let future2 = async { server.get("/ping").await };
1226        let (r1, r2) = tokio::join!(future1, future2);
1227
1228        assert_eq!(r1.text(), r2.text());
1229    }
1230
1231    #[tokio::test]
1232    async fn it_should_work_in_parallel_with_sleeping_requests() {
1233        let app = axum::Router::new().route(
1234            &"/slow",
1235            axum::routing::get(|| async {
1236                tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1237                "hello!"
1238            }),
1239        );
1240
1241        let server = TestServer::new(app);
1242
1243        let future1 = async { server.get("/slow").await };
1244        let future2 = async { server.get("/slow").await };
1245        let (r1, r2) = tokio::join!(future1, future2);
1246
1247        assert_eq!(r1.text(), r2.text());
1248    }
1249}
1250
1251#[cfg(feature = "reqwest")]
1252#[cfg(test)]
1253mod test_reqwest_get {
1254    use super::*;
1255    use axum::Router;
1256    use axum::routing::get;
1257
1258    async fn get_ping() -> &'static str {
1259        "pong!"
1260    }
1261
1262    #[tokio::test]
1263    async fn it_should_get_using_relative_path_with_slash() {
1264        let app = Router::new().route("/ping", get(get_ping));
1265        let server = TestServer::builder().http_transport().build(app);
1266
1267        let response = server
1268            .reqwest_get(&"/ping")
1269            .send()
1270            .await
1271            .unwrap()
1272            .text()
1273            .await
1274            .unwrap();
1275
1276        assert_eq!(response, "pong!");
1277    }
1278}
1279
1280#[cfg(feature = "reqwest")]
1281#[cfg(test)]
1282mod test_reqwest_post {
1283    use super::*;
1284    use axum::Json;
1285    use axum::Router;
1286    use axum::routing::post;
1287    use serde::Deserialize;
1288
1289    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1290    struct TestBody {
1291        number: u32,
1292        text: String,
1293    }
1294
1295    async fn post_json(Json(body): Json<TestBody>) -> Json<TestBody> {
1296        let response = TestBody {
1297            number: body.number * 2,
1298            text: format!("{}_plus_response", body.text),
1299        };
1300
1301        Json(response)
1302    }
1303
1304    #[tokio::test]
1305    async fn it_should_post_and_receive_json() {
1306        let app = Router::new().route("/json", post(post_json));
1307        let server = TestServer::builder().http_transport().build(app);
1308
1309        let response = server
1310            .reqwest_post(&"/json")
1311            .json(&TestBody {
1312                number: 111,
1313                text: format!("request"),
1314            })
1315            .send()
1316            .await
1317            .unwrap()
1318            .json::<TestBody>()
1319            .await
1320            .unwrap();
1321
1322        assert_eq!(
1323            response,
1324            TestBody {
1325                number: 222,
1326                text: format!("request_plus_response"),
1327            }
1328        );
1329    }
1330}
1331
1332#[cfg(test)]
1333mod test_server_address {
1334    use super::*;
1335    use axum::Router;
1336    use regex::Regex;
1337    use reserve_port::ReservedPort;
1338    use std::net::Ipv4Addr;
1339
1340    #[tokio::test]
1341    async fn it_should_return_address_used_from_config() {
1342        let reserved_port = ReservedPort::random().unwrap();
1343        let ip = Ipv4Addr::LOCALHOST.into();
1344        let port = reserved_port.port();
1345
1346        // Build an application with a route.
1347        let app = Router::new();
1348        let server = TestServer::builder()
1349            .http_transport_with_ip_port(Some(ip), Some(port))
1350            .try_build(app)
1351            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1352
1353        let expected_ip_port = format!("http://{}:{}/", ip, reserved_port.port());
1354        assert_eq!(
1355            server.server_address().unwrap().to_string(),
1356            expected_ip_port
1357        );
1358    }
1359
1360    #[tokio::test]
1361    async fn it_should_return_default_address_without_ending_slash() {
1362        let app = Router::new();
1363        let server = TestServer::builder().http_transport().build(app);
1364
1365        let address_regex = Regex::new("^http://127\\.0\\.0\\.1:[0-9]+/$").unwrap();
1366        let is_match = address_regex.is_match(&server.server_address().unwrap().to_string());
1367        assert!(is_match);
1368    }
1369
1370    #[tokio::test]
1371    async fn it_should_return_none_on_mock_transport() {
1372        let app = Router::new();
1373        let server = TestServer::builder().mock_transport().build(app);
1374
1375        assert!(server.server_address().is_none());
1376    }
1377}
1378
1379#[cfg(test)]
1380mod test_server_url {
1381    use super::*;
1382    use axum::Router;
1383    use pretty_assertions::assert_str_eq;
1384    use regex::Regex;
1385    use reserve_port::ReservedPort;
1386    use std::net::Ipv4Addr;
1387
1388    #[tokio::test]
1389    async fn it_should_return_address_with_url_on_http_ip_port() {
1390        let reserved_port = ReservedPort::random().unwrap();
1391        let ip = Ipv4Addr::LOCALHOST.into();
1392        let port = reserved_port.port();
1393
1394        // Build an application with a route.
1395        let app = Router::new();
1396        let server = TestServer::builder()
1397            .http_transport_with_ip_port(Some(ip), Some(port))
1398            .try_build(app)
1399            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1400
1401        let expected_ip_port_url = format!("http://{}:{}/users", ip, reserved_port.port());
1402        let absolute_url = server.server_url("/users").unwrap().to_string();
1403        assert_eq!(expected_ip_port_url, absolute_url);
1404    }
1405
1406    #[tokio::test]
1407    async fn it_should_return_address_with_url_on_random_http() {
1408        let app = Router::new();
1409        let server = TestServer::builder().http_transport().build(app);
1410
1411        let address_regex =
1412            Regex::new("^http://127\\.0\\.0\\.1:[0-9]+/users/123\\?filter=enabled$").unwrap();
1413        let absolute_url = &server
1414            .server_url(&"/users/123?filter=enabled")
1415            .unwrap()
1416            .to_string();
1417
1418        let is_match = address_regex.is_match(absolute_url);
1419        assert!(is_match);
1420    }
1421
1422    #[tokio::test]
1423    async fn it_should_error_on_mock_transport() {
1424        // Build an application with a route.
1425        let app = Router::new();
1426        let server = TestServer::builder().mock_transport().build(app);
1427
1428        let result = server.server_url("/users");
1429        assert!(result.is_err());
1430    }
1431
1432    #[tokio::test]
1433    async fn it_should_include_path_query_params() {
1434        let reserved_port = ReservedPort::random().unwrap();
1435        let ip = Ipv4Addr::LOCALHOST.into();
1436        let port = reserved_port.port();
1437
1438        // Build an application with a route.
1439        let app = Router::new();
1440        let server = TestServer::builder()
1441            .http_transport_with_ip_port(Some(ip), Some(port))
1442            .try_build(app)
1443            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1444
1445        let expected_url = format!(
1446            "http://{}:{}/users?filter=enabled",
1447            ip,
1448            reserved_port.port()
1449        );
1450        let received_url = server
1451            .server_url("/users?filter=enabled")
1452            .unwrap()
1453            .to_string();
1454
1455        assert_eq!(expected_url, received_url);
1456    }
1457
1458    #[tokio::test]
1459    async fn it_should_include_server_query_params() {
1460        let reserved_port = ReservedPort::random().unwrap();
1461        let ip = Ipv4Addr::LOCALHOST.into();
1462        let port = reserved_port.port();
1463
1464        // Build an application with a route.
1465        let app = Router::new();
1466        let mut server = TestServer::builder()
1467            .http_transport_with_ip_port(Some(ip), Some(port))
1468            .try_build(app)
1469            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1470
1471        server.add_query_param("filter", "enabled");
1472
1473        let expected_url = format!(
1474            "http://{}:{}/users?filter=enabled",
1475            ip,
1476            reserved_port.port()
1477        );
1478        let received_url = server.server_url("/users").unwrap().to_string();
1479
1480        assert_eq!(expected_url, received_url);
1481    }
1482
1483    #[tokio::test]
1484    async fn it_should_include_server_and_path_query_params() {
1485        let reserved_port = ReservedPort::random().unwrap();
1486        let ip = Ipv4Addr::LOCALHOST.into();
1487        let port = reserved_port.port();
1488
1489        // Build an application with a route.
1490        let app = Router::new();
1491        let mut server = TestServer::builder()
1492            .http_transport_with_ip_port(Some(ip), Some(port))
1493            .try_build(app)
1494            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1495
1496        server.add_query_param("filter", "enabled");
1497
1498        let expected_url = format!(
1499            "http://{}:{}/users?filter=enabled&animal=donkeys",
1500            ip,
1501            reserved_port.port()
1502        );
1503        let received_url = server
1504            .server_url("/users?animal=donkeys")
1505            .unwrap()
1506            .to_string();
1507
1508        assert_eq!(expected_url, received_url);
1509    }
1510
1511    #[tokio::test]
1512    async fn it_should_include_both_server_and_path_queries() {
1513        let reserved_port = ReservedPort::random().unwrap();
1514        let ip = Ipv4Addr::LOCALHOST.into();
1515        let port = reserved_port.port();
1516
1517        // Build an application with a route.
1518        let app = Router::new();
1519        let mut server = TestServer::builder()
1520            .http_transport_with_ip_port(Some(ip), Some(port))
1521            .try_build(app)
1522            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1523
1524        server.add_query_param("query", "server");
1525
1526        let expected_url = format!(
1527            "http://{}:{}/users?query=server&query=path",
1528            ip,
1529            reserved_port.port()
1530        );
1531        let received_url = server.server_url("/users?query=path").unwrap().to_string();
1532
1533        assert_eq!(expected_url, received_url);
1534    }
1535
1536    #[tokio::test]
1537    async fn it_should_work_for_paths_with_leading_slash() {
1538        let reserved_port = ReservedPort::random().unwrap();
1539        let ip = Ipv4Addr::LOCALHOST.into();
1540        let port = reserved_port.port();
1541
1542        // Build an application with a route.
1543        let app = Router::new();
1544        let server = TestServer::builder()
1545            .http_transport_with_ip_port(Some(ip), Some(port))
1546            .try_build(app)
1547            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1548
1549        let expected_url = format!("http://{}:{}/users", ip, reserved_port.port());
1550        let received_url = server.server_url("users").unwrap().to_string();
1551
1552        assert_eq!(expected_url, received_url);
1553    }
1554
1555    // TODO, change this behaviour to allow an empty path. It should be the same as no path at all.
1556    #[tokio::test]
1557    async fn it_should_panic_when_provided_an_empty_path() {
1558        let reserved_port = ReservedPort::random().unwrap();
1559        let ip = Ipv4Addr::LOCALHOST.into();
1560        let port = reserved_port.port();
1561
1562        // Build an application with a route.
1563        let app = Router::new();
1564        let server = TestServer::builder()
1565            .http_transport_with_ip_port(Some(ip), Some(port))
1566            .try_build(app)
1567            .error_message_fn(|| format!("Should create test server with address {}:{}", ip, port));
1568
1569        // let expected_url = format!("http://{}:{}", ip, reserved_port.port());
1570        let error_message = server.server_url("").unwrap_err().to_string();
1571
1572        assert_str_eq!("empty string", error_message);
1573    }
1574}
1575
1576#[cfg(test)]
1577mod test_add_cookie {
1578    use crate::TestServer;
1579    use axum::Router;
1580    use axum::routing::get;
1581    use axum_extra::extract::cookie::CookieJar;
1582    use cookie::Cookie;
1583
1584    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
1585
1586    async fn get_cookie(cookies: CookieJar) -> (CookieJar, String) {
1587        let cookie = cookies.get(&TEST_COOKIE_NAME);
1588        let cookie_value = cookie
1589            .map(|c| c.value().to_string())
1590            .unwrap_or_else(|| "cookie-not-found".to_string());
1591
1592        (cookies, cookie_value)
1593    }
1594
1595    #[tokio::test]
1596    async fn it_should_send_cookies_added_to_request() {
1597        let app = Router::new().route("/cookie", get(get_cookie));
1598        let mut server = TestServer::new(app);
1599
1600        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1601        server.add_cookie(cookie);
1602
1603        let response_text = server.get(&"/cookie").await.text();
1604        assert_eq!(response_text, "my-custom-cookie");
1605    }
1606}
1607
1608#[cfg(test)]
1609mod test_add_cookies {
1610    use crate::TestServer;
1611
1612    use axum::Router;
1613    use axum::routing::get;
1614    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1615    use cookie::Cookie;
1616    use cookie::CookieJar;
1617
1618    async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1619        let mut all_cookies = cookies
1620            .iter()
1621            .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1622            .collect::<Vec<String>>();
1623        all_cookies.sort();
1624
1625        all_cookies.join(&", ")
1626    }
1627
1628    #[tokio::test]
1629    async fn it_should_send_all_cookies_added_by_jar() {
1630        let app = Router::new().route("/cookies", get(route_get_cookies));
1631        let mut server = TestServer::new(app);
1632
1633        // Build cookies to send up
1634        let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1635        let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1636        let mut cookie_jar = CookieJar::new();
1637        cookie_jar.add(cookie_1);
1638        cookie_jar.add(cookie_2);
1639
1640        server.add_cookies(cookie_jar);
1641
1642        server
1643            .get(&"/cookies")
1644            .await
1645            .assert_text("first-cookie=my-custom-cookie, second-cookie=other-cookie");
1646    }
1647}
1648
1649#[cfg(test)]
1650mod test_clear_cookies {
1651    use crate::TestServer;
1652
1653    use axum::Router;
1654    use axum::routing::get;
1655    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1656    use cookie::Cookie;
1657    use cookie::CookieJar;
1658
1659    async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1660        let mut all_cookies = cookies
1661            .iter()
1662            .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1663            .collect::<Vec<String>>();
1664        all_cookies.sort();
1665
1666        all_cookies.join(&", ")
1667    }
1668
1669    #[tokio::test]
1670    async fn it_should_not_send_cookies_cleared() {
1671        let app = Router::new().route("/cookies", get(route_get_cookies));
1672        let mut server = TestServer::new(app);
1673
1674        let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1675        let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1676        let mut cookie_jar = CookieJar::new();
1677        cookie_jar.add(cookie_1);
1678        cookie_jar.add(cookie_2);
1679
1680        server.add_cookies(cookie_jar);
1681
1682        // The important bit of this test
1683        server.clear_cookies();
1684
1685        server.get(&"/cookies").await.assert_text("");
1686    }
1687}
1688
1689#[cfg(test)]
1690mod test_add_header {
1691    use super::*;
1692    use crate::TestServer;
1693    use axum::Router;
1694    use axum::extract::FromRequestParts;
1695    use axum::routing::get;
1696    use http::HeaderName;
1697    use http::HeaderValue;
1698    use http::request::Parts;
1699    use hyper::StatusCode;
1700    use std::marker::Sync;
1701
1702    const TEST_HEADER_NAME: &'static str = &"test-header";
1703    const TEST_HEADER_CONTENT: &'static str = &"Test header content";
1704
1705    struct TestHeader(Vec<u8>);
1706
1707    impl<S: Sync> FromRequestParts<S> for TestHeader {
1708        type Rejection = (StatusCode, &'static str);
1709
1710        async fn from_request_parts(
1711            parts: &mut Parts,
1712            _state: &S,
1713        ) -> Result<TestHeader, Self::Rejection> {
1714            parts
1715                .headers
1716                .get(HeaderName::from_static(TEST_HEADER_NAME))
1717                .map(|v| TestHeader(v.as_bytes().to_vec()))
1718                .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
1719        }
1720    }
1721
1722    async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
1723        header
1724    }
1725
1726    #[tokio::test]
1727    async fn it_should_send_header_added_to_server() {
1728        // Build an application with a route.
1729        let app = Router::new().route("/header", get(ping_header));
1730
1731        // Run the server.
1732        let mut server = TestServer::new(app);
1733        server.add_header(
1734            HeaderName::from_static(TEST_HEADER_NAME),
1735            HeaderValue::from_static(TEST_HEADER_CONTENT),
1736        );
1737
1738        // Send a request with the header
1739        let response = server.get(&"/header").await;
1740
1741        // Check it sent back the right text
1742        response.assert_text(TEST_HEADER_CONTENT);
1743    }
1744}
1745
1746#[cfg(test)]
1747mod test_clear_headers {
1748    use super::*;
1749    use crate::TestServer;
1750    use axum::Router;
1751    use axum::extract::FromRequestParts;
1752    use axum::routing::get;
1753    use http::HeaderName;
1754    use http::HeaderValue;
1755    use http::request::Parts;
1756    use hyper::StatusCode;
1757    use std::marker::Sync;
1758
1759    const TEST_HEADER_NAME: &'static str = &"test-header";
1760    const TEST_HEADER_CONTENT: &'static str = &"Test header content";
1761
1762    struct TestHeader(Vec<u8>);
1763
1764    impl<S: Sync> FromRequestParts<S> for TestHeader {
1765        type Rejection = (StatusCode, &'static str);
1766
1767        async fn from_request_parts(
1768            parts: &mut Parts,
1769            _state: &S,
1770        ) -> Result<Self, Self::Rejection> {
1771            parts
1772                .headers
1773                .get(HeaderName::from_static(TEST_HEADER_NAME))
1774                .map(|v| TestHeader(v.as_bytes().to_vec()))
1775                .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
1776        }
1777    }
1778
1779    async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
1780        header
1781    }
1782
1783    #[tokio::test]
1784    async fn it_should_not_send_headers_cleared_by_server() {
1785        // Build an application with a route.
1786        let app = Router::new().route("/header", get(ping_header));
1787
1788        // Run the server.
1789        let mut server = TestServer::new(app);
1790        server.add_header(
1791            HeaderName::from_static(TEST_HEADER_NAME),
1792            HeaderValue::from_static(TEST_HEADER_CONTENT),
1793        );
1794        server.clear_headers();
1795
1796        // Send a request with the header
1797        let response = server.get(&"/header").await;
1798
1799        // Check it sent back the right text
1800        response.assert_status_bad_request();
1801        response.assert_text("Missing test header");
1802    }
1803}
1804
1805#[cfg(test)]
1806mod test_add_query_params {
1807    use axum::Router;
1808    use axum::extract::Query;
1809    use axum::routing::get;
1810
1811    use serde::Deserialize;
1812    use serde::Serialize;
1813    use serde_json::json;
1814
1815    use crate::TestServer;
1816
1817    #[derive(Debug, Deserialize, Serialize)]
1818    struct QueryParam {
1819        message: String,
1820    }
1821
1822    async fn get_query_param(Query(params): Query<QueryParam>) -> String {
1823        params.message
1824    }
1825
1826    #[derive(Debug, Deserialize, Serialize)]
1827    struct QueryParam2 {
1828        message: String,
1829        other: String,
1830    }
1831
1832    async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
1833        format!("{}-{}", params.message, params.other)
1834    }
1835
1836    #[tokio::test]
1837    async fn it_should_pass_up_query_params_from_serialization() {
1838        // Build an application with a route.
1839        let app = Router::new().route("/query", get(get_query_param));
1840
1841        // Run the server.
1842        let mut server = TestServer::new(app);
1843        server.add_query_params(QueryParam {
1844            message: "it works".to_string(),
1845        });
1846
1847        // Get the request.
1848        server.get(&"/query").await.assert_text(&"it works");
1849    }
1850
1851    #[tokio::test]
1852    async fn it_should_pass_up_query_params_from_pairs() {
1853        // Build an application with a route.
1854        let app = Router::new().route("/query", get(get_query_param));
1855
1856        // Run the server.
1857        let mut server = TestServer::new(app);
1858        server.add_query_params(&[("message", "it works")]);
1859
1860        // Get the request.
1861        server.get(&"/query").await.assert_text(&"it works");
1862    }
1863
1864    #[tokio::test]
1865    async fn it_should_pass_up_multiple_query_params_from_multiple_params() {
1866        // Build an application with a route.
1867        let app = Router::new().route("/query-2", get(get_query_param_2));
1868
1869        // Run the server.
1870        let mut server = TestServer::new(app);
1871        server.add_query_params(&[("message", "it works"), ("other", "yup")]);
1872
1873        // Get the request.
1874        server.get(&"/query-2").await.assert_text(&"it works-yup");
1875    }
1876
1877    #[tokio::test]
1878    async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
1879        // Build an application with a route.
1880        let app = Router::new().route("/query-2", get(get_query_param_2));
1881
1882        // Run the server.
1883        let mut server = TestServer::new(app);
1884        server.add_query_params(&[("message", "it works")]);
1885        server.add_query_params(&[("other", "yup")]);
1886
1887        // Get the request.
1888        server.get(&"/query-2").await.assert_text(&"it works-yup");
1889    }
1890
1891    #[tokio::test]
1892    async fn it_should_pass_up_multiple_query_params_from_json() {
1893        // Build an application with a route.
1894        let app = Router::new().route("/query-2", get(get_query_param_2));
1895
1896        // Run the server.
1897        let mut server = TestServer::new(app);
1898        server.add_query_params(json!({
1899            "message": "it works",
1900            "other": "yup"
1901        }));
1902
1903        // Get the request.
1904        server.get(&"/query-2").await.assert_text(&"it works-yup");
1905    }
1906}
1907
1908#[cfg(test)]
1909mod test_add_query_param {
1910    use axum::Router;
1911    use axum::extract::Query;
1912    use axum::routing::get;
1913
1914    use serde::Deserialize;
1915    use serde::Serialize;
1916
1917    use crate::TestServer;
1918
1919    #[derive(Debug, Deserialize, Serialize)]
1920    struct QueryParam {
1921        message: String,
1922    }
1923
1924    async fn get_query_param(Query(params): Query<QueryParam>) -> String {
1925        params.message
1926    }
1927
1928    #[derive(Debug, Deserialize, Serialize)]
1929    struct QueryParam2 {
1930        message: String,
1931        other: String,
1932    }
1933
1934    async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
1935        format!("{}-{}", params.message, params.other)
1936    }
1937
1938    #[tokio::test]
1939    async fn it_should_pass_up_query_params_from_pairs() {
1940        // Build an application with a route.
1941        let app = Router::new().route("/query", get(get_query_param));
1942
1943        // Run the server.
1944        let mut server = TestServer::new(app);
1945        server.add_query_param("message", "it works");
1946
1947        // Get the request.
1948        server.get(&"/query").await.assert_text(&"it works");
1949    }
1950
1951    #[tokio::test]
1952    async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
1953        // Build an application with a route.
1954        let app = Router::new().route("/query-2", get(get_query_param_2));
1955
1956        // Run the server.
1957        let mut server = TestServer::new(app);
1958        server.add_query_param("message", "it works");
1959        server.add_query_param("other", "yup");
1960
1961        // Get the request.
1962        server.get(&"/query-2").await.assert_text(&"it works-yup");
1963    }
1964
1965    #[tokio::test]
1966    async fn it_should_pass_up_multiple_query_params_from_calls_across_server_and_request() {
1967        // Build an application with a route.
1968        let app = Router::new().route("/query-2", get(get_query_param_2));
1969
1970        // Run the server.
1971        let mut server = TestServer::new(app);
1972        server.add_query_param("message", "it works");
1973
1974        // Get the request.
1975        server
1976            .get(&"/query-2")
1977            .add_query_param("other", "yup")
1978            .await
1979            .assert_text(&"it works-yup");
1980    }
1981}
1982
1983#[cfg(test)]
1984mod test_add_raw_query_param {
1985    use axum::Router;
1986    use axum::extract::Query as AxumStdQuery;
1987    use axum::routing::get;
1988    use axum_extra::extract::Query as AxumExtraQuery;
1989    use serde::Deserialize;
1990    use serde::Serialize;
1991    use std::fmt::Write;
1992
1993    use crate::TestServer;
1994
1995    #[derive(Debug, Deserialize, Serialize)]
1996    struct QueryParam {
1997        message: String,
1998    }
1999
2000    async fn get_query_param(AxumStdQuery(params): AxumStdQuery<QueryParam>) -> String {
2001        params.message
2002    }
2003
2004    #[derive(Debug, Deserialize, Serialize)]
2005    struct QueryParamExtra {
2006        #[serde(default)]
2007        items: Vec<String>,
2008
2009        #[serde(default, rename = "arrs[]")]
2010        arrs: Vec<String>,
2011    }
2012
2013    async fn get_query_param_extra(
2014        AxumExtraQuery(params): AxumExtraQuery<QueryParamExtra>,
2015    ) -> String {
2016        let mut output = String::new();
2017
2018        if params.items.len() > 0 {
2019            write!(output, "{}", params.items.join(", ")).unwrap();
2020        }
2021
2022        if params.arrs.len() > 0 {
2023            write!(output, "{}", params.arrs.join(", ")).unwrap();
2024        }
2025
2026        output
2027    }
2028
2029    fn build_app() -> Router {
2030        Router::new()
2031            .route("/query", get(get_query_param))
2032            .route("/query-extra", get(get_query_param_extra))
2033    }
2034
2035    #[tokio::test]
2036    async fn it_should_pass_up_query_param_as_is() {
2037        // Run the server.
2038        let mut server = TestServer::new(build_app());
2039        server.add_raw_query_param(&"message=it-works");
2040
2041        // Get the request.
2042        server.get(&"/query").await.assert_text(&"it-works");
2043    }
2044
2045    #[tokio::test]
2046    async fn it_should_pass_up_array_query_params_as_one_string() {
2047        // Run the server.
2048        let mut server = TestServer::new(build_app());
2049        server.add_raw_query_param(&"items=one&items=two&items=three");
2050
2051        // Get the request.
2052        server
2053            .get(&"/query-extra")
2054            .await
2055            .assert_text(&"one, two, three");
2056    }
2057
2058    #[tokio::test]
2059    async fn it_should_pass_up_array_query_params_as_multiple_params() {
2060        // Run the server.
2061        let mut server = TestServer::new(build_app());
2062        server.add_raw_query_param(&"arrs[]=one");
2063        server.add_raw_query_param(&"arrs[]=two");
2064        server.add_raw_query_param(&"arrs[]=three");
2065
2066        // Get the request.
2067        server
2068            .get(&"/query-extra")
2069            .await
2070            .assert_text(&"one, two, three");
2071    }
2072}
2073
2074#[cfg(test)]
2075mod test_clear_query_params {
2076    use axum::Router;
2077    use axum::extract::Query;
2078    use axum::routing::get;
2079
2080    use serde::Deserialize;
2081    use serde::Serialize;
2082
2083    use crate::TestServer;
2084
2085    #[derive(Debug, Deserialize, Serialize)]
2086    struct QueryParams {
2087        first: Option<String>,
2088        second: Option<String>,
2089    }
2090
2091    async fn get_query_params(Query(params): Query<QueryParams>) -> String {
2092        format!(
2093            "has first? {}, has second? {}",
2094            params.first.is_some(),
2095            params.second.is_some()
2096        )
2097    }
2098
2099    #[tokio::test]
2100    async fn it_should_clear_all_params_set() {
2101        // Build an application with a route.
2102        let app = Router::new().route("/query", get(get_query_params));
2103
2104        // Run the server.
2105        let mut server = TestServer::new(app);
2106        server.add_query_params(QueryParams {
2107            first: Some("first".to_string()),
2108            second: Some("second".to_string()),
2109        });
2110        server.clear_query_params();
2111
2112        // Get the request.
2113        server
2114            .get(&"/query")
2115            .await
2116            .assert_text(&"has first? false, has second? false");
2117    }
2118
2119    #[tokio::test]
2120    async fn it_should_clear_all_params_set_and_allow_replacement() {
2121        // Build an application with a route.
2122        let app = Router::new().route("/query", get(get_query_params));
2123
2124        // Run the server.
2125        let mut server = TestServer::new(app);
2126        server.add_query_params(QueryParams {
2127            first: Some("first".to_string()),
2128            second: Some("second".to_string()),
2129        });
2130        server.clear_query_params();
2131        server.add_query_params(QueryParams {
2132            first: Some("first".to_string()),
2133            second: Some("second".to_string()),
2134        });
2135
2136        // Get the request.
2137        server
2138            .get(&"/query")
2139            .await
2140            .assert_text(&"has first? true, has second? true");
2141    }
2142}
2143
2144#[cfg(test)]
2145mod test_expect_success_by_default {
2146    use super::*;
2147    use crate::testing::catch_panic_error_message_async;
2148    use axum::Router;
2149    use axum::routing::get;
2150    use pretty_assertions::assert_str_eq;
2151
2152    #[tokio::test]
2153    async fn it_should_not_panic_by_default_if_accessing_404_route() {
2154        let app = Router::new();
2155        let server = TestServer::new(app);
2156
2157        server.get(&"/some_unknown_route").await;
2158    }
2159
2160    #[tokio::test]
2161    async fn it_should_not_panic_by_default_if_accessing_200_route() {
2162        let app = Router::new().route("/known_route", get(|| async { "🦊🦊🦊" }));
2163        let server = TestServer::new(app);
2164
2165        server.get(&"/known_route").await;
2166    }
2167
2168    #[tokio::test]
2169    async fn it_should_panic_by_default_if_accessing_404_route_and_expect_success_on() {
2170        let app = Router::new();
2171        let server = TestServer::builder().expect_success_by_default().build(app);
2172
2173        let message = catch_panic_error_message_async(server.get(&"/some_unknown_route")).await;
2174        assert_str_eq!(
2175            "Expect status code within 2xx range, received 404 (Not Found), for request GET http://localhost/some_unknown_route, with body ''",
2176            message
2177        );
2178    }
2179
2180    #[tokio::test]
2181    async fn it_should_not_panic_by_default_if_accessing_200_route_and_expect_success_on() {
2182        let app = Router::new().route("/known_route", get(|| async { "🦊🦊🦊" }));
2183        let server = TestServer::builder().expect_success_by_default().build(app);
2184
2185        server.get(&"/known_route").await;
2186    }
2187}
2188
2189#[cfg(test)]
2190mod test_content_type {
2191    use super::*;
2192    use axum::Router;
2193    use axum::routing::get;
2194    use http::HeaderMap;
2195    use http::header::CONTENT_TYPE;
2196
2197    async fn get_content_type(headers: HeaderMap) -> String {
2198        headers
2199            .get(CONTENT_TYPE)
2200            .map(|h| h.to_str().unwrap().to_string())
2201            .unwrap_or_else(|| "".to_string())
2202    }
2203
2204    #[tokio::test]
2205    async fn it_should_default_to_server_content_type_when_present() {
2206        // Build an application with a route.
2207        let app = Router::new().route("/content_type", get(get_content_type));
2208
2209        // Run the server.
2210        let server = TestServer::builder()
2211            .default_content_type("text/plain")
2212            .build(app);
2213
2214        // Get the request.
2215        let text = server.get(&"/content_type").await.text();
2216
2217        assert_eq!(text, "text/plain");
2218    }
2219}
2220
2221#[cfg(test)]
2222mod test_expect_success {
2223    use crate::TestServer;
2224    use crate::testing::catch_panic_error_message_async;
2225    use axum::Router;
2226    use axum::routing::get;
2227    use http::StatusCode;
2228    use pretty_assertions::assert_str_eq;
2229
2230    #[tokio::test]
2231    async fn it_should_not_panic_if_success_is_returned() {
2232        async fn get_ping() -> &'static str {
2233            "pong!"
2234        }
2235
2236        // Build an application with a route.
2237        let app = Router::new().route("/ping", get(get_ping));
2238
2239        // Run the server.
2240        let mut server = TestServer::new(app);
2241        server.expect_success();
2242
2243        // Get the request.
2244        server.get(&"/ping").await;
2245    }
2246
2247    #[tokio::test]
2248    async fn it_should_not_panic_on_other_2xx_status_code() {
2249        async fn get_accepted() -> StatusCode {
2250            StatusCode::ACCEPTED
2251        }
2252
2253        // Build an application with a route.
2254        let app = Router::new().route("/accepted", get(get_accepted));
2255
2256        // Run the server.
2257        let mut server = TestServer::new(app);
2258        server.expect_success();
2259
2260        // Get the request.
2261        server.get(&"/accepted").await;
2262    }
2263
2264    #[tokio::test]
2265    async fn it_should_panic_on_404() {
2266        // Build an application with a route.
2267        let app = Router::new();
2268
2269        // Run the server.
2270        let mut server = TestServer::new(app);
2271        server.expect_success();
2272
2273        // Get the request.
2274        let message = catch_panic_error_message_async(server.get(&"/some_unknown_route")).await;
2275        assert_str_eq!(
2276            "Expect status code within 2xx range, received 404 (Not Found), for request GET http://localhost/some_unknown_route, with body ''",
2277            message
2278        );
2279    }
2280}
2281
2282#[cfg(test)]
2283mod test_expect_failure {
2284    use crate::TestServer;
2285    use crate::testing::catch_panic_error_message_async;
2286    use axum::Router;
2287    use axum::routing::get;
2288    use http::StatusCode;
2289    use pretty_assertions::assert_str_eq;
2290
2291    #[tokio::test]
2292    async fn it_should_not_panic_if_expect_failure_on_404() {
2293        // Build an application with a route.
2294        let app = Router::new();
2295
2296        // Run the server.
2297        let mut server = TestServer::new(app);
2298        server.expect_failure();
2299
2300        // Get the request.
2301        server.get(&"/some_unknown_route").await;
2302    }
2303
2304    #[tokio::test]
2305    async fn it_should_panic_if_success_is_returned() {
2306        async fn get_ping() -> &'static str {
2307            "pong!"
2308        }
2309
2310        // Build an application with a route.
2311        let app = Router::new().route("/ping", get(get_ping));
2312
2313        // Run the server.
2314        let mut server = TestServer::new(app);
2315        server.expect_failure();
2316
2317        // Get the request.
2318        let message = catch_panic_error_message_async(server.get(&"/ping")).await;
2319        assert_str_eq!(
2320            "Expect status code outside 2xx range, received 200 (OK), for request GET http://localhost/ping, with body 'pong!'",
2321            message
2322        );
2323    }
2324
2325    #[tokio::test]
2326    async fn it_should_panic_on_other_2xx_status_code() {
2327        async fn get_accepted() -> StatusCode {
2328            StatusCode::ACCEPTED
2329        }
2330
2331        // Build an application with a route.
2332        let app = Router::new().route("/accepted", get(get_accepted));
2333
2334        // Run the server.
2335        let mut server = TestServer::new(app);
2336        server.expect_failure();
2337
2338        // Get the request.
2339        let message = catch_panic_error_message_async(server.get(&"/accepted")).await;
2340        assert_str_eq!(
2341            "Expect status code outside 2xx range, received 202 (Accepted), for request GET http://localhost/accepted, with body ''",
2342            message
2343        );
2344    }
2345}
2346
2347#[cfg(feature = "typed-routing")]
2348#[cfg(test)]
2349mod test_typed_get {
2350    use super::*;
2351    use axum::Router;
2352    use axum_extra::routing::RouterExt;
2353    use serde::Deserialize;
2354
2355    #[derive(TypedPath, Deserialize)]
2356    #[typed_path("/path/{id}")]
2357    struct TestingPath {
2358        id: u32,
2359    }
2360
2361    async fn route_get(TestingPath { id }: TestingPath) -> String {
2362        format!("get {id}")
2363    }
2364
2365    fn new_app() -> Router {
2366        Router::new().typed_get(route_get)
2367    }
2368
2369    #[tokio::test]
2370    async fn it_should_send_get() {
2371        let server = TestServer::new(new_app());
2372
2373        server
2374            .typed_get(&TestingPath { id: 123 })
2375            .await
2376            .assert_text("get 123");
2377    }
2378}
2379
2380#[cfg(feature = "typed-routing")]
2381#[cfg(test)]
2382mod test_typed_post {
2383    use super::*;
2384    use axum::Router;
2385    use axum_extra::routing::RouterExt;
2386    use serde::Deserialize;
2387
2388    #[derive(TypedPath, Deserialize)]
2389    #[typed_path("/path/{id}")]
2390    struct TestingPath {
2391        id: u32,
2392    }
2393
2394    async fn route_post(TestingPath { id }: TestingPath) -> String {
2395        format!("post {id}")
2396    }
2397
2398    fn new_app() -> Router {
2399        Router::new().typed_post(route_post)
2400    }
2401
2402    #[tokio::test]
2403    async fn it_should_send_post() {
2404        let server = TestServer::new(new_app());
2405
2406        server
2407            .typed_post(&TestingPath { id: 123 })
2408            .await
2409            .assert_text("post 123");
2410    }
2411}
2412
2413#[cfg(feature = "typed-routing")]
2414#[cfg(test)]
2415mod test_typed_patch {
2416    use super::*;
2417    use axum::Router;
2418    use axum_extra::routing::RouterExt;
2419    use serde::Deserialize;
2420
2421    #[derive(TypedPath, Deserialize)]
2422    #[typed_path("/path/{id}")]
2423    struct TestingPath {
2424        id: u32,
2425    }
2426
2427    async fn route_patch(TestingPath { id }: TestingPath) -> String {
2428        format!("patch {id}")
2429    }
2430
2431    fn new_app() -> Router {
2432        Router::new().typed_patch(route_patch)
2433    }
2434
2435    #[tokio::test]
2436    async fn it_should_send_patch() {
2437        let server = TestServer::new(new_app());
2438
2439        server
2440            .typed_patch(&TestingPath { id: 123 })
2441            .await
2442            .assert_text("patch 123");
2443    }
2444}
2445
2446#[cfg(feature = "typed-routing")]
2447#[cfg(test)]
2448mod test_typed_put {
2449    use super::*;
2450    use axum::Router;
2451    use axum_extra::routing::RouterExt;
2452    use serde::Deserialize;
2453
2454    #[derive(TypedPath, Deserialize)]
2455    #[typed_path("/path/{id}")]
2456    struct TestingPath {
2457        id: u32,
2458    }
2459
2460    async fn route_put(TestingPath { id }: TestingPath) -> String {
2461        format!("put {id}")
2462    }
2463
2464    fn new_app() -> Router {
2465        Router::new().typed_put(route_put)
2466    }
2467
2468    #[tokio::test]
2469    async fn it_should_send_put() {
2470        let server = TestServer::new(new_app());
2471
2472        server
2473            .typed_put(&TestingPath { id: 123 })
2474            .await
2475            .assert_text("put 123");
2476    }
2477}
2478
2479#[cfg(feature = "typed-routing")]
2480#[cfg(test)]
2481mod test_typed_delete {
2482    use super::*;
2483    use axum::Router;
2484    use axum_extra::routing::RouterExt;
2485    use serde::Deserialize;
2486
2487    #[derive(TypedPath, Deserialize)]
2488    #[typed_path("/path/{id}")]
2489    struct TestingPath {
2490        id: u32,
2491    }
2492
2493    async fn route_delete(TestingPath { id }: TestingPath) -> String {
2494        format!("delete {id}")
2495    }
2496
2497    fn new_app() -> Router {
2498        Router::new().typed_delete(route_delete)
2499    }
2500
2501    #[tokio::test]
2502    async fn it_should_send_delete() {
2503        let server = TestServer::new(new_app());
2504
2505        server
2506            .typed_delete(&TestingPath { id: 123 })
2507            .await
2508            .assert_text("delete 123");
2509    }
2510}
2511
2512#[cfg(feature = "typed-routing")]
2513#[cfg(test)]
2514mod test_typed_method {
2515    use super::*;
2516    use axum::Router;
2517    use axum_extra::routing::RouterExt;
2518    use serde::Deserialize;
2519
2520    #[derive(TypedPath, Deserialize)]
2521    #[typed_path("/path/{id}")]
2522    struct TestingPath {
2523        id: u32,
2524    }
2525
2526    async fn route_get(TestingPath { id }: TestingPath) -> String {
2527        format!("get {id}")
2528    }
2529
2530    async fn route_post(TestingPath { id }: TestingPath) -> String {
2531        format!("post {id}")
2532    }
2533
2534    async fn route_patch(TestingPath { id }: TestingPath) -> String {
2535        format!("patch {id}")
2536    }
2537
2538    async fn route_put(TestingPath { id }: TestingPath) -> String {
2539        format!("put {id}")
2540    }
2541
2542    async fn route_delete(TestingPath { id }: TestingPath) -> String {
2543        format!("delete {id}")
2544    }
2545
2546    fn new_app() -> Router {
2547        Router::new()
2548            .typed_get(route_get)
2549            .typed_post(route_post)
2550            .typed_patch(route_patch)
2551            .typed_put(route_put)
2552            .typed_delete(route_delete)
2553    }
2554
2555    #[tokio::test]
2556    async fn it_should_send_get() {
2557        let server = TestServer::new(new_app());
2558
2559        server
2560            .typed_method(Method::GET, &TestingPath { id: 123 })
2561            .await
2562            .assert_text("get 123");
2563    }
2564
2565    #[tokio::test]
2566    async fn it_should_send_post() {
2567        let server = TestServer::new(new_app());
2568
2569        server
2570            .typed_method(Method::POST, &TestingPath { id: 123 })
2571            .await
2572            .assert_text("post 123");
2573    }
2574
2575    #[tokio::test]
2576    async fn it_should_send_patch() {
2577        let server = TestServer::new(new_app());
2578
2579        server
2580            .typed_method(Method::PATCH, &TestingPath { id: 123 })
2581            .await
2582            .assert_text("patch 123");
2583    }
2584
2585    #[tokio::test]
2586    async fn it_should_send_put() {
2587        let server = TestServer::new(new_app());
2588
2589        server
2590            .typed_method(Method::PUT, &TestingPath { id: 123 })
2591            .await
2592            .assert_text("put 123");
2593    }
2594
2595    #[tokio::test]
2596    async fn it_should_send_delete() {
2597        let server = TestServer::new(new_app());
2598
2599        server
2600            .typed_method(Method::DELETE, &TestingPath { id: 123 })
2601            .await
2602            .assert_text("delete 123");
2603    }
2604}
2605
2606#[cfg(test)]
2607mod test_sync {
2608    use super::*;
2609    use axum::Router;
2610    use axum::routing::get;
2611    use std::cell::OnceCell;
2612
2613    #[tokio::test]
2614    async fn it_should_be_able_to_be_in_one_cell() {
2615        let cell: OnceCell<TestServer> = OnceCell::new();
2616        let server = cell.get_or_init(|| {
2617            async fn route_get() -> &'static str {
2618                "it works"
2619            }
2620
2621            let router = Router::new().route("/test", get(route_get));
2622
2623            TestServer::new(router)
2624        });
2625
2626        server.get("/test").await.assert_text("it works");
2627    }
2628}
2629
2630#[cfg(test)]
2631mod test_is_running {
2632    use super::*;
2633    use crate::testing::catch_panic_error_message_async;
2634    use crate::util::new_random_tokio_tcp_listener_with_socket_addr;
2635    use axum::Router;
2636    use axum::routing::IntoMakeService;
2637    use axum::routing::get;
2638    use axum::serve;
2639    use pretty_assertions::assert_str_eq;
2640    use std::time::Duration;
2641    use tokio::sync::Notify;
2642    use tokio::time::sleep;
2643
2644    async fn get_ping() -> &'static str {
2645        "pong!"
2646    }
2647
2648    #[tokio::test]
2649    async fn it_should_panic_when_run_with_mock_http() {
2650        let shutdown_notification = Arc::new(Notify::new());
2651        let waiting_notification = shutdown_notification.clone();
2652
2653        // Build an application with a route.
2654        let app: IntoMakeService<Router> = Router::new()
2655            .route("/ping", get(get_ping))
2656            .into_make_service();
2657        let (listener, ip_port) = new_random_tokio_tcp_listener_with_socket_addr().unwrap();
2658        let application = serve(listener, app)
2659            .with_graceful_shutdown(async move { waiting_notification.notified().await });
2660
2661        // Run the server.
2662        let server = TestServer::builder().build(application);
2663
2664        server.get("/ping").await.assert_status_ok();
2665        assert!(server.is_running());
2666
2667        shutdown_notification.notify_one();
2668        sleep(Duration::from_millis(10)).await;
2669
2670        assert!(!server.is_running());
2671
2672        let ip = ip_port.ip();
2673        let port = ip_port.port();
2674        let expected = format!(
2675            "Sending request failed, for request GET http://{ip}:{port}/ping,
2676    client error (Connect)
2677    tcp connect error
2678    Connection refused (os error 61)
2679"
2680        );
2681        let message = catch_panic_error_message_async(server.get("/ping")).await;
2682        assert_str_eq!(expected, message);
2683    }
2684}
2685
2686#[cfg(test)]
2687mod test_save_cookies {
2688    use crate::TestServer;
2689    use axum::Router;
2690    use axum::extract::Request;
2691    use axum::http::header::HeaderMap;
2692    use axum::routing::get;
2693    use axum::routing::put;
2694    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
2695    use cookie::Cookie;
2696    use cookie::SameSite;
2697    use http_body_util::BodyExt;
2698
2699    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
2700
2701    #[tokio::test]
2702    async fn it_should_save_cookies_across_requests_when_enabled() {
2703        let mut server = TestServer::new(app());
2704
2705        server.save_cookies();
2706
2707        save_cookie_using_axum_test(&server).await;
2708        assert_cookie_using_axum_test(&server).await;
2709    }
2710
2711    #[cfg(feature = "reqwest")]
2712    #[tokio::test]
2713    async fn it_should_save_cookies_across_reqwest_requests_when_enabled() {
2714        let mut server = TestServer::builder().http_transport().build(app());
2715
2716        server.save_cookies();
2717
2718        save_cookie_using_reqwest(&server).await;
2719        save_cookie_using_reqwest(&server).await;
2720    }
2721
2722    #[tokio::test]
2723    async fn it_should_save_cookies_across_axum_test_requests_when_enabled_for_second_request() {
2724        let mut server = TestServer::builder().http_transport().build(app());
2725
2726        save_cookie_using_axum_test(&server).await;
2727        assert_no_cookie_using_axum_test(&server).await;
2728
2729        server.save_cookies();
2730
2731        save_cookie_using_axum_test(&server).await;
2732        assert_cookie_using_axum_test(&server).await;
2733    }
2734
2735    #[cfg(feature = "reqwest")]
2736    #[tokio::test]
2737    async fn it_should_save_cookies_across_reqwest_requests_when_enabled_for_second_request() {
2738        let mut server = TestServer::builder().http_transport().build(app());
2739
2740        save_cookie_using_reqwest(&server).await;
2741        assert_no_cookie_using_reqwest(&server).await;
2742
2743        server.save_cookies();
2744
2745        save_cookie_using_reqwest(&server).await;
2746        assert_cookie_using_reqwest(&server).await;
2747    }
2748
2749    #[cfg(feature = "reqwest")]
2750    #[tokio::test]
2751    async fn it_should_save_cookies_when_set_by_reqwest_and_read_by_axum_test() {
2752        let mut server = TestServer::builder().http_transport().build(app());
2753
2754        server.save_cookies();
2755
2756        save_cookie_using_reqwest(&server).await;
2757        assert_cookie_using_axum_test(&server).await;
2758    }
2759
2760    #[cfg(feature = "reqwest")]
2761    #[tokio::test]
2762    async fn it_should_save_cookies_when_set_by_axum_test_and_read_by_reqwest() {
2763        let mut server = TestServer::builder().http_transport().build(app());
2764
2765        server.save_cookies();
2766
2767        save_cookie_using_axum_test(&server).await;
2768        assert_cookie_using_reqwest(&server).await;
2769    }
2770
2771    fn app() -> Router {
2772        async fn put_cookie_with_attributes(
2773            mut cookies: AxumCookieJar,
2774            request: Request,
2775        ) -> (AxumCookieJar, &'static str) {
2776            let body_bytes = request
2777                .into_body()
2778                .collect()
2779                .await
2780                .expect("Should turn the body into bytes")
2781                .to_bytes();
2782
2783            let body_text: String = String::from_utf8_lossy(&body_bytes).to_string();
2784            let cookie = Cookie::build((TEST_COOKIE_NAME, body_text))
2785                .http_only(true)
2786                .secure(true)
2787                .same_site(SameSite::Strict)
2788                .path("/cookie")
2789                .build();
2790            cookies = cookies.add(cookie);
2791
2792            (cookies, &"done")
2793        }
2794
2795        async fn get_cookie_headers_joined(headers: HeaderMap) -> String {
2796            let cookies: String = headers
2797                .get_all("cookie")
2798                .into_iter()
2799                .map(|c| c.to_str().unwrap_or("").to_string())
2800                .reduce(|a, b| a + "; " + &b)
2801                .unwrap_or_else(|| String::new());
2802
2803            cookies
2804        }
2805
2806        Router::new()
2807            .route("/cookie", put(put_cookie_with_attributes))
2808            .route("/cookie", get(get_cookie_headers_joined))
2809    }
2810
2811    async fn save_cookie_using_axum_test(server: &TestServer) {
2812        server.put(&"/cookie").text(&"cookie-found!").await;
2813    }
2814
2815    #[cfg(feature = "reqwest")]
2816    async fn save_cookie_using_reqwest(server: &TestServer) {
2817        server
2818            .reqwest_put(&"/cookie")
2819            .body("cookie-found!".to_string())
2820            .send()
2821            .await
2822            .unwrap();
2823    }
2824
2825    async fn assert_cookie_using_axum_test(server: &TestServer) {
2826        server
2827            .get(&"/cookie")
2828            .await
2829            .assert_text("test-cookie=cookie-found!");
2830    }
2831
2832    #[cfg(feature = "reqwest")]
2833    async fn assert_cookie_using_reqwest(server: &TestServer) {
2834        let response_text = server
2835            .reqwest_get(&"/cookie")
2836            .send()
2837            .await
2838            .unwrap()
2839            .text()
2840            .await
2841            .unwrap();
2842
2843        assert_eq!("test-cookie=cookie-found!", response_text);
2844    }
2845
2846    async fn assert_no_cookie_using_axum_test(server: &TestServer) {
2847        server.get(&"/cookie").await.assert_text("");
2848    }
2849
2850    #[cfg(feature = "reqwest")]
2851    async fn assert_no_cookie_using_reqwest(server: &TestServer) {
2852        let response_text = server
2853            .reqwest_get(&"/cookie")
2854            .send()
2855            .await
2856            .unwrap()
2857            .text()
2858            .await
2859            .unwrap();
2860
2861        assert_eq!("", response_text);
2862    }
2863}