Skip to main content

axum_test/
test_request.rs

1use crate::TestResponse;
2use crate::internals::ErrorMessage;
3use crate::internals::ExpectedState;
4use crate::internals::QueryParamsStore;
5use crate::internals::RequestPathFormatter;
6use crate::multipart::MultipartForm;
7use crate::transport_layer::TransportLayer;
8use anyhow::Context;
9use anyhow::Error as AnyhowError;
10use anyhow::Result;
11use axum::body::Body;
12use bytes::Bytes;
13use cookie::Cookie;
14use cookie::CookieJar;
15use cookie::time::OffsetDateTime;
16use http::HeaderName;
17use http::HeaderValue;
18use http::Method;
19use http::Request;
20use http::header;
21use http::header::SET_COOKIE;
22use http_body_util::BodyExt;
23use serde::Serialize;
24use std::fmt::Debug;
25use std::fmt::Display;
26use std::fs::File;
27use std::fs::read;
28use std::fs::read_to_string;
29use std::future::{Future, IntoFuture};
30use std::io::BufReader;
31use std::path::Path;
32use std::pin::Pin;
33use std::sync::Arc;
34use url::Url;
35
36mod test_request_config;
37pub(crate) use self::test_request_config::*;
38
39///
40/// A `TestRequest` is for building and executing a HTTP request to the [`TestServer`](crate::TestServer).
41///
42/// ## Building
43///
44/// Requests are created by the [`TestServer`](crate::TestServer), using it's builder functions.
45/// They correspond to the appropriate HTTP method: [`TestServer::get()`](crate::TestServer::get()),
46/// [`TestServer::post()`](crate::TestServer::post()), etc.
47///
48/// See there for documentation.
49///
50/// ## Customising
51///
52/// The `TestRequest` allows the caller to fill in the rest of the request
53/// to be sent to the server. Including the headers, the body, cookies,
54/// and the content type, using the relevant functions.
55///
56/// The TestRequest struct provides a number of methods to set up the request,
57/// such as json, text, bytes, expect_failure, content_type, etc.
58///
59/// ## Sending
60///
61/// Once fully configured you send the request by awaiting the request object.
62///
63/// ```rust
64/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
65/// #
66/// # use axum::Router;
67/// # use axum_test::TestServer;
68/// #
69/// # let server = TestServer::new(Router::new());
70/// #
71/// // Build your request
72/// let request = server.get(&"/user")
73///     .add_header("x-custom-header", "example.com")
74///     .content_type("application/yaml");
75///
76/// // await request to execute
77/// let response = request.await;
78/// #
79/// # Ok(()) }
80/// ```
81///
82/// You will receive a `TestResponse`.
83///
84/// ## Cookie Saving
85///
86/// [`TestRequest::save_cookies()`](crate::TestRequest::save_cookies()) and [`TestRequest::do_not_save_cookies()`](crate::TestRequest::do_not_save_cookies())
87/// methods allow you to set the request to save cookies to the `TestServer`,
88/// for reuse on any future requests.
89///
90/// This behaviour is **off** by default, and can be changed for all `TestRequests`
91/// when building the `TestServer`. By building it with a `TestServerConfig` where `save_cookies` is set to true.
92///
93/// ## Expecting Failure and Success
94///
95/// When making a request you can mark it to expect a response within,
96/// or outside, of the 2xx range of HTTP status codes.
97///
98/// If the response returns a status code different to what is expected,
99/// then it will panic.
100///
101/// This is useful when making multiple requests within a test.
102/// As it can find issues earlier than later.
103///
104/// See the [`TestRequest::expect_failure()`](crate::TestRequest::expect_failure()),
105/// and [`TestRequest::expect_success()`](crate::TestRequest::expect_success()).
106///
107#[derive(Debug)]
108#[must_use = "futures do nothing unless polled"]
109pub struct TestRequest {
110    config: TestRequestConfig,
111    transport: Arc<Box<dyn TransportLayer>>,
112    body: Option<Body>,
113    expected_state: ExpectedState,
114}
115
116impl TestRequest {
117    pub(crate) fn new(transport: Arc<Box<dyn TransportLayer>>, config: TestRequestConfig) -> Self {
118        let expected_state = config.expected_state;
119
120        Self {
121            config,
122            transport,
123            body: None,
124            expected_state,
125        }
126    }
127
128    /// Set the body of the request to send up data as Json,
129    /// and changes the content type to `application/json`.
130    pub fn json<J>(self, body: &J) -> Self
131    where
132        J: ?Sized + Serialize,
133    {
134        let body_bytes =
135            serde_json::to_vec(body).expect("It should serialize the content into Json");
136
137        self.bytes(body_bytes.into())
138            .content_type(mime::APPLICATION_JSON.essence_str())
139    }
140
141    /// Sends a payload as a Json request, with the contents coming from a file.
142    pub fn json_from_file<P>(self, path: P) -> Self
143    where
144        P: AsRef<Path>,
145    {
146        let path_ref = path.as_ref();
147        let file = File::open(path_ref)
148            .error_message_fn(|| format!("Failed to read from file '{}'", path_ref.display()));
149
150        let reader = BufReader::new(file);
151        let payload =
152            serde_json::from_reader::<_, serde_json::Value>(reader).error_message_fn(|| {
153                format!(
154                    "Failed to deserialize file '{}' as Json",
155                    path_ref.display()
156                )
157            });
158
159        self.json(&payload)
160    }
161
162    /// Set the body of the request to send up data as Yaml,
163    /// and changes the content type to `application/yaml`.
164    #[cfg(feature = "yaml")]
165    pub fn yaml<Y>(self, body: &Y) -> Self
166    where
167        Y: ?Sized + Serialize,
168    {
169        let body = serde_yaml::to_string(body).expect("It should serialize the content into Yaml");
170
171        self.bytes(body.into_bytes().into())
172            .content_type("application/yaml")
173    }
174
175    /// Sends a payload as a Yaml request, with the contents coming from a file.
176    #[cfg(feature = "yaml")]
177    pub fn yaml_from_file<P>(self, path: P) -> Self
178    where
179        P: AsRef<Path>,
180    {
181        let path_ref = path.as_ref();
182        let file = File::open(path_ref)
183            .error_message_fn(|| format!("Failed to read from file '{}'", path_ref.display()));
184
185        let reader = BufReader::new(file);
186        let payload =
187            serde_yaml::from_reader::<_, serde_yaml::Value>(reader).error_message_fn(|| {
188                format!(
189                    "Failed to deserialize file '{}' as Yaml",
190                    path_ref.display()
191                )
192            });
193
194        self.yaml(&payload)
195    }
196
197    /// Set the body of the request to send up data as MsgPack,
198    /// and changes the content type to `application/msgpack`.
199    #[cfg(feature = "msgpack")]
200    pub fn msgpack<M>(self, body: &M) -> Self
201    where
202        M: ?Sized + Serialize,
203    {
204        let body_bytes =
205            ::rmp_serde::to_vec(body).expect("It should serialize the content into MsgPack");
206
207        self.bytes(body_bytes.into())
208            .content_type("application/msgpack")
209    }
210
211    /// Sets the body of the request, with the content type
212    /// of 'application/x-www-form-urlencoded'.
213    pub fn form<F>(self, body: &F) -> Self
214    where
215        F: ?Sized + Serialize,
216    {
217        let body_text =
218            serde_urlencoded::to_string(body).expect("It should serialize the content into a Form");
219
220        self.bytes(body_text.into())
221            .content_type(mime::APPLICATION_WWW_FORM_URLENCODED.essence_str())
222    }
223
224    /// For sending multipart forms.
225    /// The payload is built using [`MultipartForm`](crate::multipart::MultipartForm) and [`Part`](crate::multipart::Part).
226    ///
227    /// This will be sent with the content type of 'multipart/form-data'.
228    ///
229    /// # Simple example
230    ///
231    /// ```rust
232    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
233    /// #
234    /// use axum::Router;
235    /// use axum_test::TestServer;
236    /// use axum_test::multipart::MultipartForm;
237    ///
238    /// let app = Router::new();
239    /// let server = TestServer::new(app);
240    ///
241    /// let multipart_form = MultipartForm::new()
242    ///     .add_text("name", "Joe")
243    ///     .add_text("animals", "foxes");
244    ///
245    /// let response = server.post(&"/my-form")
246    ///     .multipart(multipart_form)
247    ///     .await;
248    /// #
249    /// # Ok(()) }
250    /// ```
251    ///
252    /// # Sending byte parts
253    ///
254    /// ```rust
255    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
256    /// #
257    /// use axum::Router;
258    /// use axum_test::TestServer;
259    /// use axum_test::multipart::MultipartForm;
260    /// use axum_test::multipart::Part;
261    ///
262    /// let app = Router::new();
263    /// let server = TestServer::new(app);
264    ///
265    /// let readme_bytes = include_bytes!("../README.md");
266    /// let readme_part = Part::bytes(readme_bytes.as_slice())
267    ///     .file_name(&"README.md")
268    ///     .mime_type(&"text/markdown");
269    ///
270    /// let multipart_form = MultipartForm::new()
271    ///     .add_part("file", readme_part);
272    ///
273    /// let response = server.post(&"/my-form")
274    ///     .multipart(multipart_form)
275    ///     .await;
276    /// #
277    /// # Ok(()) }
278    /// ```
279    ///
280    pub fn multipart(mut self, multipart: MultipartForm) -> Self {
281        self.config.content_type = Some(multipart.content_type());
282        self.body = Some(multipart.into());
283
284        self
285    }
286
287    /// Set raw text as the body of the request,
288    /// and sets the content type to `text/plain`.
289    pub fn text<T>(self, raw_text: T) -> Self
290    where
291        T: Display,
292    {
293        let body_text = raw_text.to_string();
294
295        self.bytes(body_text.into())
296            .content_type(mime::TEXT_PLAIN.essence_str())
297    }
298
299    /// Sends a payload as plain text, with the contents coming from a file.
300    pub fn text_from_file<P>(self, path: P) -> Self
301    where
302        P: AsRef<Path>,
303    {
304        let path_ref = path.as_ref();
305        let payload = read_to_string(path_ref)
306            .error_message_fn(|| format!("Failed to read from file '{}'", path_ref.display()));
307
308        self.text(payload)
309    }
310
311    /// Set raw bytes as the body of the request.
312    ///
313    /// The content type is left unchanged.
314    pub fn bytes(mut self, body_bytes: Bytes) -> Self {
315        let body: Body = body_bytes.into();
316
317        self.body = Some(body);
318        self
319    }
320
321    /// Reads the contents of the file as raw bytes, and sends it within the request.
322    ///
323    /// The content type is left unchanged, and no parsing of the file is done.
324    pub fn bytes_from_file<P>(self, path: P) -> Self
325    where
326        P: AsRef<Path>,
327    {
328        let path_ref = path.as_ref();
329        let payload = read(path_ref)
330            .error_message_fn(|| format!("Failed to read from file '{}'", path_ref.display()));
331
332        self.bytes(payload.into())
333    }
334
335    /// Set the content type to use for this request in the header.
336    pub fn content_type(mut self, content_type: &str) -> Self {
337        self.config.content_type = Some(content_type.to_string());
338        self
339    }
340
341    /// Adds a Cookie to be sent with this request.
342    pub fn add_cookie(mut self, cookie: Cookie<'_>) -> Self {
343        self.config.cookies.add(cookie.into_owned());
344        self
345    }
346
347    /// Adds many cookies to be used with this request.
348    pub fn add_cookies(mut self, cookies: CookieJar) -> Self {
349        for cookie in cookies.iter() {
350            self.config.cookies.add(cookie.clone());
351        }
352
353        self
354    }
355
356    /// Clears all cookies used internally within this Request,
357    /// including any that came from the `TestServer`.
358    pub fn clear_cookies(mut self) -> Self {
359        self.config.cookies = CookieJar::new();
360        self
361    }
362
363    /// Any cookies returned will be saved to the [`TestServer`](crate::TestServer) that created this,
364    /// which will continue to use those cookies on future requests.
365    pub fn save_cookies(mut self) -> Self {
366        self.config.is_saving_cookies = true;
367        self
368    }
369
370    /// Cookies returned by this will _not_ be saved to the `TestServer`.
371    /// For use by future requests.
372    ///
373    /// This is the default behaviour.
374    /// You can change that default in [`TestServerConfig`](crate::TestServerConfig).
375    pub fn do_not_save_cookies(mut self) -> Self {
376        self.config.is_saving_cookies = false;
377        self
378    }
379
380    /// Adds query parameters to be sent with this request.
381    pub fn add_query_param<V>(self, key: &str, value: V) -> Self
382    where
383        V: Serialize,
384    {
385        self.add_query_params(&[(key, value)])
386    }
387
388    /// Adds the structure given as query parameters for this request.
389    ///
390    /// This is designed to take a list of parameters, or a body of parameters,
391    /// and then serializes them into the parameters of the request.
392    ///
393    /// # Sending a body of parameters using `json!`
394    ///
395    /// ```rust
396    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
397    /// #
398    /// use axum::Router;
399    /// use axum_test::TestServer;
400    /// use serde_json::json;
401    ///
402    /// let app = Router::new();
403    /// let server = TestServer::new(app);
404    ///
405    /// let response = server.get(&"/my-end-point")
406    ///     .add_query_params(json!({
407    ///         "username": "Brian",
408    ///         "age": 20
409    ///     }))
410    ///     .await;
411    /// #
412    /// # Ok(()) }
413    /// ```
414    ///
415    /// # Sending a body of parameters with Serde
416    ///
417    /// ```rust
418    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
419    /// #
420    /// use axum::Router;
421    /// use axum_test::TestServer;
422    /// use serde::Deserialize;
423    /// use serde::Serialize;
424    ///
425    /// #[derive(Serialize, Deserialize)]
426    /// struct UserQueryParams {
427    ///     username: String,
428    ///     age: u32,
429    /// }
430    ///
431    /// let app = Router::new();
432    /// let server = TestServer::new(app);
433    ///
434    /// let response = server.get(&"/my-end-point")
435    ///     .add_query_params(UserQueryParams {
436    ///         username: "Brian".to_string(),
437    ///         age: 20
438    ///     })
439    ///     .await;
440    /// #
441    /// # Ok(()) }
442    /// ```
443    ///
444    /// # Sending a list of parameters
445    ///
446    /// ```rust
447    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
448    /// #
449    /// use axum::Router;
450    /// use axum_test::TestServer;
451    ///
452    /// let app = Router::new();
453    /// let server = TestServer::new(app);
454    ///
455    /// let response = server.get(&"/my-end-point")
456    ///     .add_query_params(&[
457    ///         ("username", "Brian"),
458    ///         ("age", "20"),
459    ///     ])
460    ///     .await;
461    /// #
462    /// # Ok(()) }
463    /// ```
464    ///
465    pub fn add_query_params<V>(mut self, query_params: V) -> Self
466    where
467        V: Serialize,
468    {
469        self.config
470            .query_params
471            .add(query_params)
472            .error_request("It should serialize query parameters", &self);
473
474        self
475    }
476
477    /// Adds a query param onto the end of the request,
478    /// with no urlencoding of any kind.
479    ///
480    /// This exists to allow custom query parameters,
481    /// such as for the many versions of query param arrays.
482    ///
483    /// ```rust
484    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
485    /// #
486    /// use axum::Router;
487    /// use axum_test::TestServer;
488    ///
489    /// let app = Router::new();
490    /// let server = TestServer::new(app);
491    ///
492    /// let response = server.get(&"/my-end-point")
493    ///     .add_raw_query_param(&"my-flag")
494    ///     .add_raw_query_param(&"array[]=123")
495    ///     .add_raw_query_param(&"filter[value]=some-value")
496    ///     .await;
497    /// #
498    /// # Ok(()) }
499    /// ```
500    ///
501    pub fn add_raw_query_param(mut self, query_param: &str) -> Self {
502        self.config.query_params.add_raw(query_param.to_string());
503
504        self
505    }
506
507    /// Clears all query params set,
508    /// including any that came from the [`TestServer`](crate::TestServer).
509    pub fn clear_query_params(mut self) -> Self {
510        self.config.query_params.clear();
511
512        self
513    }
514
515    /// Adds a header to be sent with this request.
516    ///
517    /// ```rust
518    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
519    /// #
520    /// use axum::Router;
521    /// use axum_test::TestServer;
522    ///
523    /// let app = Router::new();
524    /// let server = TestServer::new(app);
525    ///
526    /// let response = server.get(&"/my-end-point")
527    ///     .add_header("x-custom-header", "custom-value")
528    ///     .add_header(http::header::CONTENT_LENGTH, 12345)
529    ///     .add_header(http::header::HOST, "example.com")
530    ///     .await;
531    /// #
532    /// # Ok(()) }
533    /// ```
534    pub fn add_header<N, V>(mut self, name: N, value: V) -> Self
535    where
536        N: TryInto<HeaderName>,
537        N::Error: Debug,
538        V: TryInto<HeaderValue>,
539        V::Error: Debug,
540    {
541        let header_name: HeaderName = name
542            .try_into()
543            .expect("Failed to convert header name to HeaderName");
544        let header_value: HeaderValue = value
545            .try_into()
546            .expect("Failed to convert header vlue to HeaderValue");
547
548        self.config.headers.push((header_name, header_value));
549        self
550    }
551
552    /// Adds an 'AUTHORIZATION' HTTP header to the request,
553    /// with no internal formatting of what is given.
554    pub fn authorization<T>(self, authorization_header: T) -> Self
555    where
556        T: AsRef<str>,
557    {
558        let authorization_header_value = HeaderValue::from_str(authorization_header.as_ref())
559            .expect("Cannot build Authorization HeaderValue from token");
560
561        self.add_header(header::AUTHORIZATION, authorization_header_value)
562    }
563
564    /// Adds an 'AUTHORIZATION' HTTP header to the request,
565    /// in the 'Bearer {token}' format.
566    pub fn authorization_bearer<T>(self, authorization_bearer_token: T) -> Self
567    where
568        T: Display,
569    {
570        let authorization_bearer_header_str = format!("Bearer {authorization_bearer_token}");
571        self.authorization(authorization_bearer_header_str)
572    }
573
574    /// Clears all headers set.
575    pub fn clear_headers(mut self) -> Self {
576        self.config.headers = vec![];
577        self
578    }
579
580    /// Marks that this request is expected to always return a HTTP
581    /// status code within the 2xx range (200 to 299).
582    ///
583    /// If a code _outside_ of that range is returned,
584    /// then this will panic.
585    ///
586    /// ```rust
587    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
588    /// #
589    /// use axum::Json;
590    /// use axum::Router;
591    /// use axum::http::StatusCode;
592    /// use axum::routing::put;
593    /// use serde_json::json;
594    ///
595    /// use axum_test::TestServer;
596    ///
597    /// let app = Router::new()
598    ///     .route(&"/todo", put(|| async { StatusCode::NOT_FOUND }));
599    ///
600    /// let server = TestServer::new(app);
601    ///
602    /// // If this doesn't return a value in the 2xx range,
603    /// // then it will panic.
604    /// server.put(&"/todo")
605    ///     .expect_success()
606    ///     .json(&json!({
607    ///         "task": "buy milk",
608    ///     }))
609    ///     .await;
610    /// #
611    /// # Ok(())
612    /// # }
613    /// ```
614    ///
615    pub fn expect_success(self) -> Self {
616        self.expect_state(ExpectedState::Success)
617    }
618
619    /// Marks that this request is expected to return a HTTP status code
620    /// outside of the 2xx range.
621    ///
622    /// If a code _within_ the 2xx range is returned,
623    /// then this will panic.
624    pub fn expect_failure(self) -> Self {
625        self.expect_state(ExpectedState::Failure)
626    }
627
628    fn expect_state(mut self, expected_state: ExpectedState) -> Self {
629        self.expected_state = expected_state;
630        self
631    }
632
633    async fn send(self) -> Result<TestResponse> {
634        let debug_request_format = self.debug_request_format().to_string();
635
636        let method = self.config.method;
637        let expected_state = self.expected_state;
638        let save_cookies = self.config.is_saving_cookies;
639        let body = self.body.unwrap_or(Body::empty());
640        let full_request_url =
641            Self::build_url_query_params(self.config.full_request_url, &self.config.query_params);
642
643        let request = Self::build_request(
644            method.clone(),
645            &full_request_url,
646            body,
647            self.config.content_type,
648            self.config.cookies,
649            self.config.headers,
650            &debug_request_format,
651        )?;
652
653        #[allow(unused_mut)] // Allowed for the `ws` use immediately after.
654        let mut http_response = self.transport.send(request).await?;
655
656        #[cfg(feature = "ws")]
657        let websockets = {
658            let maybe_on_upgrade = http_response
659                .extensions_mut()
660                .remove::<hyper::upgrade::OnUpgrade>();
661            let transport_type = self.transport.transport_layer_type();
662
663            crate::internals::TestResponseWebSocket {
664                maybe_on_upgrade,
665                transport_type,
666            }
667        };
668
669        let version = http_response.version();
670        let (response_parts, response_body) = http_response.into_parts();
671        let response_bytes = response_body.collect().await?.to_bytes();
672
673        if save_cookies {
674            let cookie_headers = response_parts.headers.get_all(SET_COOKIE).into_iter();
675            self.config
676                .atomic_cookie_jar
677                .add_cookies_by_headers(cookie_headers)?;
678        }
679
680        let test_response = TestResponse::new(
681            version,
682            method,
683            full_request_url,
684            response_parts,
685            response_bytes,
686            #[cfg(feature = "ws")]
687            websockets,
688        );
689
690        // Assert if ok or not.
691        match expected_state {
692            ExpectedState::Success => {
693                test_response.assert_status_success();
694            }
695            ExpectedState::Failure => {
696                test_response.assert_status_failure();
697            }
698            ExpectedState::None => {}
699        }
700
701        Ok(test_response)
702    }
703
704    fn build_url_query_params(mut url: Url, query_params: &QueryParamsStore) -> Url {
705        // Add all the query params we have
706        if query_params.has_content() {
707            url.set_query(Some(&query_params.to_string()));
708        }
709
710        url
711    }
712
713    fn build_request(
714        method: Method,
715        url: &Url,
716        body: Body,
717        content_type: Option<String>,
718        cookies: CookieJar,
719        headers: Vec<(HeaderName, HeaderValue)>,
720        debug_request_format: &str,
721    ) -> Result<Request<Body>> {
722        let mut request_builder = Request::builder().uri(url.as_str()).method(method);
723
724        // Add all the headers we have.
725        if let Some(content_type) = content_type {
726            let (header_key, header_value) =
727                build_content_type_header(&content_type, debug_request_format)?;
728            request_builder = request_builder.header(header_key, header_value);
729        }
730
731        // Add all the non-expired cookies as headers
732        // Also strip cookies from their attributes, only their names and values should be preserved to conform the HTTP standard
733        let now = OffsetDateTime::now_utc();
734        for cookie in cookies.iter() {
735            let expired = cookie
736                .expires_datetime()
737                .map(|expires| expires <= now)
738                .unwrap_or(false);
739
740            if !expired {
741                let cookie_raw = cookie.stripped().to_string();
742                let header_value = HeaderValue::from_str(&cookie_raw)?;
743                request_builder = request_builder.header(header::COOKIE, header_value);
744            }
745        }
746
747        // Put headers into the request
748        for (header_name, header_value) in headers {
749            request_builder = request_builder.header(header_name, header_value);
750        }
751
752        let request = request_builder.body(body).with_context(|| {
753            format!("Expect valid hyper Request to be built, for request {debug_request_format}")
754        })?;
755
756        Ok(request)
757    }
758
759    pub(crate) fn debug_request_format(&self) -> RequestPathFormatter<'_, Url> {
760        RequestPathFormatter::new(
761            &self.config.method,
762            &self.config.full_request_url,
763            Some(&self.config.query_params),
764        )
765    }
766}
767
768impl TryFrom<TestRequest> for Request<Body> {
769    type Error = AnyhowError;
770
771    fn try_from(test_request: TestRequest) -> Result<Request<Body>> {
772        let debug_request_format = test_request.debug_request_format().to_string();
773        let url = TestRequest::build_url_query_params(
774            test_request.config.full_request_url,
775            &test_request.config.query_params,
776        );
777        let body = test_request.body.unwrap_or(Body::empty());
778
779        TestRequest::build_request(
780            test_request.config.method,
781            &url,
782            body,
783            test_request.config.content_type,
784            test_request.config.cookies,
785            test_request.config.headers,
786            &debug_request_format,
787        )
788    }
789}
790
791impl IntoFuture for TestRequest {
792    type Output = TestResponse;
793    type IntoFuture = Pin<Box<dyn Future<Output = TestResponse> + Send>>;
794
795    fn into_future(self) -> Self::IntoFuture {
796        Box::pin(async {
797            let debug_request_format = self.debug_request_format().to_string();
798
799            self.send()
800                .await
801                .map_err(|err| {
802                    use std::fmt::Write;
803                    let mut output = err.to_string();
804                    if let Some(inner) = err.source() {
805                        write!(
806                            output,
807                            "
808    {inner}"
809                        )
810                        .unwrap();
811
812                        // TODO: get rid of this hack and do this properly.
813                        // It exists to ensure the 'connection refused' part of an error shows up when the server isn't running.
814                        // See: test `it_should_panic_when_run_with_mock_http`
815                        if let Some(inner_2) = inner.source() {
816                            write!(
817                                output,
818                                "
819    {inner_2}"
820                            )
821                            .unwrap();
822                        }
823                    }
824
825                    output
826                })
827                .error_message_fn(|| {
828                    format!("Sending request failed, for request {debug_request_format}")
829                })
830        })
831    }
832}
833
834fn build_content_type_header(
835    content_type: &str,
836    debug_request_format: &str,
837) -> Result<(HeaderName, HeaderValue)> {
838    let header_value = HeaderValue::from_str(content_type).with_context(|| {
839        format!(
840            "Failed to store header content type '{content_type}', for request {debug_request_format}"
841        )
842    })?;
843
844    Ok((header::CONTENT_TYPE, header_value))
845}
846
847#[cfg(test)]
848mod test_content_type {
849    use crate::TestServer;
850    use axum::Router;
851    use axum::routing::get;
852    use http::HeaderMap;
853    use http::header::CONTENT_TYPE;
854
855    async fn get_content_type(headers: HeaderMap) -> String {
856        headers
857            .get(CONTENT_TYPE)
858            .map(|h| h.to_str().unwrap().to_string())
859            .unwrap_or_else(|| "".to_string())
860    }
861
862    #[tokio::test]
863    async fn it_should_not_set_a_content_type_by_default() {
864        // Build an application with a route.
865        let app = Router::new().route("/content_type", get(get_content_type));
866
867        // Run the server.
868        let server = TestServer::new(app);
869
870        // Get the request.
871        let text = server.get(&"/content_type").await.text();
872
873        assert_eq!(text, "");
874    }
875
876    #[tokio::test]
877    async fn it_should_override_server_content_type_when_present() {
878        // Build an application with a route.
879        let app = Router::new().route("/content_type", get(get_content_type));
880
881        // Run the server.
882        let server = TestServer::builder()
883            .default_content_type("text/plain")
884            .build(app);
885
886        // Get the request.
887        let text = server
888            .get(&"/content_type")
889            .content_type(&"application/json")
890            .await
891            .text();
892
893        assert_eq!(text, "application/json");
894    }
895
896    #[tokio::test]
897    async fn it_should_set_content_type_when_present() {
898        // Build an application with a route.
899        let app = Router::new().route("/content_type", get(get_content_type));
900
901        // Run the server.
902        let server = TestServer::new(app);
903
904        // Get the request.
905        let text = server
906            .get(&"/content_type")
907            .content_type(&"application/custom")
908            .await
909            .text();
910
911        assert_eq!(text, "application/custom");
912    }
913}
914
915#[cfg(test)]
916mod test_json {
917    use crate::TestServer;
918    use axum::Json;
919    use axum::Router;
920    use axum::extract::DefaultBodyLimit;
921    use axum::routing::post;
922    use http::HeaderMap;
923    use http::header::CONTENT_TYPE;
924    use rand::random;
925    use serde::Deserialize;
926    use serde::Serialize;
927    use serde_json::json;
928
929    #[tokio::test]
930    async fn it_should_pass_json_up_to_be_read() {
931        #[derive(Deserialize, Serialize)]
932        struct TestJson {
933            name: String,
934            age: u32,
935            pets: Option<String>,
936        }
937
938        // Build an application with a route.
939        let app = Router::new().route(
940            "/json",
941            post(|Json(json): Json<TestJson>| async move {
942                format!(
943                    "json: {}, {}, {}",
944                    json.name,
945                    json.age,
946                    json.pets.unwrap_or_else(|| "pandas".to_string())
947                )
948            }),
949        );
950
951        // Run the server.
952        let server = TestServer::new(app);
953
954        // Get the request.
955        let text = server
956            .post(&"/json")
957            .json(&TestJson {
958                name: "Joe".to_string(),
959                age: 20,
960                pets: Some("foxes".to_string()),
961            })
962            .await
963            .text();
964
965        assert_eq!(text, "json: Joe, 20, foxes");
966    }
967
968    #[tokio::test]
969    async fn it_should_pass_json_content_type_for_json() {
970        // Build an application with a route.
971        let app = Router::new().route(
972            "/content_type",
973            post(|headers: HeaderMap| async move {
974                headers
975                    .get(CONTENT_TYPE)
976                    .map(|h| h.to_str().unwrap().to_string())
977                    .unwrap_or_else(|| "".to_string())
978            }),
979        );
980
981        // Run the server.
982        let server = TestServer::new(app);
983
984        // Get the request.
985        let text = server.post(&"/content_type").json(&json!({})).await.text();
986
987        assert_eq!(text, "application/json");
988    }
989
990    #[tokio::test]
991    async fn it_should_pass_large_json_blobs_over_http() {
992        const LARGE_BLOB_SIZE: usize = 16777216; // 16mb
993
994        #[derive(Deserialize, Serialize, PartialEq, Debug)]
995        struct TestLargeJson {
996            items: Vec<String>,
997        }
998
999        let mut size = 0;
1000        let mut items = vec![];
1001        while size < LARGE_BLOB_SIZE {
1002            let item = random::<u64>().to_string();
1003            size += item.len();
1004            items.push(item);
1005        }
1006        let large_json_blob = TestLargeJson { items };
1007
1008        // Build an application with a route.
1009        let app = Router::new()
1010            .route(
1011                "/json",
1012                post(|Json(json): Json<TestLargeJson>| async { Json(json) }),
1013            )
1014            .layer(DefaultBodyLimit::max(LARGE_BLOB_SIZE * 2));
1015
1016        // Run the server.
1017        let server = TestServer::builder()
1018            .http_transport()
1019            .expect_success_by_default()
1020            .build(app);
1021
1022        // Get the request.
1023        server
1024            .post(&"/json")
1025            .json(&large_json_blob)
1026            .await
1027            .assert_json(&large_json_blob);
1028    }
1029}
1030
1031#[cfg(test)]
1032mod test_json_from_file {
1033    use crate::TestServer;
1034    use axum::Json;
1035    use axum::Router;
1036    use axum::routing::post;
1037    use http::HeaderMap;
1038    use http::header::CONTENT_TYPE;
1039    use serde::Deserialize;
1040    use serde::Serialize;
1041
1042    #[tokio::test]
1043    async fn it_should_pass_json_up_to_be_read() {
1044        #[derive(Deserialize, Serialize)]
1045        struct TestJson {
1046            name: String,
1047            age: u32,
1048        }
1049
1050        // Build an application with a route.
1051        let app = Router::new().route(
1052            "/json",
1053            post(|Json(json): Json<TestJson>| async move {
1054                format!("json: {}, {}", json.name, json.age,)
1055            }),
1056        );
1057
1058        // Run the server.
1059        let server = TestServer::new(app);
1060
1061        // Get the request.
1062        let text = server
1063            .post(&"/json")
1064            .json_from_file(&"files/example.json")
1065            .await
1066            .text();
1067
1068        assert_eq!(text, "json: Joe, 20");
1069    }
1070
1071    #[tokio::test]
1072    async fn it_should_pass_json_content_type_for_json() {
1073        // Build an application with a route.
1074        let app = Router::new().route(
1075            "/content_type",
1076            post(|headers: HeaderMap| async move {
1077                headers
1078                    .get(CONTENT_TYPE)
1079                    .map(|h| h.to_str().unwrap().to_string())
1080                    .unwrap_or_else(|| "".to_string())
1081            }),
1082        );
1083
1084        // Run the server.
1085        let server = TestServer::new(app);
1086
1087        // Get the request.
1088        let text = server
1089            .post(&"/content_type")
1090            .json_from_file(&"files/example.json")
1091            .await
1092            .text();
1093
1094        assert_eq!(text, "application/json");
1095    }
1096}
1097
1098#[cfg(feature = "yaml")]
1099#[cfg(test)]
1100mod test_yaml {
1101    use crate::TestServer;
1102    use axum::Router;
1103    use axum::routing::post;
1104    use axum_yaml::Yaml;
1105    use http::HeaderMap;
1106    use http::header::CONTENT_TYPE;
1107    use serde::Deserialize;
1108    use serde::Serialize;
1109    use serde_json::json;
1110
1111    #[tokio::test]
1112    async fn it_should_pass_yaml_up_to_be_read() {
1113        #[derive(Deserialize, Serialize)]
1114        struct TestYaml {
1115            name: String,
1116            age: u32,
1117            pets: Option<String>,
1118        }
1119
1120        // Build an application with a route.
1121        let app = Router::new().route(
1122            "/yaml",
1123            post(|Yaml(yaml): Yaml<TestYaml>| async move {
1124                format!(
1125                    "yaml: {}, {}, {}",
1126                    yaml.name,
1127                    yaml.age,
1128                    yaml.pets.unwrap_or_else(|| "pandas".to_string())
1129                )
1130            }),
1131        );
1132
1133        // Run the server.
1134        let server = TestServer::new(app);
1135
1136        // Get the request.
1137        let text = server
1138            .post(&"/yaml")
1139            .yaml(&TestYaml {
1140                name: "Joe".to_string(),
1141                age: 20,
1142                pets: Some("foxes".to_string()),
1143            })
1144            .await
1145            .text();
1146
1147        assert_eq!(text, "yaml: Joe, 20, foxes");
1148    }
1149
1150    #[tokio::test]
1151    async fn it_should_pass_yaml_content_type_for_yaml() {
1152        // Build an application with a route.
1153        let app = Router::new().route(
1154            "/content_type",
1155            post(|headers: HeaderMap| async move {
1156                headers
1157                    .get(CONTENT_TYPE)
1158                    .map(|h| h.to_str().unwrap().to_string())
1159                    .unwrap_or_else(|| "".to_string())
1160            }),
1161        );
1162
1163        // Run the server.
1164        let server = TestServer::new(app);
1165
1166        // Get the request.
1167        let text = server.post(&"/content_type").yaml(&json!({})).await.text();
1168
1169        assert_eq!(text, "application/yaml");
1170    }
1171}
1172
1173#[cfg(feature = "yaml")]
1174#[cfg(test)]
1175mod test_yaml_from_file {
1176    use crate::TestServer;
1177    use axum::Router;
1178    use axum::routing::post;
1179    use axum_yaml::Yaml;
1180    use http::HeaderMap;
1181    use http::header::CONTENT_TYPE;
1182    use serde::Deserialize;
1183    use serde::Serialize;
1184
1185    #[tokio::test]
1186    async fn it_should_pass_yaml_up_to_be_read() {
1187        #[derive(Deserialize, Serialize)]
1188        struct TestYaml {
1189            name: String,
1190            age: u32,
1191        }
1192
1193        // Build an application with a route.
1194        let app = Router::new().route(
1195            "/yaml",
1196            post(|Yaml(yaml): Yaml<TestYaml>| async move {
1197                format!("yaml: {}, {}", yaml.name, yaml.age,)
1198            }),
1199        );
1200
1201        // Run the server.
1202        let server = TestServer::new(app);
1203
1204        // Get the request.
1205        let text = server
1206            .post(&"/yaml")
1207            .yaml_from_file(&"files/example.yaml")
1208            .await
1209            .text();
1210
1211        assert_eq!(text, "yaml: Joe, 20");
1212    }
1213
1214    #[tokio::test]
1215    async fn it_should_pass_yaml_content_type_for_yaml() {
1216        // Build an application with a route.
1217        let app = Router::new().route(
1218            "/content_type",
1219            post(|headers: HeaderMap| async move {
1220                headers
1221                    .get(CONTENT_TYPE)
1222                    .map(|h| h.to_str().unwrap().to_string())
1223                    .unwrap_or_else(|| "".to_string())
1224            }),
1225        );
1226
1227        // Run the server.
1228        let server = TestServer::new(app);
1229
1230        // Get the request.
1231        let text = server
1232            .post(&"/content_type")
1233            .yaml_from_file(&"files/example.yaml")
1234            .await
1235            .text();
1236
1237        assert_eq!(text, "application/yaml");
1238    }
1239}
1240
1241#[cfg(feature = "msgpack")]
1242#[cfg(test)]
1243mod test_msgpack {
1244    use crate::TestServer;
1245    use axum::Router;
1246    use axum::routing::post;
1247    use axum_msgpack::MsgPack;
1248    use http::HeaderMap;
1249    use http::header::CONTENT_TYPE;
1250    use serde::Deserialize;
1251    use serde::Serialize;
1252    use serde_json::json;
1253
1254    #[tokio::test]
1255    async fn it_should_pass_msgpack_up_to_be_read() {
1256        #[derive(Deserialize, Serialize)]
1257        struct TestMsgPack {
1258            name: String,
1259            age: u32,
1260            pets: Option<String>,
1261        }
1262
1263        async fn get_msgpack(MsgPack(msgpack): MsgPack<TestMsgPack>) -> String {
1264            format!(
1265                "yaml: {}, {}, {}",
1266                msgpack.name,
1267                msgpack.age,
1268                msgpack.pets.unwrap_or_else(|| "pandas".to_string())
1269            )
1270        }
1271
1272        // Build an application with a route.
1273        let app = Router::new().route("/msgpack", post(get_msgpack));
1274
1275        // Run the server.
1276        let server = TestServer::new(app);
1277
1278        // Get the request.
1279        let text = server
1280            .post(&"/msgpack")
1281            .msgpack(&TestMsgPack {
1282                name: "Joe".to_string(),
1283                age: 20,
1284                pets: Some("foxes".to_string()),
1285            })
1286            .await
1287            .text();
1288
1289        assert_eq!(text, "yaml: Joe, 20, foxes");
1290    }
1291
1292    #[tokio::test]
1293    async fn it_should_pass_msgpck_content_type_for_msgpack() {
1294        async fn get_content_type(headers: HeaderMap) -> String {
1295            headers
1296                .get(CONTENT_TYPE)
1297                .map(|h| h.to_str().unwrap().to_string())
1298                .unwrap_or_else(|| "".to_string())
1299        }
1300
1301        // Build an application with a route.
1302        let app = Router::new().route("/content_type", post(get_content_type));
1303
1304        // Run the server.
1305        let server = TestServer::new(app);
1306
1307        // Get the request.
1308        let text = server
1309            .post(&"/content_type")
1310            .msgpack(&json!({}))
1311            .await
1312            .text();
1313
1314        assert_eq!(text, "application/msgpack");
1315    }
1316}
1317
1318#[cfg(test)]
1319mod test_form {
1320    use crate::TestServer;
1321    use axum::Form;
1322    use axum::Router;
1323    use axum::routing::post;
1324    use http::HeaderMap;
1325    use http::header::CONTENT_TYPE;
1326    use serde::Deserialize;
1327    use serde::Serialize;
1328
1329    #[tokio::test]
1330    async fn it_should_pass_form_up_to_be_read() {
1331        #[derive(Deserialize, Serialize)]
1332        struct TestForm {
1333            name: String,
1334            age: u32,
1335            pets: Option<String>,
1336        }
1337
1338        async fn get_form(Form(form): Form<TestForm>) -> String {
1339            format!(
1340                "form: {}, {}, {}",
1341                form.name,
1342                form.age,
1343                form.pets.unwrap_or_else(|| "pandas".to_string())
1344            )
1345        }
1346
1347        // Build an application with a route.
1348        let app = Router::new().route("/form", post(get_form));
1349
1350        // Run the server.
1351        let server = TestServer::new(app);
1352
1353        // Get the request.
1354        server
1355            .post(&"/form")
1356            .form(&TestForm {
1357                name: "Joe".to_string(),
1358                age: 20,
1359                pets: Some("foxes".to_string()),
1360            })
1361            .await
1362            .assert_text("form: Joe, 20, foxes");
1363    }
1364
1365    #[tokio::test]
1366    async fn it_should_pass_form_content_type_for_form() {
1367        async fn get_content_type(headers: HeaderMap) -> String {
1368            headers
1369                .get(CONTENT_TYPE)
1370                .map(|h| h.to_str().unwrap().to_string())
1371                .unwrap_or_else(|| "".to_string())
1372        }
1373
1374        // Build an application with a route.
1375        let app = Router::new().route("/content_type", post(get_content_type));
1376
1377        // Run the server.
1378        let server = TestServer::new(app);
1379
1380        #[derive(Serialize)]
1381        struct MyForm {
1382            message: String,
1383        }
1384
1385        // Get the request.
1386        server
1387            .post(&"/content_type")
1388            .form(&MyForm {
1389                message: "hello".to_string(),
1390            })
1391            .await
1392            .assert_text("application/x-www-form-urlencoded");
1393    }
1394}
1395
1396#[cfg(test)]
1397mod test_bytes {
1398    use crate::TestServer;
1399    use axum::Router;
1400    use axum::extract::Request;
1401    use axum::routing::post;
1402    use http::HeaderMap;
1403    use http::header::CONTENT_TYPE;
1404    use http_body_util::BodyExt;
1405
1406    #[tokio::test]
1407    async fn it_should_pass_bytes_up_to_be_read() {
1408        // Build an application with a route.
1409        let app = Router::new().route(
1410            "/bytes",
1411            post(|request: Request| async move {
1412                let body_bytes = request
1413                    .into_body()
1414                    .collect()
1415                    .await
1416                    .expect("Should read body to bytes")
1417                    .to_bytes();
1418
1419                String::from_utf8_lossy(&body_bytes).to_string()
1420            }),
1421        );
1422
1423        // Run the server.
1424        let server = TestServer::new(app);
1425
1426        // Get the request.
1427        let text = server
1428            .post(&"/bytes")
1429            .bytes("hello!".as_bytes().into())
1430            .await
1431            .text();
1432
1433        assert_eq!(text, "hello!");
1434    }
1435
1436    #[tokio::test]
1437    async fn it_should_not_change_content_type() {
1438        let app = Router::new().route(
1439            "/content_type",
1440            post(|headers: HeaderMap| async move {
1441                headers
1442                    .get(CONTENT_TYPE)
1443                    .map(|h| h.to_str().unwrap().to_string())
1444                    .unwrap_or_else(|| "".to_string())
1445            }),
1446        );
1447
1448        // Run the server.
1449        let server = TestServer::new(app);
1450
1451        // Get the request.
1452        let text = server
1453            .post(&"/content_type")
1454            .content_type(&"application/testing")
1455            .bytes("hello!".as_bytes().into())
1456            .await
1457            .text();
1458
1459        assert_eq!(text, "application/testing");
1460    }
1461}
1462
1463#[cfg(test)]
1464mod test_bytes_from_file {
1465    use crate::TestServer;
1466    use axum::Router;
1467    use axum::extract::Request;
1468    use axum::routing::post;
1469    use http::HeaderMap;
1470    use http::header::CONTENT_TYPE;
1471    use http_body_util::BodyExt;
1472
1473    #[tokio::test]
1474    async fn it_should_pass_bytes_up_to_be_read() {
1475        // Build an application with a route.
1476        let app = Router::new().route(
1477            "/bytes",
1478            post(|request: Request| async move {
1479                let body_bytes = request
1480                    .into_body()
1481                    .collect()
1482                    .await
1483                    .expect("Should read body to bytes")
1484                    .to_bytes();
1485
1486                String::from_utf8_lossy(&body_bytes).to_string()
1487            }),
1488        );
1489
1490        // Run the server.
1491        let server = TestServer::new(app);
1492
1493        // Get the request.
1494        let text = server
1495            .post(&"/bytes")
1496            .bytes_from_file(&"files/example.txt")
1497            .await
1498            .text();
1499
1500        assert_eq!(text, "hello!");
1501    }
1502
1503    #[tokio::test]
1504    async fn it_should_not_change_content_type() {
1505        let app = Router::new().route(
1506            "/content_type",
1507            post(|headers: HeaderMap| async move {
1508                headers
1509                    .get(CONTENT_TYPE)
1510                    .map(|h| h.to_str().unwrap().to_string())
1511                    .unwrap_or_else(|| "".to_string())
1512            }),
1513        );
1514
1515        // Run the server.
1516        let server = TestServer::new(app);
1517
1518        // Get the request.
1519        let text = server
1520            .post(&"/content_type")
1521            .content_type(&"application/testing")
1522            .bytes_from_file(&"files/example.txt")
1523            .await
1524            .text();
1525
1526        assert_eq!(text, "application/testing");
1527    }
1528}
1529
1530#[cfg(test)]
1531mod test_text {
1532    use crate::TestServer;
1533    use axum::Router;
1534    use axum::extract::Request;
1535    use axum::routing::post;
1536    use http::HeaderMap;
1537    use http::header::CONTENT_TYPE;
1538    use http_body_util::BodyExt;
1539
1540    #[tokio::test]
1541    async fn it_should_pass_text_up_to_be_read() {
1542        // Build an application with a route.
1543        let app = Router::new().route(
1544            "/text",
1545            post(|request: Request| async move {
1546                let body_bytes = request
1547                    .into_body()
1548                    .collect()
1549                    .await
1550                    .expect("Should read body to bytes")
1551                    .to_bytes();
1552
1553                String::from_utf8_lossy(&body_bytes).to_string()
1554            }),
1555        );
1556
1557        // Run the server.
1558        let server = TestServer::new(app);
1559
1560        // Get the request.
1561        let text = server.post(&"/text").text(&"hello!").await.text();
1562
1563        assert_eq!(text, "hello!");
1564    }
1565
1566    #[tokio::test]
1567    async fn it_should_pass_text_content_type_for_text() {
1568        let app = Router::new().route(
1569            "/content_type",
1570            post(|headers: HeaderMap| async move {
1571                headers
1572                    .get(CONTENT_TYPE)
1573                    .map(|h| h.to_str().unwrap().to_string())
1574                    .unwrap_or_else(|| "".to_string())
1575            }),
1576        );
1577
1578        // Run the server.
1579        let server = TestServer::new(app);
1580
1581        // Get the request.
1582        let text = server.post(&"/content_type").text(&"hello!").await.text();
1583
1584        assert_eq!(text, "text/plain");
1585    }
1586
1587    #[tokio::test]
1588    async fn it_should_pass_large_text_blobs_over_mock_http() {
1589        const LARGE_BLOB_SIZE: usize = 16777216; // 16mb
1590        let large_blob = (0..LARGE_BLOB_SIZE).map(|_| "X").collect::<String>();
1591
1592        // Build an application with a route.
1593        let app = Router::new().route(
1594            "/text",
1595            post(|request: Request| async move {
1596                let body_bytes = request
1597                    .into_body()
1598                    .collect()
1599                    .await
1600                    .expect("Should read body to bytes")
1601                    .to_bytes();
1602
1603                String::from_utf8_lossy(&body_bytes).to_string()
1604            }),
1605        );
1606
1607        // Run the server.
1608        let server = TestServer::builder().mock_transport().build(app);
1609
1610        // Get the request.
1611        let text = server.post(&"/text").text(&large_blob).await.text();
1612
1613        assert_eq!(text.len(), LARGE_BLOB_SIZE);
1614        assert_eq!(text, large_blob);
1615    }
1616
1617    #[tokio::test]
1618    async fn it_should_pass_large_text_blobs_over_http() {
1619        const LARGE_BLOB_SIZE: usize = 16777216; // 16mb
1620        let large_blob = (0..LARGE_BLOB_SIZE).map(|_| "X").collect::<String>();
1621
1622        // Build an application with a route.
1623        let app = Router::new().route(
1624            "/text",
1625            post(|request: Request| async move {
1626                let body_bytes = request
1627                    .into_body()
1628                    .collect()
1629                    .await
1630                    .expect("Should read body to bytes")
1631                    .to_bytes();
1632
1633                String::from_utf8_lossy(&body_bytes).to_string()
1634            }),
1635        );
1636
1637        // Run the server.
1638        let server = TestServer::builder().http_transport().build(app);
1639
1640        // Get the request.
1641        let text = server.post(&"/text").text(&large_blob).await.text();
1642
1643        assert_eq!(text.len(), LARGE_BLOB_SIZE);
1644        assert_eq!(text, large_blob);
1645    }
1646}
1647
1648#[cfg(test)]
1649mod test_text_from_file {
1650    use crate::TestServer;
1651    use axum::Router;
1652    use axum::extract::Request;
1653    use axum::routing::post;
1654    use http::HeaderMap;
1655    use http::header::CONTENT_TYPE;
1656    use http_body_util::BodyExt;
1657
1658    #[tokio::test]
1659    async fn it_should_pass_text_up_to_be_read() {
1660        // Build an application with a route.
1661        let app = Router::new().route(
1662            "/text",
1663            post(|request: Request| async move {
1664                let body_bytes = request
1665                    .into_body()
1666                    .collect()
1667                    .await
1668                    .expect("Should read body to bytes")
1669                    .to_bytes();
1670
1671                String::from_utf8_lossy(&body_bytes).to_string()
1672            }),
1673        );
1674
1675        // Run the server.
1676        let server = TestServer::new(app);
1677
1678        // Get the request.
1679        let text = server
1680            .post(&"/text")
1681            .text_from_file(&"files/example.txt")
1682            .await
1683            .text();
1684
1685        assert_eq!(text, "hello!");
1686    }
1687
1688    #[tokio::test]
1689    async fn it_should_pass_text_content_type_for_text() {
1690        // Build an application with a route.
1691        let app = Router::new().route(
1692            "/content_type",
1693            post(|headers: HeaderMap| async move {
1694                headers
1695                    .get(CONTENT_TYPE)
1696                    .map(|h| h.to_str().unwrap().to_string())
1697                    .unwrap_or_else(|| "".to_string())
1698            }),
1699        );
1700
1701        // Run the server.
1702        let server = TestServer::new(app);
1703
1704        // Get the request.
1705        let text = server
1706            .post(&"/content_type")
1707            .text_from_file(&"files/example.txt")
1708            .await
1709            .text();
1710
1711        assert_eq!(text, "text/plain");
1712    }
1713}
1714
1715#[cfg(test)]
1716mod test_expect_success {
1717    use crate::TestServer;
1718    use crate::testing::catch_panic_error_message_async;
1719    use axum::Router;
1720    use axum::routing::get;
1721    use http::StatusCode;
1722    use pretty_assertions::assert_str_eq;
1723
1724    #[tokio::test]
1725    async fn it_should_not_panic_if_success_is_returned() {
1726        async fn get_ping() -> &'static str {
1727            "pong!"
1728        }
1729
1730        // Build an application with a route.
1731        let app = Router::new().route("/ping", get(get_ping));
1732
1733        // Run the server.
1734        let server = TestServer::new(app);
1735
1736        // Get the request.
1737        server.get(&"/ping").expect_success().await;
1738    }
1739
1740    #[tokio::test]
1741    async fn it_should_not_panic_on_other_2xx_status_code() {
1742        async fn get_accepted() -> StatusCode {
1743            StatusCode::ACCEPTED
1744        }
1745
1746        // Build an application with a route.
1747        let app = Router::new().route("/accepted", get(get_accepted));
1748
1749        // Run the server.
1750        let server = TestServer::new(app);
1751
1752        // Get the request.
1753        server.get(&"/accepted").expect_success().await;
1754    }
1755
1756    #[tokio::test]
1757    async fn it_should_panic_on_404() {
1758        // Build an application with a route.
1759        let app = Router::new();
1760
1761        // Run the server.
1762        let server = TestServer::new(app);
1763
1764        // Get the request.
1765        let message =
1766            catch_panic_error_message_async(server.get(&"/some_unknown_route").expect_success())
1767                .await;
1768        assert_str_eq!(
1769            "Expect status code within 2xx range, received 404 (Not Found), for request GET http://localhost/some_unknown_route, with body ''",
1770            message
1771        );
1772    }
1773
1774    #[tokio::test]
1775    async fn it_should_override_what_test_server_has_set() {
1776        async fn get_ping() -> &'static str {
1777            "pong!"
1778        }
1779
1780        // Build an application with a route.
1781        let app = Router::new().route("/ping", get(get_ping));
1782
1783        // Run the server.
1784        let mut server = TestServer::new(app);
1785        server.expect_failure();
1786
1787        // Get the request.
1788        server.get(&"/ping").expect_success().await;
1789    }
1790}
1791
1792#[cfg(test)]
1793mod test_expect_failure {
1794    use crate::TestServer;
1795    use crate::testing::catch_panic_error_message_async;
1796    use axum::Router;
1797    use axum::routing::get;
1798    use http::StatusCode;
1799    use pretty_assertions::assert_str_eq;
1800
1801    #[tokio::test]
1802    async fn it_should_not_panic_if_expect_failure_on_404() {
1803        // Build an application with a route.
1804        let app = Router::new();
1805
1806        // Run the server.
1807        let server = TestServer::new(app);
1808
1809        // Get the request.
1810        server.get(&"/some_unknown_route").expect_failure().await;
1811    }
1812
1813    #[tokio::test]
1814    async fn it_should_panic_if_success_is_returned() {
1815        async fn get_ping() -> &'static str {
1816            "pong!"
1817        }
1818
1819        // Build an application with a route.
1820        let app = Router::new().route("/ping", get(get_ping));
1821
1822        // Run the server.
1823        let server = TestServer::new(app);
1824
1825        // Get the request.
1826        let message = catch_panic_error_message_async(server.get(&"/ping").expect_failure()).await;
1827        assert_str_eq!(
1828            "Expect status code outside 2xx range, received 200 (OK), for request GET http://localhost/ping, with body 'pong!'",
1829            message
1830        );
1831    }
1832
1833    #[tokio::test]
1834    async fn it_should_panic_on_other_2xx_status_code() {
1835        async fn get_accepted() -> StatusCode {
1836            StatusCode::ACCEPTED
1837        }
1838
1839        // Build an application with a route.
1840        let app = Router::new().route("/accepted", get(get_accepted));
1841
1842        // Run the server.
1843        let server = TestServer::new(app);
1844
1845        // Get the request.
1846        let message =
1847            catch_panic_error_message_async(server.get(&"/accepted").expect_failure()).await;
1848        assert_str_eq!(
1849            "Expect status code outside 2xx range, received 202 (Accepted), for request GET http://localhost/accepted, with body ''",
1850            message
1851        );
1852    }
1853
1854    #[tokio::test]
1855    async fn it_should_should_override_what_test_server_has_set() {
1856        // Build an application with a route.
1857        let app = Router::new();
1858
1859        // Run the server.
1860        let mut server = TestServer::new(app);
1861        server.expect_success();
1862
1863        // Get the request.
1864        server.get(&"/some_unknown_route").expect_failure().await;
1865    }
1866}
1867
1868#[cfg(test)]
1869mod test_add_cookie {
1870    use crate::TestServer;
1871    use axum::Router;
1872    use axum::routing::get;
1873    use axum_extra::extract::cookie::CookieJar;
1874    use cookie::Cookie;
1875    use cookie::time::Duration;
1876    use cookie::time::OffsetDateTime;
1877
1878    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
1879
1880    async fn get_cookie(cookies: CookieJar) -> (CookieJar, String) {
1881        let cookie = cookies.get(&TEST_COOKIE_NAME);
1882        let cookie_value = cookie
1883            .map(|c| c.value().to_string())
1884            .unwrap_or_else(|| "cookie-not-found".to_string());
1885
1886        (cookies, cookie_value)
1887    }
1888
1889    #[tokio::test]
1890    async fn it_should_send_cookies_added_to_request() {
1891        let app = Router::new().route("/cookie", get(get_cookie));
1892        let server = TestServer::new(app);
1893
1894        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1895        let response_text = server.get(&"/cookie").add_cookie(cookie).await.text();
1896        assert_eq!(response_text, "my-custom-cookie");
1897    }
1898
1899    #[tokio::test]
1900    async fn it_should_send_non_expired_cookies_added_to_request() {
1901        let app = Router::new().route("/cookie", get(get_cookie));
1902        let server = TestServer::new(app);
1903
1904        let mut cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1905        cookie.set_expires(
1906            OffsetDateTime::now_utc()
1907                .checked_add(Duration::minutes(10))
1908                .unwrap(),
1909        );
1910        let response_text = server.get(&"/cookie").add_cookie(cookie).await.text();
1911        assert_eq!(response_text, "my-custom-cookie");
1912    }
1913
1914    #[tokio::test]
1915    async fn it_should_not_send_expired_cookies_added_to_request() {
1916        let app = Router::new().route("/cookie", get(get_cookie));
1917        let server = TestServer::new(app);
1918
1919        let mut cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1920        cookie.set_expires(OffsetDateTime::now_utc());
1921        let response_text = server.get(&"/cookie").add_cookie(cookie).await.text();
1922        assert_eq!(response_text, "cookie-not-found");
1923    }
1924}
1925
1926#[cfg(test)]
1927mod test_add_cookies {
1928    use crate::TestServer;
1929    use axum::Router;
1930    use axum::http::header::HeaderMap;
1931    use axum::routing::get;
1932    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1933    use cookie::Cookie;
1934    use cookie::CookieJar;
1935    use cookie::SameSite;
1936
1937    async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1938        let mut all_cookies = cookies
1939            .iter()
1940            .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1941            .collect::<Vec<String>>();
1942        all_cookies.sort();
1943
1944        all_cookies.join(&", ")
1945    }
1946
1947    async fn get_cookie_headers_joined(headers: HeaderMap) -> String {
1948        let cookies: String = headers
1949            .get_all("cookie")
1950            .into_iter()
1951            .map(|c| c.to_str().unwrap_or("").to_string())
1952            .reduce(|a, b| a + "; " + &b)
1953            .unwrap_or_else(|| String::new());
1954
1955        cookies
1956    }
1957
1958    #[tokio::test]
1959    async fn it_should_send_all_cookies_added_by_jar() {
1960        let app = Router::new().route("/cookies", get(route_get_cookies));
1961        let server = TestServer::new(app);
1962
1963        // Build cookies to send up
1964        let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1965        let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1966        let mut cookie_jar = CookieJar::new();
1967        cookie_jar.add(cookie_1);
1968        cookie_jar.add(cookie_2);
1969
1970        server
1971            .get(&"/cookies")
1972            .add_cookies(cookie_jar)
1973            .await
1974            .assert_text("first-cookie=my-custom-cookie, second-cookie=other-cookie");
1975    }
1976
1977    #[tokio::test]
1978    async fn it_should_send_all_cookies_stripped_by_their_attributes() {
1979        let app = Router::new().route("/cookies", get(get_cookie_headers_joined));
1980        let server = TestServer::new(app);
1981
1982        const TEST_COOKIE_NAME: &'static str = &"test-cookie";
1983        const TEST_COOKIE_VALUE: &'static str = &"my-custom-cookie";
1984
1985        // Build cookie to send up
1986        let cookie = Cookie::build((TEST_COOKIE_NAME, TEST_COOKIE_VALUE))
1987            .http_only(true)
1988            .secure(true)
1989            .same_site(SameSite::Strict)
1990            .path("/cookie")
1991            .build();
1992        let mut cookie_jar = CookieJar::new();
1993        cookie_jar.add(cookie);
1994
1995        server
1996            .get(&"/cookies")
1997            .add_cookies(cookie_jar)
1998            .await
1999            .assert_text(format!("{}={}", TEST_COOKIE_NAME, TEST_COOKIE_VALUE));
2000    }
2001}
2002
2003#[cfg(test)]
2004mod test_save_cookies {
2005    use crate::TestServer;
2006    use axum::Router;
2007    use axum::extract::Request;
2008    use axum::http::header::HeaderMap;
2009    use axum::routing::get;
2010    use axum::routing::put;
2011    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
2012    use cookie::Cookie;
2013    use cookie::SameSite;
2014    use http_body_util::BodyExt;
2015
2016    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
2017
2018    #[tokio::test]
2019    async fn it_should_save_cookies_across_requests_when_enabled() {
2020        let server = TestServer::new(app());
2021
2022        // Create a cookie.
2023        server
2024            .put(&"/cookie")
2025            .text(&"cookie-found!")
2026            .save_cookies()
2027            .await;
2028
2029        // Check, only the cookie names and their values should come back.
2030        let response_text = server.get(&"/cookie").await.text();
2031
2032        assert_eq!(response_text, format!("{}=cookie-found!", TEST_COOKIE_NAME));
2033    }
2034
2035    fn app() -> Router {
2036        async fn put_cookie_with_attributes(
2037            mut cookies: AxumCookieJar,
2038            request: Request,
2039        ) -> (AxumCookieJar, &'static str) {
2040            let body_bytes = request
2041                .into_body()
2042                .collect()
2043                .await
2044                .expect("Should turn the body into bytes")
2045                .to_bytes();
2046
2047            let body_text: String = String::from_utf8_lossy(&body_bytes).to_string();
2048            let cookie = Cookie::build((TEST_COOKIE_NAME, body_text))
2049                .http_only(true)
2050                .secure(true)
2051                .same_site(SameSite::Strict)
2052                .path("/cookie")
2053                .build();
2054            cookies = cookies.add(cookie);
2055
2056            (cookies, &"done")
2057        }
2058
2059        async fn get_cookie_headers_joined(headers: HeaderMap) -> String {
2060            let cookies: String = headers
2061                .get_all("cookie")
2062                .into_iter()
2063                .map(|c| c.to_str().unwrap_or("").to_string())
2064                .reduce(|a, b| a + "; " + &b)
2065                .unwrap_or_else(|| String::new());
2066
2067            cookies
2068        }
2069
2070        Router::new()
2071            .route("/cookie", put(put_cookie_with_attributes))
2072            .route("/cookie", get(get_cookie_headers_joined))
2073    }
2074}
2075
2076#[cfg(test)]
2077mod test_do_not_save_cookies {
2078    use crate::TestServer;
2079    use axum::Router;
2080    use axum::extract::Request;
2081    use axum::http::header::HeaderMap;
2082    use axum::routing::get;
2083    use axum::routing::put;
2084    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
2085    use cookie::Cookie;
2086    use cookie::SameSite;
2087    use http_body_util::BodyExt;
2088
2089    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
2090
2091    async fn put_cookie_with_attributes(
2092        mut cookies: AxumCookieJar,
2093        request: Request,
2094    ) -> (AxumCookieJar, &'static str) {
2095        let body_bytes = request
2096            .into_body()
2097            .collect()
2098            .await
2099            .expect("Should turn the body into bytes")
2100            .to_bytes();
2101
2102        let body_text: String = String::from_utf8_lossy(&body_bytes).to_string();
2103        let cookie = Cookie::build((TEST_COOKIE_NAME, body_text))
2104            .http_only(true)
2105            .secure(true)
2106            .same_site(SameSite::Strict)
2107            .path("/cookie")
2108            .build();
2109        cookies = cookies.add(cookie);
2110
2111        (cookies, &"done")
2112    }
2113
2114    async fn get_cookie_headers_joined(headers: HeaderMap) -> String {
2115        let cookies: String = headers
2116            .get_all("cookie")
2117            .into_iter()
2118            .map(|c| c.to_str().unwrap_or("").to_string())
2119            .reduce(|a, b| a + "; " + &b)
2120            .unwrap_or_else(|| String::new());
2121
2122        cookies
2123    }
2124
2125    #[tokio::test]
2126    async fn it_should_not_save_cookies_when_set() {
2127        let app = Router::new()
2128            .route("/cookie", put(put_cookie_with_attributes))
2129            .route("/cookie", get(get_cookie_headers_joined));
2130        let server = TestServer::new(app);
2131
2132        // Create a cookie.
2133        server
2134            .put(&"/cookie")
2135            .text(&"cookie-found!")
2136            .do_not_save_cookies()
2137            .await;
2138
2139        // Check, only the cookie names and their values should come back.
2140        let response_text = server.get(&"/cookie").await.text();
2141
2142        assert_eq!(response_text, "");
2143    }
2144
2145    #[tokio::test]
2146    async fn it_should_override_test_server_and_not_save_cookies_when_set() {
2147        let app = Router::new()
2148            .route("/cookie", put(put_cookie_with_attributes))
2149            .route("/cookie", get(get_cookie_headers_joined));
2150        let server = TestServer::builder().save_cookies().build(app);
2151
2152        // Create a cookie.
2153        server
2154            .put(&"/cookie")
2155            .text(&"cookie-found!")
2156            .do_not_save_cookies()
2157            .await;
2158
2159        // Check, only the cookie names and their values should come back.
2160        let response_text = server.get(&"/cookie").await.text();
2161
2162        assert_eq!(response_text, "");
2163    }
2164}
2165
2166#[cfg(test)]
2167mod test_clear_cookies {
2168    use crate::TestServer;
2169    use axum::Router;
2170    use axum::extract::Request;
2171    use axum::routing::get;
2172    use axum::routing::put;
2173    use axum_extra::extract::cookie::Cookie as AxumCookie;
2174    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
2175    use cookie::Cookie;
2176    use cookie::CookieJar;
2177    use http_body_util::BodyExt;
2178
2179    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
2180
2181    async fn get_cookie(cookies: AxumCookieJar) -> (AxumCookieJar, String) {
2182        let cookie = cookies.get(&TEST_COOKIE_NAME);
2183        let cookie_value = cookie
2184            .map(|c| c.value().to_string())
2185            .unwrap_or_else(|| "cookie-not-found".to_string());
2186
2187        (cookies, cookie_value)
2188    }
2189
2190    async fn put_cookie(
2191        mut cookies: AxumCookieJar,
2192        request: Request,
2193    ) -> (AxumCookieJar, &'static str) {
2194        let body_bytes = request
2195            .into_body()
2196            .collect()
2197            .await
2198            .expect("Should turn the body into bytes")
2199            .to_bytes();
2200
2201        let body_text: String = String::from_utf8_lossy(&body_bytes).to_string();
2202        let cookie = AxumCookie::new(TEST_COOKIE_NAME, body_text);
2203        cookies = cookies.add(cookie);
2204
2205        (cookies, &"done")
2206    }
2207
2208    #[tokio::test]
2209    async fn it_should_clear_cookie_added_to_request() {
2210        let app = Router::new().route("/cookie", get(get_cookie));
2211        let server = TestServer::new(app);
2212
2213        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
2214        let response_text = server
2215            .get(&"/cookie")
2216            .add_cookie(cookie)
2217            .clear_cookies()
2218            .await
2219            .text();
2220
2221        assert_eq!(response_text, "cookie-not-found");
2222    }
2223
2224    #[tokio::test]
2225    async fn it_should_clear_cookie_jar_added_to_request() {
2226        let app = Router::new().route("/cookie", get(get_cookie));
2227        let server = TestServer::new(app);
2228
2229        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
2230        let mut cookie_jar = CookieJar::new();
2231        cookie_jar.add(cookie);
2232
2233        let response_text = server
2234            .get(&"/cookie")
2235            .add_cookies(cookie_jar)
2236            .clear_cookies()
2237            .await
2238            .text();
2239
2240        assert_eq!(response_text, "cookie-not-found");
2241    }
2242
2243    #[tokio::test]
2244    async fn it_should_clear_cookies_saved_by_past_request() {
2245        let app = Router::new()
2246            .route("/cookie", put(put_cookie))
2247            .route("/cookie", get(get_cookie));
2248        let server = TestServer::new(app);
2249
2250        // Create a cookie.
2251        server
2252            .put(&"/cookie")
2253            .text(&"cookie-found!")
2254            .save_cookies()
2255            .await;
2256
2257        // Check it comes back.
2258        let response_text = server.get(&"/cookie").clear_cookies().await.text();
2259
2260        assert_eq!(response_text, "cookie-not-found");
2261    }
2262
2263    #[tokio::test]
2264    async fn it_should_clear_cookies_added_to_test_server() {
2265        let app = Router::new()
2266            .route("/cookie", put(put_cookie))
2267            .route("/cookie", get(get_cookie));
2268        let mut server = TestServer::new(app);
2269
2270        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
2271        server.add_cookie(cookie);
2272
2273        // Check it comes back.
2274        let response_text = server.get(&"/cookie").clear_cookies().await.text();
2275
2276        assert_eq!(response_text, "cookie-not-found");
2277    }
2278}
2279
2280#[cfg(test)]
2281mod test_add_header {
2282    use super::*;
2283    use crate::TestServer;
2284    use axum::Router;
2285    use axum::extract::FromRequestParts;
2286    use axum::routing::get;
2287    use http::HeaderName;
2288    use http::HeaderValue;
2289    use http::request::Parts;
2290    use hyper::StatusCode;
2291    use std::marker::Sync;
2292
2293    const TEST_HEADER_NAME: &'static str = &"test-header";
2294    const TEST_HEADER_CONTENT: &'static str = &"Test header content";
2295
2296    struct TestHeader(Vec<u8>);
2297
2298    impl<S: Sync> FromRequestParts<S> for TestHeader {
2299        type Rejection = (StatusCode, &'static str);
2300
2301        async fn from_request_parts(
2302            parts: &mut Parts,
2303            _state: &S,
2304        ) -> Result<TestHeader, Self::Rejection> {
2305            parts
2306                .headers
2307                .get(HeaderName::from_static(TEST_HEADER_NAME))
2308                .map(|v| TestHeader(v.as_bytes().to_vec()))
2309                .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
2310        }
2311    }
2312
2313    async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
2314        header
2315    }
2316
2317    #[tokio::test]
2318    async fn it_should_send_header_added_to_request() {
2319        // Build an application with a route.
2320        let app = Router::new().route("/header", get(ping_header));
2321
2322        // Run the server.
2323        let server = TestServer::new(app);
2324
2325        // Send a request with the header
2326        server
2327            .get(&"/header")
2328            .add_header(
2329                HeaderName::from_static(TEST_HEADER_NAME),
2330                HeaderValue::from_static(TEST_HEADER_CONTENT),
2331            )
2332            .await
2333            // Check it sent back the right text
2334            .assert_text(TEST_HEADER_CONTENT);
2335    }
2336}
2337
2338#[cfg(test)]
2339mod test_authorization {
2340    use super::*;
2341    use crate::TestServer;
2342    use axum::Router;
2343    use axum::extract::FromRequestParts;
2344    use axum::routing::get;
2345    use http::request::Parts;
2346    use hyper::StatusCode;
2347    use std::marker::Sync;
2348
2349    fn new_test_server() -> TestServer {
2350        struct TestHeader(String);
2351
2352        impl<S: Sync> FromRequestParts<S> for TestHeader {
2353            type Rejection = (StatusCode, &'static str);
2354
2355            async fn from_request_parts(
2356                parts: &mut Parts,
2357                _state: &S,
2358            ) -> Result<TestHeader, Self::Rejection> {
2359                parts
2360                    .headers
2361                    .get(header::AUTHORIZATION)
2362                    .map(|v| TestHeader(v.to_str().unwrap().to_string()))
2363                    .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
2364            }
2365        }
2366
2367        async fn ping_auth_header(TestHeader(header): TestHeader) -> String {
2368            header
2369        }
2370
2371        // Build an application with a route.
2372        let app = Router::new().route("/auth-header", get(ping_auth_header));
2373
2374        // Run the server.
2375        let mut server = TestServer::new(app);
2376        server.expect_success();
2377
2378        server
2379    }
2380
2381    #[tokio::test]
2382    async fn it_should_send_header_added_to_request() {
2383        let server = new_test_server();
2384
2385        // Send a request with the header
2386        server
2387            .get(&"/auth-header")
2388            .authorization("Bearer abc123")
2389            .await
2390            // Check it sent back the right text
2391            .assert_text("Bearer abc123");
2392    }
2393}
2394
2395#[cfg(test)]
2396mod test_authorization_bearer {
2397    use super::*;
2398    use crate::TestServer;
2399    use axum::Router;
2400    use axum::extract::FromRequestParts;
2401    use axum::routing::get;
2402    use http::request::Parts;
2403    use hyper::StatusCode;
2404    use std::marker::Sync;
2405
2406    fn new_test_server() -> TestServer {
2407        struct TestHeader(String);
2408
2409        impl<S: Sync> FromRequestParts<S> for TestHeader {
2410            type Rejection = (StatusCode, &'static str);
2411
2412            async fn from_request_parts(
2413                parts: &mut Parts,
2414                _state: &S,
2415            ) -> Result<TestHeader, Self::Rejection> {
2416                parts
2417                    .headers
2418                    .get(header::AUTHORIZATION)
2419                    .map(|v| TestHeader(v.to_str().unwrap().to_string().replace("Bearer ", "")))
2420                    .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
2421            }
2422        }
2423
2424        async fn ping_auth_header(TestHeader(header): TestHeader) -> String {
2425            header
2426        }
2427
2428        // Build an application with a route.
2429        let app = Router::new().route("/auth-header", get(ping_auth_header));
2430
2431        // Run the server.
2432        let mut server = TestServer::new(app);
2433        server.expect_success();
2434
2435        server
2436    }
2437
2438    #[tokio::test]
2439    async fn it_should_send_header_added_to_request() {
2440        let server = new_test_server();
2441
2442        // Send a request with the header
2443        server
2444            .get(&"/auth-header")
2445            .authorization_bearer("abc123")
2446            .await
2447            // Check it sent back the right text
2448            .assert_text("abc123");
2449    }
2450}
2451
2452#[cfg(test)]
2453mod test_clear_headers {
2454    use super::*;
2455    use crate::TestServer;
2456    use axum::Router;
2457    use axum::extract::FromRequestParts;
2458    use axum::routing::get;
2459    use http::HeaderName;
2460    use http::HeaderValue;
2461    use http::request::Parts;
2462    use hyper::StatusCode;
2463    use std::marker::Sync;
2464
2465    const TEST_HEADER_NAME: &'static str = &"test-header";
2466    const TEST_HEADER_CONTENT: &'static str = &"Test header content";
2467
2468    struct TestHeader(Vec<u8>);
2469
2470    impl<S: Sync> FromRequestParts<S> for TestHeader {
2471        type Rejection = (StatusCode, &'static str);
2472
2473        async fn from_request_parts(
2474            parts: &mut Parts,
2475            _state: &S,
2476        ) -> Result<TestHeader, Self::Rejection> {
2477            parts
2478                .headers
2479                .get(HeaderName::from_static(TEST_HEADER_NAME))
2480                .map(|v| TestHeader(v.as_bytes().to_vec()))
2481                .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
2482        }
2483    }
2484
2485    async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
2486        header
2487    }
2488
2489    #[tokio::test]
2490    async fn it_should_clear_headers_added_to_request() {
2491        // Build an application with a route.
2492        let app = Router::new().route("/header", get(ping_header));
2493
2494        // Run the server.
2495        let server = TestServer::new(app);
2496
2497        // Send a request with the header
2498        let response = server
2499            .get(&"/header")
2500            .add_header(
2501                HeaderName::from_static(TEST_HEADER_NAME),
2502                HeaderValue::from_static(TEST_HEADER_CONTENT),
2503            )
2504            .clear_headers()
2505            .await;
2506
2507        // Check it sent back the right text
2508        response.assert_status_bad_request();
2509        response.assert_text("Missing test header");
2510    }
2511
2512    #[tokio::test]
2513    async fn it_should_clear_headers_added_to_server() {
2514        // Build an application with a route.
2515        let app = Router::new().route("/header", get(ping_header));
2516
2517        // Run the server.
2518        let mut server = TestServer::new(app);
2519        server.add_header(
2520            HeaderName::from_static(TEST_HEADER_NAME),
2521            HeaderValue::from_static(TEST_HEADER_CONTENT),
2522        );
2523
2524        // Send a request with the header
2525        let response = server.get(&"/header").clear_headers().await;
2526
2527        // Check it sent back the right text
2528        response.assert_status_bad_request();
2529        response.assert_text("Missing test header");
2530    }
2531}
2532
2533#[cfg(test)]
2534mod test_add_query_params {
2535    use crate::TestServer;
2536    use axum::Router;
2537    use axum::extract::Query as AxumStdQuery;
2538    use axum::routing::get;
2539    use serde::Deserialize;
2540    use serde::Serialize;
2541    use serde_json::json;
2542
2543    #[derive(Debug, Deserialize, Serialize)]
2544    struct QueryParam {
2545        message: String,
2546    }
2547
2548    async fn get_query_param(AxumStdQuery(params): AxumStdQuery<QueryParam>) -> String {
2549        params.message
2550    }
2551
2552    #[derive(Debug, Deserialize, Serialize)]
2553    struct QueryParam2 {
2554        message: String,
2555        other: String,
2556    }
2557
2558    async fn get_query_param_2(AxumStdQuery(params): AxumStdQuery<QueryParam2>) -> String {
2559        format!("{}-{}", params.message, params.other)
2560    }
2561
2562    fn build_app() -> Router {
2563        Router::new()
2564            .route("/query", get(get_query_param))
2565            .route("/query-2", get(get_query_param_2))
2566    }
2567
2568    #[tokio::test]
2569    async fn it_should_pass_up_query_params_from_serialization() {
2570        // Run the server.
2571        let server = TestServer::new(build_app());
2572
2573        // Get the request.
2574        server
2575            .get(&"/query")
2576            .add_query_params(QueryParam {
2577                message: "it works".to_string(),
2578            })
2579            .await
2580            .assert_text(&"it works");
2581    }
2582
2583    #[tokio::test]
2584    async fn it_should_pass_up_query_params_from_pairs() {
2585        // Run the server.
2586        let server = TestServer::new(build_app());
2587
2588        // Get the request.
2589        server
2590            .get(&"/query")
2591            .add_query_params(&[("message", "it works")])
2592            .await
2593            .assert_text(&"it works");
2594    }
2595
2596    #[tokio::test]
2597    async fn it_should_pass_up_multiple_query_params_from_multiple_params() {
2598        // Run the server.
2599        let server = TestServer::new(build_app());
2600
2601        // Get the request.
2602        server
2603            .get(&"/query-2")
2604            .add_query_params(&[("message", "it works"), ("other", "yup")])
2605            .await
2606            .assert_text(&"it works-yup");
2607    }
2608
2609    #[tokio::test]
2610    async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
2611        // Run the server.
2612        let server = TestServer::new(build_app());
2613
2614        // Get the request.
2615        server
2616            .get(&"/query-2")
2617            .add_query_params(&[("message", "it works")])
2618            .add_query_params(&[("other", "yup")])
2619            .await
2620            .assert_text(&"it works-yup");
2621    }
2622
2623    #[tokio::test]
2624    async fn it_should_pass_up_multiple_query_params_from_json() {
2625        // Run the server.
2626        let server = TestServer::new(build_app());
2627
2628        // Get the request.
2629        server
2630            .get(&"/query-2")
2631            .add_query_params(json!({
2632                "message": "it works",
2633                "other": "yup"
2634            }))
2635            .await
2636            .assert_text(&"it works-yup");
2637    }
2638}
2639
2640#[cfg(test)]
2641mod test_add_raw_query_param {
2642    use crate::TestServer;
2643    use axum::Router;
2644    use axum::extract::Query as AxumStdQuery;
2645    use axum::routing::get;
2646    use axum_extra::extract::Query as AxumExtraQuery;
2647    use serde::Deserialize;
2648    use serde::Serialize;
2649    use std::fmt::Write;
2650
2651    #[derive(Debug, Deserialize, Serialize)]
2652    struct QueryParam {
2653        message: String,
2654    }
2655
2656    async fn get_query_param(AxumStdQuery(params): AxumStdQuery<QueryParam>) -> String {
2657        params.message
2658    }
2659
2660    #[derive(Debug, Deserialize, Serialize)]
2661    struct QueryParamExtra {
2662        #[serde(default)]
2663        items: Vec<String>,
2664
2665        #[serde(default, rename = "arrs[]")]
2666        arrs: Vec<String>,
2667    }
2668
2669    async fn get_query_param_extra(
2670        AxumExtraQuery(params): AxumExtraQuery<QueryParamExtra>,
2671    ) -> String {
2672        let mut output = String::new();
2673
2674        if params.items.len() > 0 {
2675            write!(output, "{}", params.items.join(", ")).unwrap();
2676        }
2677
2678        if params.arrs.len() > 0 {
2679            write!(output, "{}", params.arrs.join(", ")).unwrap();
2680        }
2681
2682        output
2683    }
2684
2685    fn build_app() -> Router {
2686        Router::new()
2687            .route("/query", get(get_query_param))
2688            .route("/query-extra", get(get_query_param_extra))
2689    }
2690
2691    #[tokio::test]
2692    async fn it_should_pass_up_query_param_as_is() {
2693        // Run the server.
2694        let server = TestServer::new(build_app());
2695
2696        // Get the request.
2697        server
2698            .get(&"/query")
2699            .add_raw_query_param(&"message=it-works")
2700            .await
2701            .assert_text(&"it-works");
2702    }
2703
2704    #[tokio::test]
2705    async fn it_should_pass_up_array_query_params_as_one_string() {
2706        // Run the server.
2707        let server = TestServer::new(build_app());
2708
2709        // Get the request.
2710        server
2711            .get(&"/query-extra")
2712            .add_raw_query_param(&"items=one&items=two&items=three")
2713            .await
2714            .assert_text(&"one, two, three");
2715    }
2716
2717    #[tokio::test]
2718    async fn it_should_pass_up_array_query_params_as_multiple_params() {
2719        // Run the server.
2720        let server = TestServer::new(build_app());
2721
2722        // Get the request.
2723        server
2724            .get(&"/query-extra")
2725            .add_raw_query_param(&"arrs[]=one")
2726            .add_raw_query_param(&"arrs[]=two")
2727            .add_raw_query_param(&"arrs[]=three")
2728            .await
2729            .assert_text(&"one, two, three");
2730    }
2731}
2732
2733#[cfg(test)]
2734mod test_add_query_param {
2735    use crate::TestServer;
2736    use axum::Router;
2737    use axum::extract::Query;
2738    use axum::routing::get;
2739    use serde::Deserialize;
2740    use serde::Serialize;
2741
2742    #[derive(Debug, Deserialize, Serialize)]
2743    struct QueryParam {
2744        message: String,
2745    }
2746
2747    async fn get_query_param(Query(params): Query<QueryParam>) -> String {
2748        params.message
2749    }
2750
2751    #[derive(Debug, Deserialize, Serialize)]
2752    struct QueryParam2 {
2753        message: String,
2754        other: String,
2755    }
2756
2757    async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
2758        format!("{}-{}", params.message, params.other)
2759    }
2760
2761    #[tokio::test]
2762    async fn it_should_pass_up_query_params_from_pairs() {
2763        // Build an application with a route.
2764        let app = Router::new().route("/query", get(get_query_param));
2765
2766        // Run the server.
2767        let server = TestServer::new(app);
2768
2769        // Get the request.
2770        server
2771            .get(&"/query")
2772            .add_query_param("message", "it works")
2773            .await
2774            .assert_text(&"it works");
2775    }
2776
2777    #[tokio::test]
2778    async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
2779        // Build an application with a route.
2780        let app = Router::new().route("/query-2", get(get_query_param_2));
2781
2782        // Run the server.
2783        let server = TestServer::new(app);
2784
2785        // Get the request.
2786        server
2787            .get(&"/query-2")
2788            .add_query_param("message", "it works")
2789            .add_query_param("other", "yup")
2790            .await
2791            .assert_text(&"it works-yup");
2792    }
2793}
2794
2795#[cfg(test)]
2796mod test_clear_query_params {
2797    use crate::TestServer;
2798    use axum::Router;
2799    use axum::extract::Query;
2800    use axum::routing::get;
2801    use serde::Deserialize;
2802    use serde::Serialize;
2803
2804    #[derive(Debug, Deserialize, Serialize)]
2805    struct QueryParams {
2806        first: Option<String>,
2807        second: Option<String>,
2808    }
2809
2810    async fn get_query_params(Query(params): Query<QueryParams>) -> String {
2811        format!(
2812            "has first? {}, has second? {}",
2813            params.first.is_some(),
2814            params.second.is_some()
2815        )
2816    }
2817
2818    #[tokio::test]
2819    async fn it_should_clear_all_params_set() {
2820        // Build an application with a route.
2821        let app = Router::new().route("/query", get(get_query_params));
2822
2823        // Run the server.
2824        let server = TestServer::new(app);
2825
2826        // Get the request.
2827        server
2828            .get(&"/query")
2829            .add_query_params(QueryParams {
2830                first: Some("first".to_string()),
2831                second: Some("second".to_string()),
2832            })
2833            .clear_query_params()
2834            .await
2835            .assert_text(&"has first? false, has second? false");
2836    }
2837
2838    #[tokio::test]
2839    async fn it_should_clear_all_params_set_and_allow_replacement() {
2840        // Build an application with a route.
2841        let app = Router::new().route("/query", get(get_query_params));
2842
2843        // Run the server.
2844        let server = TestServer::new(app);
2845
2846        // Get the request.
2847        server
2848            .get(&"/query")
2849            .add_query_params(QueryParams {
2850                first: Some("first".to_string()),
2851                second: Some("second".to_string()),
2852            })
2853            .clear_query_params()
2854            .add_query_params(QueryParams {
2855                first: Some("first".to_string()),
2856                second: Some("second".to_string()),
2857            })
2858            .await
2859            .assert_text(&"has first? true, has second? true");
2860    }
2861}
2862
2863#[cfg(test)]
2864mod test_multipart {
2865    use crate::TestServer;
2866    use crate::multipart::MultipartForm;
2867    use crate::multipart::Part;
2868    use axum::Json;
2869    use axum::Router;
2870    use axum::extract::Multipart;
2871    use axum::routing::post;
2872    use serde_json::Value;
2873    use serde_json::json;
2874
2875    async fn route_post_multipart(mut multipart: Multipart) -> Json<Vec<String>> {
2876        let mut fields = vec![];
2877
2878        while let Some(field) = multipart.next_field().await.unwrap() {
2879            let name = field.name().unwrap().to_string();
2880            let content_type = field.content_type().unwrap().to_owned();
2881            let data = field.bytes().await.unwrap();
2882
2883            let field_stats = format!("{name} is {} bytes, {content_type}", data.len());
2884            fields.push(field_stats);
2885        }
2886
2887        Json(fields)
2888    }
2889
2890    async fn route_post_multipart_headers(mut multipart: Multipart) -> Json<Vec<Value>> {
2891        let mut sent_part_headers = vec![];
2892
2893        while let Some(field) = multipart.next_field().await.unwrap() {
2894            let part_name = field.name().unwrap().to_string();
2895            let part_header_value = field
2896                .headers()
2897                .get("x-part-header-test")
2898                .unwrap()
2899                .to_str()
2900                .unwrap()
2901                .to_string();
2902            let part_text = String::from_utf8(field.bytes().await.unwrap().into()).unwrap();
2903
2904            sent_part_headers.push(json!({
2905                "name": part_name,
2906                "text": part_text,
2907                "header": part_header_value,
2908            }))
2909        }
2910
2911        Json(sent_part_headers)
2912    }
2913
2914    fn test_router() -> Router {
2915        Router::new()
2916            .route("/multipart", post(route_post_multipart))
2917            .route("/multipart_headers", post(route_post_multipart_headers))
2918    }
2919
2920    #[tokio::test]
2921    async fn it_should_get_multipart_stats_on_mock_transport() {
2922        // Run the server.
2923        let server = TestServer::builder().mock_transport().build(test_router());
2924
2925        let form = MultipartForm::new()
2926            .add_text("penguins?", "lots")
2927            .add_text("animals", "🦊🦊🦊")
2928            .add_text("carrots", 123 as u32);
2929
2930        // Get the request.
2931        server
2932            .post(&"/multipart")
2933            .multipart(form)
2934            .await
2935            .assert_json(&vec![
2936                "penguins? is 4 bytes, text/plain".to_string(),
2937                "animals is 12 bytes, text/plain".to_string(),
2938                "carrots is 3 bytes, text/plain".to_string(),
2939            ]);
2940    }
2941
2942    #[tokio::test]
2943    async fn it_should_get_multipart_stats_on_http_transport() {
2944        // Run the server.
2945        let server = TestServer::builder().http_transport().build(test_router());
2946
2947        let form = MultipartForm::new()
2948            .add_text("penguins?", "lots")
2949            .add_text("animals", "🦊🦊🦊")
2950            .add_text("carrots", 123 as u32);
2951
2952        // Get the request.
2953        server
2954            .post(&"/multipart")
2955            .multipart(form)
2956            .await
2957            .assert_json(&vec![
2958                "penguins? is 4 bytes, text/plain".to_string(),
2959                "animals is 12 bytes, text/plain".to_string(),
2960                "carrots is 3 bytes, text/plain".to_string(),
2961            ]);
2962    }
2963
2964    #[tokio::test]
2965    async fn it_should_send_text_parts_as_text() {
2966        // Run the server.
2967        let server = TestServer::builder().mock_transport().build(test_router());
2968
2969        let form = MultipartForm::new().add_part("animals", Part::text("🦊🦊🦊"));
2970
2971        // Get the request.
2972        server
2973            .post(&"/multipart")
2974            .multipart(form)
2975            .await
2976            .assert_json(&vec!["animals is 12 bytes, text/plain".to_string()]);
2977    }
2978
2979    #[tokio::test]
2980    async fn it_should_send_custom_mime_type() {
2981        // Run the server.
2982        let server = TestServer::builder().mock_transport().build(test_router());
2983
2984        let form = MultipartForm::new().add_part(
2985            "animals",
2986            Part::bytes("🦊,🦊,🦊".as_bytes()).mime_type(mime::TEXT_CSV),
2987        );
2988
2989        // Get the request.
2990        server
2991            .post(&"/multipart")
2992            .multipart(form)
2993            .await
2994            .assert_json(&vec!["animals is 14 bytes, text/csv".to_string()]);
2995    }
2996
2997    #[tokio::test]
2998    async fn it_should_send_using_include_bytes() {
2999        // Run the server.
3000        let server = TestServer::builder().mock_transport().build(test_router());
3001
3002        let form = MultipartForm::new().add_part(
3003            "file",
3004            Part::bytes(include_bytes!("../files/example.txt").as_slice())
3005                .mime_type(mime::TEXT_PLAIN),
3006        );
3007
3008        // Get the request.
3009        server
3010            .post(&"/multipart")
3011            .multipart(form)
3012            .await
3013            .assert_json(&vec!["file is 6 bytes, text/plain".to_string()]);
3014    }
3015
3016    #[tokio::test]
3017    async fn it_should_send_form_headers_in_parts() {
3018        // Run the server.
3019        let server = TestServer::builder().mock_transport().build(test_router());
3020
3021        let form = MultipartForm::new()
3022            .add_part(
3023                "part_1",
3024                Part::text("part_1_text").add_header("x-part-header-test", "part_1_header"),
3025            )
3026            .add_part(
3027                "part_2",
3028                Part::text("part_2_text").add_header("x-part-header-test", "part_2_header"),
3029            );
3030
3031        // Get the request.
3032        server
3033            .post(&"/multipart_headers")
3034            .multipart(form)
3035            .await
3036            .assert_json(&json!([
3037                {
3038                    "name": "part_1",
3039                    "text": "part_1_text",
3040                    "header": "part_1_header",
3041                },
3042                {
3043                    "name": "part_2",
3044                    "text": "part_2_text",
3045                    "header": "part_2_header",
3046                },
3047            ]));
3048    }
3049}