axum_test/
test_request.rs

1use anyhow::anyhow;
2use anyhow::Context;
3use anyhow::Error as AnyhowError;
4use anyhow::Result;
5use axum::body::Body;
6use bytes::Bytes;
7use cookie::time::OffsetDateTime;
8use cookie::Cookie;
9use cookie::CookieJar;
10use http::header;
11use http::header::SET_COOKIE;
12use http::HeaderName;
13use http::HeaderValue;
14use http::Method;
15use http::Request;
16use http_body_util::BodyExt;
17use serde::Serialize;
18use std::fmt::Debug;
19use std::fmt::Display;
20use std::fs::read;
21use std::fs::read_to_string;
22use std::fs::File;
23use std::future::{Future, IntoFuture};
24use std::io::BufReader;
25use std::path::Path;
26use std::pin::Pin;
27use std::sync::Arc;
28use std::sync::Mutex;
29use url::Url;
30
31use crate::internals::ExpectedState;
32use crate::internals::QueryParamsStore;
33use crate::internals::RequestPathFormatter;
34use crate::multipart::MultipartForm;
35use crate::transport_layer::TransportLayer;
36use crate::ServerSharedState;
37use crate::TestResponse;
38
39mod test_request_config;
40pub(crate) use self::test_request_config::*;
41
42///
43/// A `TestRequest` is for building and executing a HTTP request to the [`TestServer`](crate::TestServer).
44///
45/// ## Building
46///
47/// Requests are created by the [`TestServer`](crate::TestServer), using it's builder functions.
48/// They correspond to the appropriate HTTP method: [`TestServer::get()`](crate::TestServer::get()),
49/// [`TestServer::post()`](crate::TestServer::post()), etc.
50///
51/// See there for documentation.
52///
53/// ## Customising
54///
55/// The `TestRequest` allows the caller to fill in the rest of the request
56/// to be sent to the server. Including the headers, the body, cookies,
57/// and the content type, using the relevant functions.
58///
59/// The TestRequest struct provides a number of methods to set up the request,
60/// such as json, text, bytes, expect_failure, content_type, etc.
61///
62/// ## Sending
63///
64/// Once fully configured you send the request by awaiting the request object.
65///
66/// ```rust
67/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
68/// #
69/// # use axum::Router;
70/// # use axum_test::TestServer;
71/// #
72/// # let server = TestServer::new(Router::new())?;
73/// #
74/// // Build your request
75/// let request = server.get(&"/user")
76///     .add_header("x-custom-header", "example.com")
77///     .content_type("application/yaml");
78///
79/// // await request to execute
80/// let response = request.await;
81/// #
82/// # Ok(()) }
83/// ```
84///
85/// You will receive a `TestResponse`.
86///
87/// ## Cookie Saving
88///
89/// [`TestRequest::save_cookies()`](crate::TestRequest::save_cookies()) and [`TestRequest::do_not_save_cookies()`](crate::TestRequest::do_not_save_cookies())
90/// methods allow you to set the request to save cookies to the `TestServer`,
91/// for reuse on any future requests.
92///
93/// This behaviour is **off** by default, and can be changed for all `TestRequests`
94/// when building the `TestServer`. By building it with a `TestServerConfig` where `save_cookies` is set to true.
95///
96/// ## Expecting Failure and Success
97///
98/// When making a request you can mark it to expect a response within,
99/// or outside, of the 2xx range of HTTP status codes.
100///
101/// If the response returns a status code different to what is expected,
102/// then it will panic.
103///
104/// This is useful when making multiple requests within a test.
105/// As it can find issues earlier than later.
106///
107/// See the [`TestRequest::expect_failure()`](crate::TestRequest::expect_failure()),
108/// and [`TestRequest::expect_success()`](crate::TestRequest::expect_success()).
109///
110#[derive(Debug)]
111#[must_use = "futures do nothing unless polled"]
112pub struct TestRequest {
113    config: TestRequestConfig,
114
115    server_state: Arc<Mutex<ServerSharedState>>,
116    transport: Arc<Box<dyn TransportLayer>>,
117
118    body: Option<Body>,
119
120    expected_state: ExpectedState,
121}
122
123impl TestRequest {
124    pub(crate) fn new(
125        server_state: Arc<Mutex<ServerSharedState>>,
126        transport: Arc<Box<dyn TransportLayer>>,
127        config: TestRequestConfig,
128    ) -> Self {
129        let expected_state = config.expected_state;
130
131        Self {
132            config,
133            server_state,
134            transport,
135            body: None,
136            expected_state,
137        }
138    }
139
140    /// Set the body of the request to send up data as Json,
141    /// and changes the content type to `application/json`.
142    pub fn json<J>(self, body: &J) -> Self
143    where
144        J: ?Sized + Serialize,
145    {
146        let body_bytes =
147            serde_json::to_vec(body).expect("It should serialize the content into Json");
148
149        self.bytes(body_bytes.into())
150            .content_type(mime::APPLICATION_JSON.essence_str())
151    }
152
153    /// Sends a payload as a Json request, with the contents coming from a file.
154    pub fn json_from_file<P>(self, path: P) -> Self
155    where
156        P: AsRef<Path>,
157    {
158        let path_ref = path.as_ref();
159        let file = File::open(path_ref)
160            .with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
161            .unwrap();
162
163        let reader = BufReader::new(file);
164        let payload = serde_json::from_reader::<_, serde_json::Value>(reader)
165            .with_context(|| {
166                format!(
167                    "Failed to deserialize file '{}' as Json",
168                    path_ref.display()
169                )
170            })
171            .unwrap();
172
173        self.json(&payload)
174    }
175
176    /// Set the body of the request to send up data as Yaml,
177    /// and changes the content type to `application/yaml`.
178    #[cfg(feature = "yaml")]
179    pub fn yaml<Y>(self, body: &Y) -> Self
180    where
181        Y: ?Sized + Serialize,
182    {
183        let body = serde_yaml::to_string(body).expect("It should serialize the content into Yaml");
184
185        self.bytes(body.into_bytes().into())
186            .content_type("application/yaml")
187    }
188
189    /// Sends a payload as a Yaml request, with the contents coming from a file.
190    #[cfg(feature = "yaml")]
191    pub fn yaml_from_file<P>(self, path: P) -> Self
192    where
193        P: AsRef<Path>,
194    {
195        let path_ref = path.as_ref();
196        let file = File::open(path_ref)
197            .with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
198            .unwrap();
199
200        let reader = BufReader::new(file);
201        let payload = serde_yaml::from_reader::<_, serde_yaml::Value>(reader)
202            .with_context(|| {
203                format!(
204                    "Failed to deserialize file '{}' as Yaml",
205                    path_ref.display()
206                )
207            })
208            .unwrap();
209
210        self.yaml(&payload)
211    }
212
213    /// Set the body of the request to send up data as MsgPack,
214    /// and changes the content type to `application/msgpack`.
215    #[cfg(feature = "msgpack")]
216    pub fn msgpack<M>(self, body: &M) -> Self
217    where
218        M: ?Sized + Serialize,
219    {
220        let body_bytes =
221            ::rmp_serde::to_vec(body).expect("It should serialize the content into MsgPack");
222
223        self.bytes(body_bytes.into())
224            .content_type("application/msgpack")
225    }
226
227    /// Sets the body of the request, with the content type
228    /// of 'application/x-www-form-urlencoded'.
229    pub fn form<F>(self, body: &F) -> Self
230    where
231        F: ?Sized + Serialize,
232    {
233        let body_text =
234            serde_urlencoded::to_string(body).expect("It should serialize the content into a Form");
235
236        self.bytes(body_text.into())
237            .content_type(mime::APPLICATION_WWW_FORM_URLENCODED.essence_str())
238    }
239
240    /// For sending multipart forms.
241    /// The payload is built using [`MultipartForm`](crate::multipart::MultipartForm) and [`Part`](crate::multipart::Part).
242    ///
243    /// This will be sent with the content type of 'multipart/form-data'.
244    ///
245    /// # Simple example
246    ///
247    /// ```rust
248    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
249    /// #
250    /// use axum::Router;
251    /// use axum_test::TestServer;
252    /// use axum_test::multipart::MultipartForm;
253    ///
254    /// let app = Router::new();
255    /// let server = TestServer::new(app)?;
256    ///
257    /// let multipart_form = MultipartForm::new()
258    ///     .add_text("name", "Joe")
259    ///     .add_text("animals", "foxes");
260    ///
261    /// let response = server.post(&"/my-form")
262    ///     .multipart(multipart_form)
263    ///     .await;
264    /// #
265    /// # Ok(()) }
266    /// ```
267    ///
268    /// # Sending byte parts
269    ///
270    /// ```rust
271    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
272    /// #
273    /// use axum::Router;
274    /// use axum_test::TestServer;
275    /// use axum_test::multipart::MultipartForm;
276    /// use axum_test::multipart::Part;
277    ///
278    /// let app = Router::new();
279    /// let server = TestServer::new(app)?;
280    ///
281    /// let readme_bytes = include_bytes!("../README.md");
282    /// let readme_part = Part::bytes(readme_bytes.as_slice())
283    ///     .file_name(&"README.md")
284    ///     .mime_type(&"text/markdown");
285    ///
286    /// let multipart_form = MultipartForm::new()
287    ///     .add_part("file", readme_part);
288    ///
289    /// let response = server.post(&"/my-form")
290    ///     .multipart(multipart_form)
291    ///     .await;
292    /// #
293    /// # Ok(()) }
294    /// ```
295    ///
296    pub fn multipart(mut self, multipart: MultipartForm) -> Self {
297        self.config.content_type = Some(multipart.content_type());
298        self.body = Some(multipart.into());
299
300        self
301    }
302
303    /// Set raw text as the body of the request,
304    /// and sets the content type to `text/plain`.
305    pub fn text<T>(self, raw_text: T) -> Self
306    where
307        T: Display,
308    {
309        let body_text = raw_text.to_string();
310
311        self.bytes(body_text.into())
312            .content_type(mime::TEXT_PLAIN.essence_str())
313    }
314
315    /// Sends a payload as plain text, with the contents coming from a file.
316    pub fn text_from_file<P>(self, path: P) -> Self
317    where
318        P: AsRef<Path>,
319    {
320        let path_ref = path.as_ref();
321        let payload = read_to_string(path_ref)
322            .with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
323            .unwrap();
324
325        self.text(payload)
326    }
327
328    /// Set raw bytes as the body of the request.
329    ///
330    /// The content type is left unchanged.
331    pub fn bytes(mut self, body_bytes: Bytes) -> Self {
332        let body: Body = body_bytes.into();
333
334        self.body = Some(body);
335        self
336    }
337
338    /// Reads the contents of the file as raw bytes, and sends it within the request.
339    ///
340    /// The content type is left unchanged, and no parsing of the file is done.
341    pub fn bytes_from_file<P>(self, path: P) -> Self
342    where
343        P: AsRef<Path>,
344    {
345        let path_ref = path.as_ref();
346        let payload = read(path_ref)
347            .with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
348            .unwrap();
349
350        self.bytes(payload.into())
351    }
352
353    /// Set the content type to use for this request in the header.
354    pub fn content_type(mut self, content_type: &str) -> Self {
355        self.config.content_type = Some(content_type.to_string());
356        self
357    }
358
359    /// Adds a Cookie to be sent with this request.
360    pub fn add_cookie(mut self, cookie: Cookie<'_>) -> Self {
361        self.config.cookies.add(cookie.into_owned());
362        self
363    }
364
365    /// Adds many cookies to be used with this request.
366    pub fn add_cookies(mut self, cookies: CookieJar) -> Self {
367        for cookie in cookies.iter() {
368            self.config.cookies.add(cookie.clone());
369        }
370
371        self
372    }
373
374    /// Clears all cookies used internally within this Request,
375    /// including any that came from the `TestServer`.
376    pub fn clear_cookies(mut self) -> Self {
377        self.config.cookies = CookieJar::new();
378        self
379    }
380
381    /// Any cookies returned will be saved to the [`TestServer`](crate::TestServer) that created this,
382    /// which will continue to use those cookies on future requests.
383    pub fn save_cookies(mut self) -> Self {
384        self.config.is_saving_cookies = true;
385        self
386    }
387
388    /// Cookies returned by this will _not_ be saved to the `TestServer`.
389    /// For use by future requests.
390    ///
391    /// This is the default behaviour.
392    /// You can change that default in [`TestServerConfig`](crate::TestServerConfig).
393    pub fn do_not_save_cookies(mut self) -> Self {
394        self.config.is_saving_cookies = false;
395        self
396    }
397
398    /// Adds query parameters to be sent with this request.
399    pub fn add_query_param<V>(self, key: &str, value: V) -> Self
400    where
401        V: Serialize,
402    {
403        self.add_query_params(&[(key, value)])
404    }
405
406    /// Adds the structure given as query parameters for this request.
407    ///
408    /// This is designed to take a list of parameters, or a body of parameters,
409    /// and then serializes them into the parameters of the request.
410    ///
411    /// # Sending a body of parameters using `json!`
412    ///
413    /// ```rust
414    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
415    /// #
416    /// use axum::Router;
417    /// use axum_test::TestServer;
418    /// use serde_json::json;
419    ///
420    /// let app = Router::new();
421    /// let server = TestServer::new(app)?;
422    ///
423    /// let response = server.get(&"/my-end-point")
424    ///     .add_query_params(json!({
425    ///         "username": "Brian",
426    ///         "age": 20
427    ///     }))
428    ///     .await;
429    /// #
430    /// # Ok(()) }
431    /// ```
432    ///
433    /// # Sending a body of parameters with Serde
434    ///
435    /// ```rust
436    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
437    /// #
438    /// use axum::Router;
439    /// use axum_test::TestServer;
440    /// use serde::Deserialize;
441    /// use serde::Serialize;
442    ///
443    /// #[derive(Serialize, Deserialize)]
444    /// struct UserQueryParams {
445    ///     username: String,
446    ///     age: u32,
447    /// }
448    ///
449    /// let app = Router::new();
450    /// let server = TestServer::new(app)?;
451    ///
452    /// let response = server.get(&"/my-end-point")
453    ///     .add_query_params(UserQueryParams {
454    ///         username: "Brian".to_string(),
455    ///         age: 20
456    ///     })
457    ///     .await;
458    /// #
459    /// # Ok(()) }
460    /// ```
461    ///
462    /// # Sending a list of parameters
463    ///
464    /// ```rust
465    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
466    /// #
467    /// use axum::Router;
468    /// use axum_test::TestServer;
469    ///
470    /// let app = Router::new();
471    /// let server = TestServer::new(app)?;
472    ///
473    /// let response = server.get(&"/my-end-point")
474    ///     .add_query_params(&[
475    ///         ("username", "Brian"),
476    ///         ("age", "20"),
477    ///     ])
478    ///     .await;
479    /// #
480    /// # Ok(()) }
481    /// ```
482    ///
483    pub fn add_query_params<V>(mut self, query_params: V) -> Self
484    where
485        V: Serialize,
486    {
487        self.config
488            .query_params
489            .add(query_params)
490            .with_context(|| {
491                format!(
492                    "It should serialize query parameters, for request {}",
493                    self.debug_request_format()
494                )
495            })
496            .unwrap();
497
498        self
499    }
500
501    /// Adds a query param onto the end of the request,
502    /// with no urlencoding of any kind.
503    ///
504    /// This exists to allow custom query parameters,
505    /// such as for the many versions of query param arrays.
506    ///
507    /// ```rust
508    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
509    /// #
510    /// use axum::Router;
511    /// use axum_test::TestServer;
512    ///
513    /// let app = Router::new();
514    /// let server = TestServer::new(app)?;
515    ///
516    /// let response = server.get(&"/my-end-point")
517    ///     .add_raw_query_param(&"my-flag")
518    ///     .add_raw_query_param(&"array[]=123")
519    ///     .add_raw_query_param(&"filter[value]=some-value")
520    ///     .await;
521    /// #
522    /// # Ok(()) }
523    /// ```
524    ///
525    pub fn add_raw_query_param(mut self, query_param: &str) -> Self {
526        self.config.query_params.add_raw(query_param.to_string());
527
528        self
529    }
530
531    /// Clears all query params set,
532    /// including any that came from the [`TestServer`](crate::TestServer).
533    pub fn clear_query_params(mut self) -> Self {
534        self.config.query_params.clear();
535        self
536    }
537
538    /// Adds a header to be sent with this request.
539    ///
540    /// ```rust
541    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
542    /// #
543    /// use axum::Router;
544    /// use axum_test::TestServer;
545    ///
546    /// let app = Router::new();
547    /// let server = TestServer::new(app)?;
548    ///
549    /// let response = server.get(&"/my-end-point")
550    ///     .add_header("x-custom-header", "custom-value")
551    ///     .add_header(http::header::CONTENT_LENGTH, 12345)
552    ///     .add_header(http::header::HOST, "example.com")
553    ///     .await;
554    /// #
555    /// # Ok(()) }
556    /// ```
557    pub fn add_header<N, V>(mut self, name: N, value: V) -> Self
558    where
559        N: TryInto<HeaderName>,
560        N::Error: Debug,
561        V: TryInto<HeaderValue>,
562        V::Error: Debug,
563    {
564        let header_name: HeaderName = name
565            .try_into()
566            .expect("Failed to convert header name to HeaderName");
567        let header_value: HeaderValue = value
568            .try_into()
569            .expect("Failed to convert header vlue to HeaderValue");
570
571        self.config.headers.push((header_name, header_value));
572        self
573    }
574
575    /// Adds an 'AUTHORIZATION' HTTP header to the request,
576    /// with no internal formatting of what is given.
577    pub fn authorization<T>(self, authorization_header: T) -> Self
578    where
579        T: AsRef<str>,
580    {
581        let authorization_header_value = HeaderValue::from_str(authorization_header.as_ref())
582            .expect("Cannot build Authorization HeaderValue from token");
583
584        self.add_header(header::AUTHORIZATION, authorization_header_value)
585    }
586
587    /// Adds an 'AUTHORIZATION' HTTP header to the request,
588    /// in the 'Bearer {token}' format.
589    pub fn authorization_bearer<T>(self, authorization_bearer_token: T) -> Self
590    where
591        T: Display,
592    {
593        let authorization_bearer_header_str = format!("Bearer {authorization_bearer_token}");
594        self.authorization(authorization_bearer_header_str)
595    }
596
597    /// Clears all headers set.
598    pub fn clear_headers(mut self) -> Self {
599        self.config.headers = vec![];
600        self
601    }
602
603    /// Sets the scheme to use when making the request. i.e. http or https.
604    /// The default scheme is 'http'.
605    ///
606    /// ```rust
607    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
608    /// #
609    /// use axum::Router;
610    /// use axum_test::TestServer;
611    ///
612    /// let app = Router::new();
613    /// let server = TestServer::new(app)?;
614    ///
615    /// let response = server
616    ///     .get(&"/my-end-point")
617    ///     .scheme(&"https")
618    ///     .await;
619    /// #
620    /// # Ok(()) }
621    /// ```
622    ///
623    pub fn scheme(mut self, scheme: &str) -> Self {
624        self.config
625            .full_request_url
626            .set_scheme(scheme)
627            .map_err(|_| anyhow!("Scheme '{scheme}' cannot be set to request"))
628            .unwrap();
629        self
630    }
631
632    /// Marks that this request is expected to always return a HTTP
633    /// status code within the 2xx range (200 to 299).
634    ///
635    /// If a code _outside_ of that range is returned,
636    /// then this will panic.
637    ///
638    /// ```rust
639    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
640    /// #
641    /// use axum::Json;
642    /// use axum::Router;
643    /// use axum::routing::put;
644    /// use serde_json::json;
645    ///
646    /// use axum_test::TestServer;
647    ///
648    /// let app = Router::new()
649    ///     .route(&"/todo", put(|| async { unimplemented!() }));
650    ///
651    /// let server = TestServer::new(app)?;
652    ///
653    /// // If this doesn't return a value in the 2xx range,
654    /// // then it will panic.
655    /// server.put(&"/todo")
656    ///     .expect_success()
657    ///     .json(&json!({
658    ///         "task": "buy milk",
659    ///     }))
660    ///     .await;
661    /// #
662    /// # Ok(())
663    /// # }
664    /// ```
665    ///
666    pub fn expect_success(self) -> Self {
667        self.expect_state(ExpectedState::Success)
668    }
669
670    /// Marks that this request is expected to return a HTTP status code
671    /// outside of the 2xx range.
672    ///
673    /// If a code _within_ the 2xx range is returned,
674    /// then this will panic.
675    pub fn expect_failure(self) -> Self {
676        self.expect_state(ExpectedState::Failure)
677    }
678
679    fn expect_state(mut self, expected_state: ExpectedState) -> Self {
680        self.expected_state = expected_state;
681        self
682    }
683
684    async fn send(self) -> Result<TestResponse> {
685        let debug_request_format = self.debug_request_format().to_string();
686
687        let method = self.config.method;
688        let expected_state = self.expected_state;
689        let save_cookies = self.config.is_saving_cookies;
690        let body = self.body.unwrap_or(Body::empty());
691        let url =
692            Self::build_url_query_params(self.config.full_request_url, &self.config.query_params);
693
694        let request = Self::build_request(
695            method.clone(),
696            &url,
697            body,
698            self.config.content_type,
699            self.config.cookies,
700            self.config.headers,
701            &debug_request_format,
702        )?;
703
704        #[allow(unused_mut)] // Allowed for the `ws` use immediately after.
705        let mut http_response = self.transport.send(request).await?;
706
707        #[cfg(feature = "ws")]
708        let websockets = {
709            let maybe_on_upgrade = http_response
710                .extensions_mut()
711                .remove::<hyper::upgrade::OnUpgrade>();
712            let transport_type = self.transport.transport_layer_type();
713
714            crate::internals::TestResponseWebSocket {
715                maybe_on_upgrade,
716                transport_type,
717            }
718        };
719
720        let (parts, response_body) = http_response.into_parts();
721        let response_bytes = response_body.collect().await?.to_bytes();
722
723        if save_cookies {
724            let cookie_headers = parts.headers.get_all(SET_COOKIE).into_iter();
725            ServerSharedState::add_cookies_by_header(&self.server_state, cookie_headers)?;
726        }
727
728        let test_response = TestResponse::new(
729            method,
730            url,
731            parts,
732            response_bytes,
733            #[cfg(feature = "ws")]
734            websockets,
735        );
736
737        // Assert if ok or not.
738        match expected_state {
739            ExpectedState::Success => test_response.assert_status_success(),
740            ExpectedState::Failure => test_response.assert_status_failure(),
741            ExpectedState::None => {}
742        }
743
744        Ok(test_response)
745    }
746
747    fn build_url_query_params(mut url: Url, query_params: &QueryParamsStore) -> Url {
748        // Add all the query params we have
749        if query_params.has_content() {
750            url.set_query(Some(&query_params.to_string()));
751        }
752
753        url
754    }
755
756    fn build_request(
757        method: Method,
758        url: &Url,
759        body: Body,
760        content_type: Option<String>,
761        cookies: CookieJar,
762        headers: Vec<(HeaderName, HeaderValue)>,
763        debug_request_format: &str,
764    ) -> Result<Request<Body>> {
765        let mut request_builder = Request::builder().uri(url.as_str()).method(method);
766
767        // Add all the headers we have.
768        if let Some(content_type) = content_type {
769            let (header_key, header_value) =
770                build_content_type_header(&content_type, debug_request_format)?;
771            request_builder = request_builder.header(header_key, header_value);
772        }
773
774        // Add all the non-expired cookies as headers
775        // Also strip cookies from their attributes, only their names and values should be preserved to conform the HTTP standard
776        let now = OffsetDateTime::now_utc();
777        for cookie in cookies.iter() {
778            let expired = cookie
779                .expires_datetime()
780                .map(|expires| expires <= now)
781                .unwrap_or(false);
782
783            if !expired {
784                let cookie_raw = cookie.stripped().to_string();
785                let header_value = HeaderValue::from_str(&cookie_raw)?;
786                request_builder = request_builder.header(header::COOKIE, header_value);
787            }
788        }
789
790        // Put headers into the request
791        for (header_name, header_value) in headers {
792            request_builder = request_builder.header(header_name, header_value);
793        }
794
795        let request = request_builder.body(body).with_context(|| {
796            format!("Expect valid hyper Request to be built, for request {debug_request_format}")
797        })?;
798
799        Ok(request)
800    }
801
802    fn debug_request_format(&self) -> RequestPathFormatter<'_> {
803        RequestPathFormatter::new(
804            &self.config.method,
805            self.config.full_request_url.as_str(),
806            Some(&self.config.query_params),
807        )
808    }
809}
810
811impl TryFrom<TestRequest> for Request<Body> {
812    type Error = AnyhowError;
813
814    fn try_from(test_request: TestRequest) -> Result<Request<Body>> {
815        let debug_request_format = test_request.debug_request_format().to_string();
816        let url = TestRequest::build_url_query_params(
817            test_request.config.full_request_url,
818            &test_request.config.query_params,
819        );
820        let body = test_request.body.unwrap_or(Body::empty());
821
822        TestRequest::build_request(
823            test_request.config.method,
824            &url,
825            body,
826            test_request.config.content_type,
827            test_request.config.cookies,
828            test_request.config.headers,
829            &debug_request_format,
830        )
831    }
832}
833
834impl IntoFuture for TestRequest {
835    type Output = TestResponse;
836    type IntoFuture = Pin<Box<dyn Future<Output = TestResponse> + Send>>;
837
838    fn into_future(self) -> Self::IntoFuture {
839        Box::pin(async { self.send().await.context("Sending request failed").unwrap() })
840    }
841}
842
843fn build_content_type_header(
844    content_type: &str,
845    debug_request_format: &str,
846) -> Result<(HeaderName, HeaderValue)> {
847    let header_value = HeaderValue::from_str(content_type).with_context(|| {
848        format!(
849            "Failed to store header content type '{content_type}', for request {debug_request_format}"
850        )
851    })?;
852
853    Ok((header::CONTENT_TYPE, header_value))
854}
855
856#[cfg(test)]
857mod test_content_type {
858    use crate::TestServer;
859    use axum::routing::get;
860    use axum::Router;
861    use http::header::CONTENT_TYPE;
862    use http::HeaderMap;
863
864    async fn get_content_type(headers: HeaderMap) -> String {
865        headers
866            .get(CONTENT_TYPE)
867            .map(|h| h.to_str().unwrap().to_string())
868            .unwrap_or_else(|| "".to_string())
869    }
870
871    #[tokio::test]
872    async fn it_should_not_set_a_content_type_by_default() {
873        // Build an application with a route.
874        let app = Router::new().route("/content_type", get(get_content_type));
875
876        // Run the server.
877        let server = TestServer::new(app).expect("Should create test server");
878
879        // Get the request.
880        let text = server.get(&"/content_type").await.text();
881
882        assert_eq!(text, "");
883    }
884
885    #[tokio::test]
886    async fn it_should_override_server_content_type_when_present() {
887        // Build an application with a route.
888        let app = Router::new().route("/content_type", get(get_content_type));
889
890        // Run the server.
891        let server = TestServer::builder()
892            .default_content_type("text/plain")
893            .build(app)
894            .expect("Should create test server");
895
896        // Get the request.
897        let text = server
898            .get(&"/content_type")
899            .content_type(&"application/json")
900            .await
901            .text();
902
903        assert_eq!(text, "application/json");
904    }
905
906    #[tokio::test]
907    async fn it_should_set_content_type_when_present() {
908        // Build an application with a route.
909        let app = Router::new().route("/content_type", get(get_content_type));
910
911        // Run the server.
912        let server = TestServer::new(app).expect("Should create test server");
913
914        // Get the request.
915        let text = server
916            .get(&"/content_type")
917            .content_type(&"application/custom")
918            .await
919            .text();
920
921        assert_eq!(text, "application/custom");
922    }
923}
924
925#[cfg(test)]
926mod test_json {
927    use crate::TestServer;
928    use axum::extract::DefaultBodyLimit;
929    use axum::routing::post;
930    use axum::Json;
931    use axum::Router;
932    use http::header::CONTENT_TYPE;
933    use http::HeaderMap;
934    use rand::random;
935    use serde::Deserialize;
936    use serde::Serialize;
937    use serde_json::json;
938
939    #[tokio::test]
940    async fn it_should_pass_json_up_to_be_read() {
941        #[derive(Deserialize, Serialize)]
942        struct TestJson {
943            name: String,
944            age: u32,
945            pets: Option<String>,
946        }
947
948        // Build an application with a route.
949        let app = Router::new().route(
950            "/json",
951            post(|Json(json): Json<TestJson>| async move {
952                format!(
953                    "json: {}, {}, {}",
954                    json.name,
955                    json.age,
956                    json.pets.unwrap_or_else(|| "pandas".to_string())
957                )
958            }),
959        );
960
961        // Run the server.
962        let server = TestServer::new(app).expect("Should create test server");
963
964        // Get the request.
965        let text = server
966            .post(&"/json")
967            .json(&TestJson {
968                name: "Joe".to_string(),
969                age: 20,
970                pets: Some("foxes".to_string()),
971            })
972            .await
973            .text();
974
975        assert_eq!(text, "json: Joe, 20, foxes");
976    }
977
978    #[tokio::test]
979    async fn it_should_pass_json_content_type_for_json() {
980        // Build an application with a route.
981        let app = Router::new().route(
982            "/content_type",
983            post(|headers: HeaderMap| async move {
984                headers
985                    .get(CONTENT_TYPE)
986                    .map(|h| h.to_str().unwrap().to_string())
987                    .unwrap_or_else(|| "".to_string())
988            }),
989        );
990
991        // Run the server.
992        let server = TestServer::new(app).expect("Should create test server");
993
994        // Get the request.
995        let text = server.post(&"/content_type").json(&json!({})).await.text();
996
997        assert_eq!(text, "application/json");
998    }
999
1000    #[tokio::test]
1001    async fn it_should_pass_large_json_blobs_over_http() {
1002        const LARGE_BLOB_SIZE: usize = 16777216; // 16mb
1003
1004        #[derive(Deserialize, Serialize, PartialEq, Debug)]
1005        struct TestLargeJson {
1006            items: Vec<String>,
1007        }
1008
1009        let mut size = 0;
1010        let mut items = vec![];
1011        while size < LARGE_BLOB_SIZE {
1012            let item = random::<u64>().to_string();
1013            size += item.len();
1014            items.push(item);
1015        }
1016        let large_json_blob = TestLargeJson { items };
1017
1018        // Build an application with a route.
1019        let app = Router::new()
1020            .route(
1021                "/json",
1022                post(|Json(json): Json<TestLargeJson>| async { Json(json) }),
1023            )
1024            .layer(DefaultBodyLimit::max(LARGE_BLOB_SIZE * 2));
1025
1026        // Run the server.
1027        let server = TestServer::builder()
1028            .http_transport()
1029            .expect_success_by_default()
1030            .build(app)
1031            .expect("Should create test server");
1032
1033        // Get the request.
1034        server
1035            .post(&"/json")
1036            .json(&large_json_blob)
1037            .await
1038            .assert_json(&large_json_blob);
1039    }
1040}
1041
1042#[cfg(test)]
1043mod test_json_from_file {
1044    use crate::TestServer;
1045    use axum::routing::post;
1046    use axum::Json;
1047    use axum::Router;
1048    use http::header::CONTENT_TYPE;
1049    use http::HeaderMap;
1050    use serde::Deserialize;
1051    use serde::Serialize;
1052
1053    #[tokio::test]
1054    async fn it_should_pass_json_up_to_be_read() {
1055        #[derive(Deserialize, Serialize)]
1056        struct TestJson {
1057            name: String,
1058            age: u32,
1059        }
1060
1061        // Build an application with a route.
1062        let app = Router::new().route(
1063            "/json",
1064            post(|Json(json): Json<TestJson>| async move {
1065                format!("json: {}, {}", json.name, json.age,)
1066            }),
1067        );
1068
1069        // Run the server.
1070        let server = TestServer::new(app).expect("Should create test server");
1071
1072        // Get the request.
1073        let text = server
1074            .post(&"/json")
1075            .json_from_file(&"files/example.json")
1076            .await
1077            .text();
1078
1079        assert_eq!(text, "json: Joe, 20");
1080    }
1081
1082    #[tokio::test]
1083    async fn it_should_pass_json_content_type_for_json() {
1084        // Build an application with a route.
1085        let app = Router::new().route(
1086            "/content_type",
1087            post(|headers: HeaderMap| async move {
1088                headers
1089                    .get(CONTENT_TYPE)
1090                    .map(|h| h.to_str().unwrap().to_string())
1091                    .unwrap_or_else(|| "".to_string())
1092            }),
1093        );
1094
1095        // Run the server.
1096        let server = TestServer::new(app).expect("Should create test server");
1097
1098        // Get the request.
1099        let text = server
1100            .post(&"/content_type")
1101            .json_from_file(&"files/example.json")
1102            .await
1103            .text();
1104
1105        assert_eq!(text, "application/json");
1106    }
1107}
1108
1109#[cfg(feature = "yaml")]
1110#[cfg(test)]
1111mod test_yaml {
1112    use crate::TestServer;
1113    use axum::routing::post;
1114    use axum::Router;
1115    use axum_yaml::Yaml;
1116    use http::header::CONTENT_TYPE;
1117    use http::HeaderMap;
1118    use serde::Deserialize;
1119    use serde::Serialize;
1120    use serde_json::json;
1121
1122    #[tokio::test]
1123    async fn it_should_pass_yaml_up_to_be_read() {
1124        #[derive(Deserialize, Serialize)]
1125        struct TestYaml {
1126            name: String,
1127            age: u32,
1128            pets: Option<String>,
1129        }
1130
1131        // Build an application with a route.
1132        let app = Router::new().route(
1133            "/yaml",
1134            post(|Yaml(yaml): Yaml<TestYaml>| async move {
1135                format!(
1136                    "yaml: {}, {}, {}",
1137                    yaml.name,
1138                    yaml.age,
1139                    yaml.pets.unwrap_or_else(|| "pandas".to_string())
1140                )
1141            }),
1142        );
1143
1144        // Run the server.
1145        let server = TestServer::new(app).expect("Should create test server");
1146
1147        // Get the request.
1148        let text = server
1149            .post(&"/yaml")
1150            .yaml(&TestYaml {
1151                name: "Joe".to_string(),
1152                age: 20,
1153                pets: Some("foxes".to_string()),
1154            })
1155            .await
1156            .text();
1157
1158        assert_eq!(text, "yaml: Joe, 20, foxes");
1159    }
1160
1161    #[tokio::test]
1162    async fn it_should_pass_yaml_content_type_for_yaml() {
1163        // Build an application with a route.
1164        let app = Router::new().route(
1165            "/content_type",
1166            post(|headers: HeaderMap| async move {
1167                headers
1168                    .get(CONTENT_TYPE)
1169                    .map(|h| h.to_str().unwrap().to_string())
1170                    .unwrap_or_else(|| "".to_string())
1171            }),
1172        );
1173
1174        // Run the server.
1175        let server = TestServer::new(app).expect("Should create test server");
1176
1177        // Get the request.
1178        let text = server.post(&"/content_type").yaml(&json!({})).await.text();
1179
1180        assert_eq!(text, "application/yaml");
1181    }
1182}
1183
1184#[cfg(feature = "yaml")]
1185#[cfg(test)]
1186mod test_yaml_from_file {
1187    use crate::TestServer;
1188    use axum::routing::post;
1189    use axum::Router;
1190    use axum_yaml::Yaml;
1191    use http::header::CONTENT_TYPE;
1192    use http::HeaderMap;
1193    use serde::Deserialize;
1194    use serde::Serialize;
1195
1196    #[tokio::test]
1197    async fn it_should_pass_yaml_up_to_be_read() {
1198        #[derive(Deserialize, Serialize)]
1199        struct TestYaml {
1200            name: String,
1201            age: u32,
1202        }
1203
1204        // Build an application with a route.
1205        let app = Router::new().route(
1206            "/yaml",
1207            post(|Yaml(yaml): Yaml<TestYaml>| async move {
1208                format!("yaml: {}, {}", yaml.name, yaml.age,)
1209            }),
1210        );
1211
1212        // Run the server.
1213        let server = TestServer::new(app).expect("Should create test server");
1214
1215        // Get the request.
1216        let text = server
1217            .post(&"/yaml")
1218            .yaml_from_file(&"files/example.yaml")
1219            .await
1220            .text();
1221
1222        assert_eq!(text, "yaml: Joe, 20");
1223    }
1224
1225    #[tokio::test]
1226    async fn it_should_pass_yaml_content_type_for_yaml() {
1227        // Build an application with a route.
1228        let app = Router::new().route(
1229            "/content_type",
1230            post(|headers: HeaderMap| async move {
1231                headers
1232                    .get(CONTENT_TYPE)
1233                    .map(|h| h.to_str().unwrap().to_string())
1234                    .unwrap_or_else(|| "".to_string())
1235            }),
1236        );
1237
1238        // Run the server.
1239        let server = TestServer::new(app).expect("Should create test server");
1240
1241        // Get the request.
1242        let text = server
1243            .post(&"/content_type")
1244            .yaml_from_file(&"files/example.yaml")
1245            .await
1246            .text();
1247
1248        assert_eq!(text, "application/yaml");
1249    }
1250}
1251
1252#[cfg(feature = "msgpack")]
1253#[cfg(test)]
1254mod test_msgpack {
1255    use crate::TestServer;
1256    use axum::routing::post;
1257    use axum::Router;
1258    use axum_msgpack::MsgPack;
1259    use http::header::CONTENT_TYPE;
1260    use http::HeaderMap;
1261    use serde::Deserialize;
1262    use serde::Serialize;
1263    use serde_json::json;
1264
1265    #[tokio::test]
1266    async fn it_should_pass_msgpack_up_to_be_read() {
1267        #[derive(Deserialize, Serialize)]
1268        struct TestMsgPack {
1269            name: String,
1270            age: u32,
1271            pets: Option<String>,
1272        }
1273
1274        async fn get_msgpack(MsgPack(msgpack): MsgPack<TestMsgPack>) -> String {
1275            format!(
1276                "yaml: {}, {}, {}",
1277                msgpack.name,
1278                msgpack.age,
1279                msgpack.pets.unwrap_or_else(|| "pandas".to_string())
1280            )
1281        }
1282
1283        // Build an application with a route.
1284        let app = Router::new().route("/msgpack", post(get_msgpack));
1285
1286        // Run the server.
1287        let server = TestServer::new(app).expect("Should create test server");
1288
1289        // Get the request.
1290        let text = server
1291            .post(&"/msgpack")
1292            .msgpack(&TestMsgPack {
1293                name: "Joe".to_string(),
1294                age: 20,
1295                pets: Some("foxes".to_string()),
1296            })
1297            .await
1298            .text();
1299
1300        assert_eq!(text, "yaml: Joe, 20, foxes");
1301    }
1302
1303    #[tokio::test]
1304    async fn it_should_pass_msgpck_content_type_for_msgpack() {
1305        async fn get_content_type(headers: HeaderMap) -> String {
1306            headers
1307                .get(CONTENT_TYPE)
1308                .map(|h| h.to_str().unwrap().to_string())
1309                .unwrap_or_else(|| "".to_string())
1310        }
1311
1312        // Build an application with a route.
1313        let app = Router::new().route("/content_type", post(get_content_type));
1314
1315        // Run the server.
1316        let server = TestServer::new(app).expect("Should create test server");
1317
1318        // Get the request.
1319        let text = server
1320            .post(&"/content_type")
1321            .msgpack(&json!({}))
1322            .await
1323            .text();
1324
1325        assert_eq!(text, "application/msgpack");
1326    }
1327}
1328
1329#[cfg(test)]
1330mod test_form {
1331    use crate::TestServer;
1332    use axum::routing::post;
1333    use axum::Form;
1334    use axum::Router;
1335    use http::header::CONTENT_TYPE;
1336    use http::HeaderMap;
1337    use serde::Deserialize;
1338    use serde::Serialize;
1339
1340    #[tokio::test]
1341    async fn it_should_pass_form_up_to_be_read() {
1342        #[derive(Deserialize, Serialize)]
1343        struct TestForm {
1344            name: String,
1345            age: u32,
1346            pets: Option<String>,
1347        }
1348
1349        async fn get_form(Form(form): Form<TestForm>) -> String {
1350            format!(
1351                "form: {}, {}, {}",
1352                form.name,
1353                form.age,
1354                form.pets.unwrap_or_else(|| "pandas".to_string())
1355            )
1356        }
1357
1358        // Build an application with a route.
1359        let app = Router::new().route("/form", post(get_form));
1360
1361        // Run the server.
1362        let server = TestServer::new(app).expect("Should create test server");
1363
1364        // Get the request.
1365        server
1366            .post(&"/form")
1367            .form(&TestForm {
1368                name: "Joe".to_string(),
1369                age: 20,
1370                pets: Some("foxes".to_string()),
1371            })
1372            .await
1373            .assert_text("form: Joe, 20, foxes");
1374    }
1375
1376    #[tokio::test]
1377    async fn it_should_pass_form_content_type_for_form() {
1378        async fn get_content_type(headers: HeaderMap) -> String {
1379            headers
1380                .get(CONTENT_TYPE)
1381                .map(|h| h.to_str().unwrap().to_string())
1382                .unwrap_or_else(|| "".to_string())
1383        }
1384
1385        // Build an application with a route.
1386        let app = Router::new().route("/content_type", post(get_content_type));
1387
1388        // Run the server.
1389        let server = TestServer::new(app).expect("Should create test server");
1390
1391        #[derive(Serialize)]
1392        struct MyForm {
1393            message: String,
1394        }
1395
1396        // Get the request.
1397        server
1398            .post(&"/content_type")
1399            .form(&MyForm {
1400                message: "hello".to_string(),
1401            })
1402            .await
1403            .assert_text("application/x-www-form-urlencoded");
1404    }
1405}
1406
1407#[cfg(test)]
1408mod test_bytes {
1409    use crate::TestServer;
1410    use axum::extract::Request;
1411    use axum::routing::post;
1412    use axum::Router;
1413    use http::header::CONTENT_TYPE;
1414    use http::HeaderMap;
1415    use http_body_util::BodyExt;
1416
1417    #[tokio::test]
1418    async fn it_should_pass_bytes_up_to_be_read() {
1419        // Build an application with a route.
1420        let app = Router::new().route(
1421            "/bytes",
1422            post(|request: Request| async move {
1423                let body_bytes = request
1424                    .into_body()
1425                    .collect()
1426                    .await
1427                    .expect("Should read body to bytes")
1428                    .to_bytes();
1429
1430                String::from_utf8_lossy(&body_bytes).to_string()
1431            }),
1432        );
1433
1434        // Run the server.
1435        let server = TestServer::new(app).expect("Should create test server");
1436
1437        // Get the request.
1438        let text = server
1439            .post(&"/bytes")
1440            .bytes("hello!".as_bytes().into())
1441            .await
1442            .text();
1443
1444        assert_eq!(text, "hello!");
1445    }
1446
1447    #[tokio::test]
1448    async fn it_should_not_change_content_type() {
1449        let app = Router::new().route(
1450            "/content_type",
1451            post(|headers: HeaderMap| async move {
1452                headers
1453                    .get(CONTENT_TYPE)
1454                    .map(|h| h.to_str().unwrap().to_string())
1455                    .unwrap_or_else(|| "".to_string())
1456            }),
1457        );
1458
1459        // Run the server.
1460        let server = TestServer::new(app).expect("Should create test server");
1461
1462        // Get the request.
1463        let text = server
1464            .post(&"/content_type")
1465            .content_type(&"application/testing")
1466            .bytes("hello!".as_bytes().into())
1467            .await
1468            .text();
1469
1470        assert_eq!(text, "application/testing");
1471    }
1472}
1473
1474#[cfg(test)]
1475mod test_bytes_from_file {
1476    use crate::TestServer;
1477    use axum::extract::Request;
1478    use axum::routing::post;
1479    use axum::Router;
1480    use http::header::CONTENT_TYPE;
1481    use http::HeaderMap;
1482    use http_body_util::BodyExt;
1483
1484    #[tokio::test]
1485    async fn it_should_pass_bytes_up_to_be_read() {
1486        // Build an application with a route.
1487        let app = Router::new().route(
1488            "/bytes",
1489            post(|request: Request| async move {
1490                let body_bytes = request
1491                    .into_body()
1492                    .collect()
1493                    .await
1494                    .expect("Should read body to bytes")
1495                    .to_bytes();
1496
1497                String::from_utf8_lossy(&body_bytes).to_string()
1498            }),
1499        );
1500
1501        // Run the server.
1502        let server = TestServer::new(app).expect("Should create test server");
1503
1504        // Get the request.
1505        let text = server
1506            .post(&"/bytes")
1507            .bytes_from_file(&"files/example.txt")
1508            .await
1509            .text();
1510
1511        assert_eq!(text, "hello!");
1512    }
1513
1514    #[tokio::test]
1515    async fn it_should_not_change_content_type() {
1516        let app = Router::new().route(
1517            "/content_type",
1518            post(|headers: HeaderMap| async move {
1519                headers
1520                    .get(CONTENT_TYPE)
1521                    .map(|h| h.to_str().unwrap().to_string())
1522                    .unwrap_or_else(|| "".to_string())
1523            }),
1524        );
1525
1526        // Run the server.
1527        let server = TestServer::new(app).expect("Should create test server");
1528
1529        // Get the request.
1530        let text = server
1531            .post(&"/content_type")
1532            .content_type(&"application/testing")
1533            .bytes_from_file(&"files/example.txt")
1534            .await
1535            .text();
1536
1537        assert_eq!(text, "application/testing");
1538    }
1539}
1540
1541#[cfg(test)]
1542mod test_text {
1543    use crate::TestServer;
1544    use axum::extract::Request;
1545    use axum::routing::post;
1546    use axum::Router;
1547    use http::header::CONTENT_TYPE;
1548    use http::HeaderMap;
1549    use http_body_util::BodyExt;
1550
1551    #[tokio::test]
1552    async fn it_should_pass_text_up_to_be_read() {
1553        // Build an application with a route.
1554        let app = Router::new().route(
1555            "/text",
1556            post(|request: Request| async move {
1557                let body_bytes = request
1558                    .into_body()
1559                    .collect()
1560                    .await
1561                    .expect("Should read body to bytes")
1562                    .to_bytes();
1563
1564                String::from_utf8_lossy(&body_bytes).to_string()
1565            }),
1566        );
1567
1568        // Run the server.
1569        let server = TestServer::new(app).expect("Should create test server");
1570
1571        // Get the request.
1572        let text = server.post(&"/text").text(&"hello!").await.text();
1573
1574        assert_eq!(text, "hello!");
1575    }
1576
1577    #[tokio::test]
1578    async fn it_should_pass_text_content_type_for_text() {
1579        let app = Router::new().route(
1580            "/content_type",
1581            post(|headers: HeaderMap| async move {
1582                headers
1583                    .get(CONTENT_TYPE)
1584                    .map(|h| h.to_str().unwrap().to_string())
1585                    .unwrap_or_else(|| "".to_string())
1586            }),
1587        );
1588
1589        // Run the server.
1590        let server = TestServer::new(app).expect("Should create test server");
1591
1592        // Get the request.
1593        let text = server.post(&"/content_type").text(&"hello!").await.text();
1594
1595        assert_eq!(text, "text/plain");
1596    }
1597
1598    #[tokio::test]
1599    async fn it_should_pass_large_text_blobs_over_mock_http() {
1600        const LARGE_BLOB_SIZE: usize = 16777216; // 16mb
1601        let large_blob = (0..LARGE_BLOB_SIZE).map(|_| "X").collect::<String>();
1602
1603        // Build an application with a route.
1604        let app = Router::new().route(
1605            "/text",
1606            post(|request: Request| async move {
1607                let body_bytes = request
1608                    .into_body()
1609                    .collect()
1610                    .await
1611                    .expect("Should read body to bytes")
1612                    .to_bytes();
1613
1614                String::from_utf8_lossy(&body_bytes).to_string()
1615            }),
1616        );
1617
1618        // Run the server.
1619        let server = TestServer::builder()
1620            .mock_transport()
1621            .build(app)
1622            .expect("Should create test server");
1623
1624        // Get the request.
1625        let text = server.post(&"/text").text(&large_blob).await.text();
1626
1627        assert_eq!(text.len(), LARGE_BLOB_SIZE);
1628        assert_eq!(text, large_blob);
1629    }
1630
1631    #[tokio::test]
1632    async fn it_should_pass_large_text_blobs_over_http() {
1633        const LARGE_BLOB_SIZE: usize = 16777216; // 16mb
1634        let large_blob = (0..LARGE_BLOB_SIZE).map(|_| "X").collect::<String>();
1635
1636        // Build an application with a route.
1637        let app = Router::new().route(
1638            "/text",
1639            post(|request: Request| async move {
1640                let body_bytes = request
1641                    .into_body()
1642                    .collect()
1643                    .await
1644                    .expect("Should read body to bytes")
1645                    .to_bytes();
1646
1647                String::from_utf8_lossy(&body_bytes).to_string()
1648            }),
1649        );
1650
1651        // Run the server.
1652        let server = TestServer::builder()
1653            .http_transport()
1654            .build(app)
1655            .expect("Should create test server");
1656
1657        // Get the request.
1658        let text = server.post(&"/text").text(&large_blob).await.text();
1659
1660        assert_eq!(text.len(), LARGE_BLOB_SIZE);
1661        assert_eq!(text, large_blob);
1662    }
1663}
1664
1665#[cfg(test)]
1666mod test_text_from_file {
1667    use crate::TestServer;
1668    use axum::extract::Request;
1669    use axum::routing::post;
1670    use axum::Router;
1671    use http::header::CONTENT_TYPE;
1672    use http::HeaderMap;
1673    use http_body_util::BodyExt;
1674
1675    #[tokio::test]
1676    async fn it_should_pass_text_up_to_be_read() {
1677        // Build an application with a route.
1678        let app = Router::new().route(
1679            "/text",
1680            post(|request: Request| async move {
1681                let body_bytes = request
1682                    .into_body()
1683                    .collect()
1684                    .await
1685                    .expect("Should read body to bytes")
1686                    .to_bytes();
1687
1688                String::from_utf8_lossy(&body_bytes).to_string()
1689            }),
1690        );
1691
1692        // Run the server.
1693        let server = TestServer::new(app).expect("Should create test server");
1694
1695        // Get the request.
1696        let text = server
1697            .post(&"/text")
1698            .text_from_file(&"files/example.txt")
1699            .await
1700            .text();
1701
1702        assert_eq!(text, "hello!");
1703    }
1704
1705    #[tokio::test]
1706    async fn it_should_pass_text_content_type_for_text() {
1707        // Build an application with a route.
1708        let app = Router::new().route(
1709            "/content_type",
1710            post(|headers: HeaderMap| async move {
1711                headers
1712                    .get(CONTENT_TYPE)
1713                    .map(|h| h.to_str().unwrap().to_string())
1714                    .unwrap_or_else(|| "".to_string())
1715            }),
1716        );
1717
1718        // Run the server.
1719        let server = TestServer::new(app).expect("Should create test server");
1720
1721        // Get the request.
1722        let text = server
1723            .post(&"/content_type")
1724            .text_from_file(&"files/example.txt")
1725            .await
1726            .text();
1727
1728        assert_eq!(text, "text/plain");
1729    }
1730}
1731
1732#[cfg(test)]
1733mod test_expect_success {
1734    use crate::TestServer;
1735    use axum::routing::get;
1736    use axum::Router;
1737    use http::StatusCode;
1738
1739    #[tokio::test]
1740    async fn it_should_not_panic_if_success_is_returned() {
1741        async fn get_ping() -> &'static str {
1742            "pong!"
1743        }
1744
1745        // Build an application with a route.
1746        let app = Router::new().route("/ping", get(get_ping));
1747
1748        // Run the server.
1749        let server = TestServer::new(app).expect("Should create test server");
1750
1751        // Get the request.
1752        server.get(&"/ping").expect_success().await;
1753    }
1754
1755    #[tokio::test]
1756    async fn it_should_not_panic_on_other_2xx_status_code() {
1757        async fn get_accepted() -> StatusCode {
1758            StatusCode::ACCEPTED
1759        }
1760
1761        // Build an application with a route.
1762        let app = Router::new().route("/accepted", get(get_accepted));
1763
1764        // Run the server.
1765        let server = TestServer::new(app).expect("Should create test server");
1766
1767        // Get the request.
1768        server.get(&"/accepted").expect_success().await;
1769    }
1770
1771    #[tokio::test]
1772    #[should_panic]
1773    async fn it_should_panic_on_404() {
1774        // Build an application with a route.
1775        let app = Router::new();
1776
1777        // Run the server.
1778        let server = TestServer::new(app).expect("Should create test server");
1779
1780        // Get the request.
1781        server.get(&"/some_unknown_route").expect_success().await;
1782    }
1783
1784    #[tokio::test]
1785    async fn it_should_override_what_test_server_has_set() {
1786        async fn get_ping() -> &'static str {
1787            "pong!"
1788        }
1789
1790        // Build an application with a route.
1791        let app = Router::new().route("/ping", get(get_ping));
1792
1793        // Run the server.
1794        let mut server = TestServer::new(app).expect("Should create test server");
1795        server.expect_failure();
1796
1797        // Get the request.
1798        server.get(&"/ping").expect_success().await;
1799    }
1800}
1801
1802#[cfg(test)]
1803mod test_expect_failure {
1804    use crate::TestServer;
1805    use axum::routing::get;
1806    use axum::Router;
1807    use http::StatusCode;
1808
1809    #[tokio::test]
1810    async fn it_should_not_panic_if_expect_failure_on_404() {
1811        // Build an application with a route.
1812        let app = Router::new();
1813
1814        // Run the server.
1815        let server = TestServer::new(app).expect("Should create test server");
1816
1817        // Get the request.
1818        server.get(&"/some_unknown_route").expect_failure().await;
1819    }
1820
1821    #[tokio::test]
1822    #[should_panic]
1823    async fn it_should_panic_if_success_is_returned() {
1824        async fn get_ping() -> &'static str {
1825            "pong!"
1826        }
1827
1828        // Build an application with a route.
1829        let app = Router::new().route("/ping", get(get_ping));
1830
1831        // Run the server.
1832        let server = TestServer::new(app).expect("Should create test server");
1833
1834        // Get the request.
1835        server.get(&"/ping").expect_failure().await;
1836    }
1837
1838    #[tokio::test]
1839    #[should_panic]
1840    async fn it_should_panic_on_other_2xx_status_code() {
1841        async fn get_accepted() -> StatusCode {
1842            StatusCode::ACCEPTED
1843        }
1844
1845        // Build an application with a route.
1846        let app = Router::new().route("/accepted", get(get_accepted));
1847
1848        // Run the server.
1849        let server = TestServer::new(app).expect("Should create test server");
1850
1851        // Get the request.
1852        server.get(&"/accepted").expect_failure().await;
1853    }
1854
1855    #[tokio::test]
1856    async fn it_should_should_override_what_test_server_has_set() {
1857        // Build an application with a route.
1858        let app = Router::new();
1859
1860        // Run the server.
1861        let mut server = TestServer::new(app).expect("Should create test server");
1862        server.expect_success();
1863
1864        // Get the request.
1865        server.get(&"/some_unknown_route").expect_failure().await;
1866    }
1867}
1868
1869#[cfg(test)]
1870mod test_add_cookie {
1871    use crate::TestServer;
1872    use axum::routing::get;
1873    use axum::Router;
1874    use axum_extra::extract::cookie::CookieJar;
1875    use cookie::time::Duration;
1876    use cookie::time::OffsetDateTime;
1877    use cookie::Cookie;
1878
1879    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
1880
1881    async fn get_cookie(cookies: CookieJar) -> (CookieJar, String) {
1882        let cookie = cookies.get(&TEST_COOKIE_NAME);
1883        let cookie_value = cookie
1884            .map(|c| c.value().to_string())
1885            .unwrap_or_else(|| "cookie-not-found".to_string());
1886
1887        (cookies, cookie_value)
1888    }
1889
1890    #[tokio::test]
1891    async fn it_should_send_cookies_added_to_request() {
1892        let app = Router::new().route("/cookie", get(get_cookie));
1893        let server = TestServer::new(app).expect("Should create test server");
1894
1895        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1896        let response_text = server.get(&"/cookie").add_cookie(cookie).await.text();
1897        assert_eq!(response_text, "my-custom-cookie");
1898    }
1899
1900    #[tokio::test]
1901    async fn it_should_send_non_expired_cookies_added_to_request() {
1902        let app = Router::new().route("/cookie", get(get_cookie));
1903        let server = TestServer::new(app).expect("Should create test server");
1904
1905        let mut cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1906        cookie.set_expires(
1907            OffsetDateTime::now_utc()
1908                .checked_add(Duration::minutes(10))
1909                .unwrap(),
1910        );
1911        let response_text = server.get(&"/cookie").add_cookie(cookie).await.text();
1912        assert_eq!(response_text, "my-custom-cookie");
1913    }
1914
1915    #[tokio::test]
1916    async fn it_should_not_send_expired_cookies_added_to_request() {
1917        let app = Router::new().route("/cookie", get(get_cookie));
1918        let server = TestServer::new(app).expect("Should create test server");
1919
1920        let mut cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1921        cookie.set_expires(OffsetDateTime::now_utc());
1922        let response_text = server.get(&"/cookie").add_cookie(cookie).await.text();
1923        assert_eq!(response_text, "cookie-not-found");
1924    }
1925}
1926
1927#[cfg(test)]
1928mod test_add_cookies {
1929    use crate::TestServer;
1930    use axum::http::header::HeaderMap;
1931    use axum::routing::get;
1932    use axum::Router;
1933    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1934    use cookie::Cookie;
1935    use cookie::CookieJar;
1936    use cookie::SameSite;
1937
1938    async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1939        let mut all_cookies = cookies
1940            .iter()
1941            .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1942            .collect::<Vec<String>>();
1943        all_cookies.sort();
1944
1945        all_cookies.join(&", ")
1946    }
1947
1948    async fn get_cookie_headers_joined(headers: HeaderMap) -> String {
1949        let cookies: String = headers
1950            .get_all("cookie")
1951            .into_iter()
1952            .map(|c| c.to_str().unwrap_or("").to_string())
1953            .reduce(|a, b| a + "; " + &b)
1954            .unwrap_or_else(|| String::new());
1955
1956        cookies
1957    }
1958
1959    #[tokio::test]
1960    async fn it_should_send_all_cookies_added_by_jar() {
1961        let app = Router::new().route("/cookies", get(route_get_cookies));
1962        let server = TestServer::new(app).expect("Should create test server");
1963
1964        // Build cookies to send up
1965        let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1966        let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1967        let mut cookie_jar = CookieJar::new();
1968        cookie_jar.add(cookie_1);
1969        cookie_jar.add(cookie_2);
1970
1971        server
1972            .get(&"/cookies")
1973            .add_cookies(cookie_jar)
1974            .await
1975            .assert_text("first-cookie=my-custom-cookie, second-cookie=other-cookie");
1976    }
1977
1978    #[tokio::test]
1979    async fn it_should_send_all_cookies_stripped_by_their_attributes() {
1980        let app = Router::new().route("/cookies", get(get_cookie_headers_joined));
1981        let server = TestServer::new(app).expect("Should create test server");
1982
1983        const TEST_COOKIE_NAME: &'static str = &"test-cookie";
1984        const TEST_COOKIE_VALUE: &'static str = &"my-custom-cookie";
1985
1986        // Build cookie to send up
1987        let cookie = Cookie::build((TEST_COOKIE_NAME, TEST_COOKIE_VALUE))
1988            .http_only(true)
1989            .secure(true)
1990            .same_site(SameSite::Strict)
1991            .path("/cookie")
1992            .build();
1993        let mut cookie_jar = CookieJar::new();
1994        cookie_jar.add(cookie);
1995
1996        server
1997            .get(&"/cookies")
1998            .add_cookies(cookie_jar)
1999            .await
2000            .assert_text(format!("{}={}", TEST_COOKIE_NAME, TEST_COOKIE_VALUE));
2001    }
2002}
2003
2004#[cfg(test)]
2005mod test_save_cookies {
2006    use crate::TestServer;
2007    use axum::extract::Request;
2008    use axum::http::header::HeaderMap;
2009    use axum::routing::get;
2010    use axum::routing::put;
2011    use axum::Router;
2012    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
2013    use cookie::Cookie;
2014    use cookie::SameSite;
2015    use http_body_util::BodyExt;
2016
2017    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
2018
2019    async fn put_cookie_with_attributes(
2020        mut cookies: AxumCookieJar,
2021        request: Request,
2022    ) -> (AxumCookieJar, &'static str) {
2023        let body_bytes = request
2024            .into_body()
2025            .collect()
2026            .await
2027            .expect("Should turn the body into bytes")
2028            .to_bytes();
2029
2030        let body_text: String = String::from_utf8_lossy(&body_bytes).to_string();
2031        let cookie = Cookie::build((TEST_COOKIE_NAME, body_text))
2032            .http_only(true)
2033            .secure(true)
2034            .same_site(SameSite::Strict)
2035            .path("/cookie")
2036            .build();
2037        cookies = cookies.add(cookie);
2038
2039        (cookies, &"done")
2040    }
2041
2042    async fn get_cookie_headers_joined(headers: HeaderMap) -> String {
2043        let cookies: String = headers
2044            .get_all("cookie")
2045            .into_iter()
2046            .map(|c| c.to_str().unwrap_or("").to_string())
2047            .reduce(|a, b| a + "; " + &b)
2048            .unwrap_or_else(|| String::new());
2049
2050        cookies
2051    }
2052
2053    #[tokio::test]
2054    async fn it_should_strip_cookies_from_their_attributes() {
2055        let app = Router::new()
2056            .route("/cookie", put(put_cookie_with_attributes))
2057            .route("/cookie", get(get_cookie_headers_joined));
2058        let server = TestServer::new(app).expect("Should create test server");
2059
2060        // Create a cookie.
2061        server
2062            .put(&"/cookie")
2063            .text(&"cookie-found!")
2064            .save_cookies()
2065            .await;
2066
2067        // Check, only the cookie names and their values should come back.
2068        let response_text = server.get(&"/cookie").await.text();
2069
2070        assert_eq!(response_text, format!("{}=cookie-found!", TEST_COOKIE_NAME));
2071    }
2072}
2073
2074#[cfg(test)]
2075mod test_do_not_save_cookies {
2076    use crate::TestServer;
2077    use axum::extract::Request;
2078    use axum::http::header::HeaderMap;
2079    use axum::routing::get;
2080    use axum::routing::put;
2081    use axum::Router;
2082    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
2083    use cookie::Cookie;
2084    use cookie::SameSite;
2085    use http_body_util::BodyExt;
2086
2087    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
2088
2089    async fn put_cookie_with_attributes(
2090        mut cookies: AxumCookieJar,
2091        request: Request,
2092    ) -> (AxumCookieJar, &'static str) {
2093        let body_bytes = request
2094            .into_body()
2095            .collect()
2096            .await
2097            .expect("Should turn the body into bytes")
2098            .to_bytes();
2099
2100        let body_text: String = String::from_utf8_lossy(&body_bytes).to_string();
2101        let cookie = Cookie::build((TEST_COOKIE_NAME, body_text))
2102            .http_only(true)
2103            .secure(true)
2104            .same_site(SameSite::Strict)
2105            .path("/cookie")
2106            .build();
2107        cookies = cookies.add(cookie);
2108
2109        (cookies, &"done")
2110    }
2111
2112    async fn get_cookie_headers_joined(headers: HeaderMap) -> String {
2113        let cookies: String = headers
2114            .get_all("cookie")
2115            .into_iter()
2116            .map(|c| c.to_str().unwrap_or("").to_string())
2117            .reduce(|a, b| a + "; " + &b)
2118            .unwrap_or_else(|| String::new());
2119
2120        cookies
2121    }
2122
2123    #[tokio::test]
2124    async fn it_should_not_save_cookies_when_set() {
2125        let app = Router::new()
2126            .route("/cookie", put(put_cookie_with_attributes))
2127            .route("/cookie", get(get_cookie_headers_joined));
2128        let server = TestServer::new(app).expect("Should create test server");
2129
2130        // Create a cookie.
2131        server
2132            .put(&"/cookie")
2133            .text(&"cookie-found!")
2134            .do_not_save_cookies()
2135            .await;
2136
2137        // Check, only the cookie names and their values should come back.
2138        let response_text = server.get(&"/cookie").await.text();
2139
2140        assert_eq!(response_text, "");
2141    }
2142
2143    #[tokio::test]
2144    async fn it_should_override_test_server_and_not_save_cookies_when_set() {
2145        let app = Router::new()
2146            .route("/cookie", put(put_cookie_with_attributes))
2147            .route("/cookie", get(get_cookie_headers_joined));
2148        let server = TestServer::builder()
2149            .save_cookies()
2150            .build(app)
2151            .expect("Should create test server");
2152
2153        // Create a cookie.
2154        server
2155            .put(&"/cookie")
2156            .text(&"cookie-found!")
2157            .do_not_save_cookies()
2158            .await;
2159
2160        // Check, only the cookie names and their values should come back.
2161        let response_text = server.get(&"/cookie").await.text();
2162
2163        assert_eq!(response_text, "");
2164    }
2165}
2166
2167#[cfg(test)]
2168mod test_clear_cookies {
2169    use crate::TestServer;
2170    use axum::extract::Request;
2171    use axum::routing::get;
2172    use axum::routing::put;
2173    use axum::Router;
2174    use axum_extra::extract::cookie::Cookie as AxumCookie;
2175    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
2176    use cookie::Cookie;
2177    use cookie::CookieJar;
2178    use http_body_util::BodyExt;
2179
2180    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
2181
2182    async fn get_cookie(cookies: AxumCookieJar) -> (AxumCookieJar, String) {
2183        let cookie = cookies.get(&TEST_COOKIE_NAME);
2184        let cookie_value = cookie
2185            .map(|c| c.value().to_string())
2186            .unwrap_or_else(|| "cookie-not-found".to_string());
2187
2188        (cookies, cookie_value)
2189    }
2190
2191    async fn put_cookie(
2192        mut cookies: AxumCookieJar,
2193        request: Request,
2194    ) -> (AxumCookieJar, &'static str) {
2195        let body_bytes = request
2196            .into_body()
2197            .collect()
2198            .await
2199            .expect("Should turn the body into bytes")
2200            .to_bytes();
2201
2202        let body_text: String = String::from_utf8_lossy(&body_bytes).to_string();
2203        let cookie = AxumCookie::new(TEST_COOKIE_NAME, body_text);
2204        cookies = cookies.add(cookie);
2205
2206        (cookies, &"done")
2207    }
2208
2209    #[tokio::test]
2210    async fn it_should_clear_cookie_added_to_request() {
2211        let app = Router::new().route("/cookie", get(get_cookie));
2212        let server = TestServer::new(app).expect("Should create test server");
2213
2214        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
2215        let response_text = server
2216            .get(&"/cookie")
2217            .add_cookie(cookie)
2218            .clear_cookies()
2219            .await
2220            .text();
2221
2222        assert_eq!(response_text, "cookie-not-found");
2223    }
2224
2225    #[tokio::test]
2226    async fn it_should_clear_cookie_jar_added_to_request() {
2227        let app = Router::new().route("/cookie", get(get_cookie));
2228        let server = TestServer::new(app).expect("Should create test server");
2229
2230        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
2231        let mut cookie_jar = CookieJar::new();
2232        cookie_jar.add(cookie);
2233
2234        let response_text = server
2235            .get(&"/cookie")
2236            .add_cookies(cookie_jar)
2237            .clear_cookies()
2238            .await
2239            .text();
2240
2241        assert_eq!(response_text, "cookie-not-found");
2242    }
2243
2244    #[tokio::test]
2245    async fn it_should_clear_cookies_saved_by_past_request() {
2246        let app = Router::new()
2247            .route("/cookie", put(put_cookie))
2248            .route("/cookie", get(get_cookie));
2249        let server = TestServer::new(app).expect("Should create test server");
2250
2251        // Create a cookie.
2252        server
2253            .put(&"/cookie")
2254            .text(&"cookie-found!")
2255            .save_cookies()
2256            .await;
2257
2258        // Check it comes back.
2259        let response_text = server.get(&"/cookie").clear_cookies().await.text();
2260
2261        assert_eq!(response_text, "cookie-not-found");
2262    }
2263
2264    #[tokio::test]
2265    async fn it_should_clear_cookies_added_to_test_server() {
2266        let app = Router::new()
2267            .route("/cookie", put(put_cookie))
2268            .route("/cookie", get(get_cookie));
2269        let mut server = TestServer::new(app).expect("Should create test server");
2270
2271        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
2272        server.add_cookie(cookie);
2273
2274        // Check it comes back.
2275        let response_text = server.get(&"/cookie").clear_cookies().await.text();
2276
2277        assert_eq!(response_text, "cookie-not-found");
2278    }
2279}
2280
2281#[cfg(test)]
2282mod test_add_header {
2283    use super::*;
2284    use crate::TestServer;
2285    use axum::extract::FromRequestParts;
2286    use axum::routing::get;
2287    use axum::Router;
2288    use http::request::Parts;
2289    use http::HeaderName;
2290    use http::HeaderValue;
2291    use hyper::StatusCode;
2292    use std::marker::Sync;
2293
2294    const TEST_HEADER_NAME: &'static str = &"test-header";
2295    const TEST_HEADER_CONTENT: &'static str = &"Test header content";
2296
2297    struct TestHeader(Vec<u8>);
2298
2299    impl<S: Sync> FromRequestParts<S> for TestHeader {
2300        type Rejection = (StatusCode, &'static str);
2301
2302        async fn from_request_parts(
2303            parts: &mut Parts,
2304            _state: &S,
2305        ) -> Result<TestHeader, Self::Rejection> {
2306            parts
2307                .headers
2308                .get(HeaderName::from_static(TEST_HEADER_NAME))
2309                .map(|v| TestHeader(v.as_bytes().to_vec()))
2310                .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
2311        }
2312    }
2313
2314    async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
2315        header
2316    }
2317
2318    #[tokio::test]
2319    async fn it_should_send_header_added_to_request() {
2320        // Build an application with a route.
2321        let app = Router::new().route("/header", get(ping_header));
2322
2323        // Run the server.
2324        let server = TestServer::new(app).expect("Should create test server");
2325
2326        // Send a request with the header
2327        let response = server
2328            .get(&"/header")
2329            .add_header(
2330                HeaderName::from_static(TEST_HEADER_NAME),
2331                HeaderValue::from_static(TEST_HEADER_CONTENT),
2332            )
2333            .await;
2334
2335        // Check it sent back the right text
2336        response.assert_text(TEST_HEADER_CONTENT)
2337    }
2338}
2339
2340#[cfg(test)]
2341mod test_authorization {
2342    use super::*;
2343    use crate::TestServer;
2344    use axum::extract::FromRequestParts;
2345    use axum::routing::get;
2346    use axum::Router;
2347    use http::request::Parts;
2348    use hyper::StatusCode;
2349    use std::marker::Sync;
2350
2351    fn new_test_server() -> TestServer {
2352        struct TestHeader(String);
2353
2354        impl<S: Sync> FromRequestParts<S> for TestHeader {
2355            type Rejection = (StatusCode, &'static str);
2356
2357            async fn from_request_parts(
2358                parts: &mut Parts,
2359                _state: &S,
2360            ) -> Result<TestHeader, Self::Rejection> {
2361                parts
2362                    .headers
2363                    .get(header::AUTHORIZATION)
2364                    .map(|v| TestHeader(v.to_str().unwrap().to_string()))
2365                    .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
2366            }
2367        }
2368
2369        async fn ping_auth_header(TestHeader(header): TestHeader) -> String {
2370            header
2371        }
2372
2373        // Build an application with a route.
2374        let app = Router::new().route("/auth-header", get(ping_auth_header));
2375
2376        // Run the server.
2377        let mut server = TestServer::new(app).expect("Should create test server");
2378        server.expect_success();
2379
2380        server
2381    }
2382
2383    #[tokio::test]
2384    async fn it_should_send_header_added_to_request() {
2385        let server = new_test_server();
2386
2387        // Send a request with the header
2388        let response = server
2389            .get(&"/auth-header")
2390            .authorization("Bearer abc123")
2391            .await;
2392
2393        // Check it sent back the right text
2394        response.assert_text("Bearer abc123")
2395    }
2396}
2397
2398#[cfg(test)]
2399mod test_authorization_bearer {
2400    use super::*;
2401    use crate::TestServer;
2402    use axum::extract::FromRequestParts;
2403    use axum::routing::get;
2404    use axum::Router;
2405    use http::request::Parts;
2406    use hyper::StatusCode;
2407    use std::marker::Sync;
2408
2409    fn new_test_server() -> TestServer {
2410        struct TestHeader(String);
2411
2412        impl<S: Sync> FromRequestParts<S> for TestHeader {
2413            type Rejection = (StatusCode, &'static str);
2414
2415            async fn from_request_parts(
2416                parts: &mut Parts,
2417                _state: &S,
2418            ) -> Result<TestHeader, Self::Rejection> {
2419                parts
2420                    .headers
2421                    .get(header::AUTHORIZATION)
2422                    .map(|v| TestHeader(v.to_str().unwrap().to_string().replace("Bearer ", "")))
2423                    .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
2424            }
2425        }
2426
2427        async fn ping_auth_header(TestHeader(header): TestHeader) -> String {
2428            header
2429        }
2430
2431        // Build an application with a route.
2432        let app = Router::new().route("/auth-header", get(ping_auth_header));
2433
2434        // Run the server.
2435        let mut server = TestServer::new(app).expect("Should create test server");
2436        server.expect_success();
2437
2438        server
2439    }
2440
2441    #[tokio::test]
2442    async fn it_should_send_header_added_to_request() {
2443        let server = new_test_server();
2444
2445        // Send a request with the header
2446        let response = server
2447            .get(&"/auth-header")
2448            .authorization_bearer("abc123")
2449            .await;
2450
2451        // Check it sent back the right text
2452        response.assert_text("abc123")
2453    }
2454}
2455
2456#[cfg(test)]
2457mod test_clear_headers {
2458    use super::*;
2459    use crate::TestServer;
2460    use axum::extract::FromRequestParts;
2461    use axum::routing::get;
2462    use axum::Router;
2463    use http::request::Parts;
2464    use http::HeaderName;
2465    use http::HeaderValue;
2466    use hyper::StatusCode;
2467    use std::marker::Sync;
2468
2469    const TEST_HEADER_NAME: &'static str = &"test-header";
2470    const TEST_HEADER_CONTENT: &'static str = &"Test header content";
2471
2472    struct TestHeader(Vec<u8>);
2473
2474    impl<S: Sync> FromRequestParts<S> for TestHeader {
2475        type Rejection = (StatusCode, &'static str);
2476
2477        async fn from_request_parts(
2478            parts: &mut Parts,
2479            _state: &S,
2480        ) -> Result<TestHeader, Self::Rejection> {
2481            parts
2482                .headers
2483                .get(HeaderName::from_static(TEST_HEADER_NAME))
2484                .map(|v| TestHeader(v.as_bytes().to_vec()))
2485                .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
2486        }
2487    }
2488
2489    async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
2490        header
2491    }
2492
2493    #[tokio::test]
2494    async fn it_should_clear_headers_added_to_request() {
2495        // Build an application with a route.
2496        let app = Router::new().route("/header", get(ping_header));
2497
2498        // Run the server.
2499        let server = TestServer::new(app).expect("Should create test server");
2500
2501        // Send a request with the header
2502        let response = server
2503            .get(&"/header")
2504            .add_header(
2505                HeaderName::from_static(TEST_HEADER_NAME),
2506                HeaderValue::from_static(TEST_HEADER_CONTENT),
2507            )
2508            .clear_headers()
2509            .await;
2510
2511        // Check it sent back the right text
2512        response.assert_status_bad_request();
2513        response.assert_text("Missing test header");
2514    }
2515
2516    #[tokio::test]
2517    async fn it_should_clear_headers_added_to_server() {
2518        // Build an application with a route.
2519        let app = Router::new().route("/header", get(ping_header));
2520
2521        // Run the server.
2522        let mut server = TestServer::new(app).expect("Should create test server");
2523        server.add_header(
2524            HeaderName::from_static(TEST_HEADER_NAME),
2525            HeaderValue::from_static(TEST_HEADER_CONTENT),
2526        );
2527
2528        // Send a request with the header
2529        let response = server.get(&"/header").clear_headers().await;
2530
2531        // Check it sent back the right text
2532        response.assert_status_bad_request();
2533        response.assert_text("Missing test header");
2534    }
2535}
2536
2537#[cfg(test)]
2538mod test_add_query_params {
2539    use crate::TestServer;
2540    use axum::extract::Query as AxumStdQuery;
2541    use axum::routing::get;
2542    use axum::Router;
2543    use serde::Deserialize;
2544    use serde::Serialize;
2545    use serde_json::json;
2546
2547    #[derive(Debug, Deserialize, Serialize)]
2548    struct QueryParam {
2549        message: String,
2550    }
2551
2552    async fn get_query_param(AxumStdQuery(params): AxumStdQuery<QueryParam>) -> String {
2553        params.message
2554    }
2555
2556    #[derive(Debug, Deserialize, Serialize)]
2557    struct QueryParam2 {
2558        message: String,
2559        other: String,
2560    }
2561
2562    async fn get_query_param_2(AxumStdQuery(params): AxumStdQuery<QueryParam2>) -> String {
2563        format!("{}-{}", params.message, params.other)
2564    }
2565
2566    fn build_app() -> Router {
2567        Router::new()
2568            .route("/query", get(get_query_param))
2569            .route("/query-2", get(get_query_param_2))
2570    }
2571
2572    #[tokio::test]
2573    async fn it_should_pass_up_query_params_from_serialization() {
2574        // Run the server.
2575        let server = TestServer::new(build_app()).expect("Should create test server");
2576
2577        // Get the request.
2578        server
2579            .get(&"/query")
2580            .add_query_params(QueryParam {
2581                message: "it works".to_string(),
2582            })
2583            .await
2584            .assert_text(&"it works");
2585    }
2586
2587    #[tokio::test]
2588    async fn it_should_pass_up_query_params_from_pairs() {
2589        // Run the server.
2590        let server = TestServer::new(build_app()).expect("Should create test server");
2591
2592        // Get the request.
2593        server
2594            .get(&"/query")
2595            .add_query_params(&[("message", "it works")])
2596            .await
2597            .assert_text(&"it works");
2598    }
2599
2600    #[tokio::test]
2601    async fn it_should_pass_up_multiple_query_params_from_multiple_params() {
2602        // Run the server.
2603        let server = TestServer::new(build_app()).expect("Should create test server");
2604
2605        // Get the request.
2606        server
2607            .get(&"/query-2")
2608            .add_query_params(&[("message", "it works"), ("other", "yup")])
2609            .await
2610            .assert_text(&"it works-yup");
2611    }
2612
2613    #[tokio::test]
2614    async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
2615        // Run the server.
2616        let server = TestServer::new(build_app()).expect("Should create test server");
2617
2618        // Get the request.
2619        server
2620            .get(&"/query-2")
2621            .add_query_params(&[("message", "it works")])
2622            .add_query_params(&[("other", "yup")])
2623            .await
2624            .assert_text(&"it works-yup");
2625    }
2626
2627    #[tokio::test]
2628    async fn it_should_pass_up_multiple_query_params_from_json() {
2629        // Run the server.
2630        let server = TestServer::new(build_app()).expect("Should create test server");
2631
2632        // Get the request.
2633        server
2634            .get(&"/query-2")
2635            .add_query_params(json!({
2636                "message": "it works",
2637                "other": "yup"
2638            }))
2639            .await
2640            .assert_text(&"it works-yup");
2641    }
2642}
2643
2644#[cfg(test)]
2645mod test_add_raw_query_param {
2646    use crate::TestServer;
2647    use axum::extract::Query as AxumStdQuery;
2648    use axum::routing::get;
2649    use axum::Router;
2650    use axum_extra::extract::Query as AxumExtraQuery;
2651    use serde::Deserialize;
2652    use serde::Serialize;
2653    use std::fmt::Write;
2654
2655    #[derive(Debug, Deserialize, Serialize)]
2656    struct QueryParam {
2657        message: String,
2658    }
2659
2660    async fn get_query_param(AxumStdQuery(params): AxumStdQuery<QueryParam>) -> String {
2661        params.message
2662    }
2663
2664    #[derive(Debug, Deserialize, Serialize)]
2665    struct QueryParamExtra {
2666        #[serde(default)]
2667        items: Vec<String>,
2668
2669        #[serde(default, rename = "arrs[]")]
2670        arrs: Vec<String>,
2671    }
2672
2673    async fn get_query_param_extra(
2674        AxumExtraQuery(params): AxumExtraQuery<QueryParamExtra>,
2675    ) -> String {
2676        let mut output = String::new();
2677
2678        if params.items.len() > 0 {
2679            write!(output, "{}", params.items.join(", ")).unwrap();
2680        }
2681
2682        if params.arrs.len() > 0 {
2683            write!(output, "{}", params.arrs.join(", ")).unwrap();
2684        }
2685
2686        output
2687    }
2688
2689    fn build_app() -> Router {
2690        Router::new()
2691            .route("/query", get(get_query_param))
2692            .route("/query-extra", get(get_query_param_extra))
2693    }
2694
2695    #[tokio::test]
2696    async fn it_should_pass_up_query_param_as_is() {
2697        // Run the server.
2698        let server = TestServer::new(build_app()).expect("Should create test server");
2699
2700        // Get the request.
2701        server
2702            .get(&"/query")
2703            .add_raw_query_param(&"message=it-works")
2704            .await
2705            .assert_text(&"it-works");
2706    }
2707
2708    #[tokio::test]
2709    async fn it_should_pass_up_array_query_params_as_one_string() {
2710        // Run the server.
2711        let server = TestServer::new(build_app()).expect("Should create test server");
2712
2713        // Get the request.
2714        server
2715            .get(&"/query-extra")
2716            .add_raw_query_param(&"items=one&items=two&items=three")
2717            .await
2718            .assert_text(&"one, two, three");
2719    }
2720
2721    #[tokio::test]
2722    async fn it_should_pass_up_array_query_params_as_multiple_params() {
2723        // Run the server.
2724        let server = TestServer::new(build_app()).expect("Should create test server");
2725
2726        // Get the request.
2727        server
2728            .get(&"/query-extra")
2729            .add_raw_query_param(&"arrs[]=one")
2730            .add_raw_query_param(&"arrs[]=two")
2731            .add_raw_query_param(&"arrs[]=three")
2732            .await
2733            .assert_text(&"one, two, three");
2734    }
2735}
2736
2737#[cfg(test)]
2738mod test_add_query_param {
2739    use crate::TestServer;
2740    use axum::extract::Query;
2741    use axum::routing::get;
2742    use axum::Router;
2743    use serde::Deserialize;
2744    use serde::Serialize;
2745
2746    #[derive(Debug, Deserialize, Serialize)]
2747    struct QueryParam {
2748        message: String,
2749    }
2750
2751    async fn get_query_param(Query(params): Query<QueryParam>) -> String {
2752        params.message
2753    }
2754
2755    #[derive(Debug, Deserialize, Serialize)]
2756    struct QueryParam2 {
2757        message: String,
2758        other: String,
2759    }
2760
2761    async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
2762        format!("{}-{}", params.message, params.other)
2763    }
2764
2765    #[tokio::test]
2766    async fn it_should_pass_up_query_params_from_pairs() {
2767        // Build an application with a route.
2768        let app = Router::new().route("/query", get(get_query_param));
2769
2770        // Run the server.
2771        let server = TestServer::new(app).expect("Should create test server");
2772
2773        // Get the request.
2774        server
2775            .get(&"/query")
2776            .add_query_param("message", "it works")
2777            .await
2778            .assert_text(&"it works");
2779    }
2780
2781    #[tokio::test]
2782    async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
2783        // Build an application with a route.
2784        let app = Router::new().route("/query-2", get(get_query_param_2));
2785
2786        // Run the server.
2787        let server = TestServer::new(app).expect("Should create test server");
2788
2789        // Get the request.
2790        server
2791            .get(&"/query-2")
2792            .add_query_param("message", "it works")
2793            .add_query_param("other", "yup")
2794            .await
2795            .assert_text(&"it works-yup");
2796    }
2797}
2798
2799#[cfg(test)]
2800mod test_clear_query_params {
2801    use crate::TestServer;
2802    use axum::extract::Query;
2803    use axum::routing::get;
2804    use axum::Router;
2805    use serde::Deserialize;
2806    use serde::Serialize;
2807
2808    #[derive(Debug, Deserialize, Serialize)]
2809    struct QueryParams {
2810        first: Option<String>,
2811        second: Option<String>,
2812    }
2813
2814    async fn get_query_params(Query(params): Query<QueryParams>) -> String {
2815        format!(
2816            "has first? {}, has second? {}",
2817            params.first.is_some(),
2818            params.second.is_some()
2819        )
2820    }
2821
2822    #[tokio::test]
2823    async fn it_should_clear_all_params_set() {
2824        // Build an application with a route.
2825        let app = Router::new().route("/query", get(get_query_params));
2826
2827        // Run the server.
2828        let server = TestServer::new(app).expect("Should create test server");
2829
2830        // Get the request.
2831        server
2832            .get(&"/query")
2833            .add_query_params(QueryParams {
2834                first: Some("first".to_string()),
2835                second: Some("second".to_string()),
2836            })
2837            .clear_query_params()
2838            .await
2839            .assert_text(&"has first? false, has second? false");
2840    }
2841
2842    #[tokio::test]
2843    async fn it_should_clear_all_params_set_and_allow_replacement() {
2844        // Build an application with a route.
2845        let app = Router::new().route("/query", get(get_query_params));
2846
2847        // Run the server.
2848        let server = TestServer::new(app).expect("Should create test server");
2849
2850        // Get the request.
2851        server
2852            .get(&"/query")
2853            .add_query_params(QueryParams {
2854                first: Some("first".to_string()),
2855                second: Some("second".to_string()),
2856            })
2857            .clear_query_params()
2858            .add_query_params(QueryParams {
2859                first: Some("first".to_string()),
2860                second: Some("second".to_string()),
2861            })
2862            .await
2863            .assert_text(&"has first? true, has second? true");
2864    }
2865}
2866
2867#[cfg(test)]
2868mod test_scheme {
2869    use crate::TestServer;
2870    use axum::extract::Request;
2871    use axum::routing::get;
2872    use axum::Router;
2873
2874    async fn route_get_scheme(request: Request) -> String {
2875        request.uri().scheme_str().unwrap().to_string()
2876    }
2877
2878    #[tokio::test]
2879    async fn it_should_return_http_by_default() {
2880        let router = Router::new().route("/scheme", get(route_get_scheme));
2881        let server = TestServer::builder().build(router).unwrap();
2882
2883        server.get("/scheme").await.assert_text("http");
2884    }
2885
2886    #[tokio::test]
2887    async fn it_should_return_http_when_set() {
2888        let router = Router::new().route("/scheme", get(route_get_scheme));
2889        let server = TestServer::builder().build(router).unwrap();
2890
2891        server
2892            .get("/scheme")
2893            .scheme(&"http")
2894            .await
2895            .assert_text("http");
2896    }
2897
2898    #[tokio::test]
2899    async fn it_should_return_https_when_set() {
2900        let router = Router::new().route("/scheme", get(route_get_scheme));
2901        let server = TestServer::builder().build(router).unwrap();
2902
2903        server
2904            .get("/scheme")
2905            .scheme(&"https")
2906            .await
2907            .assert_text("https");
2908    }
2909
2910    #[tokio::test]
2911    async fn it_should_override_test_server_when_set() {
2912        let router = Router::new().route("/scheme", get(route_get_scheme));
2913
2914        let mut server = TestServer::builder().build(router).unwrap();
2915        server.scheme(&"https");
2916
2917        server
2918            .get("/scheme")
2919            .scheme(&"http") // set it back to http
2920            .await
2921            .assert_text("http");
2922    }
2923}
2924
2925#[cfg(test)]
2926mod test_multipart {
2927    use crate::multipart::MultipartForm;
2928    use crate::multipart::Part;
2929    use crate::TestServer;
2930    use axum::extract::Multipart;
2931    use axum::routing::post;
2932    use axum::Json;
2933    use axum::Router;
2934    use serde_json::json;
2935    use serde_json::Value;
2936
2937    async fn route_post_multipart(mut multipart: Multipart) -> Json<Vec<String>> {
2938        let mut fields = vec![];
2939
2940        while let Some(field) = multipart.next_field().await.unwrap() {
2941            let name = field.name().unwrap().to_string();
2942            let content_type = field.content_type().unwrap().to_owned();
2943            let data = field.bytes().await.unwrap();
2944
2945            let field_stats = format!("{name} is {} bytes, {content_type}", data.len());
2946            fields.push(field_stats);
2947        }
2948
2949        Json(fields)
2950    }
2951
2952    async fn route_post_multipart_headers(mut multipart: Multipart) -> Json<Vec<Value>> {
2953        let mut sent_part_headers = vec![];
2954
2955        while let Some(field) = multipart.next_field().await.unwrap() {
2956            let part_name = field.name().unwrap().to_string();
2957            let part_header_value = field
2958                .headers()
2959                .get("x-part-header-test")
2960                .unwrap()
2961                .to_str()
2962                .unwrap()
2963                .to_string();
2964            let part_text = String::from_utf8(field.bytes().await.unwrap().into()).unwrap();
2965
2966            sent_part_headers.push(json!({
2967                "name": part_name,
2968                "text": part_text,
2969                "header": part_header_value,
2970            }))
2971        }
2972
2973        Json(sent_part_headers)
2974    }
2975
2976    fn test_router() -> Router {
2977        Router::new()
2978            .route("/multipart", post(route_post_multipart))
2979            .route("/multipart_headers", post(route_post_multipart_headers))
2980    }
2981
2982    #[tokio::test]
2983    async fn it_should_get_multipart_stats_on_mock_transport() {
2984        // Run the server.
2985        let server = TestServer::builder()
2986            .mock_transport()
2987            .build(test_router())
2988            .expect("Should create test server");
2989
2990        let form = MultipartForm::new()
2991            .add_text("penguins?", "lots")
2992            .add_text("animals", "🦊🦊🦊")
2993            .add_text("carrots", 123 as u32);
2994
2995        // Get the request.
2996        server
2997            .post(&"/multipart")
2998            .multipart(form)
2999            .await
3000            .assert_json(&vec![
3001                "penguins? is 4 bytes, text/plain".to_string(),
3002                "animals is 12 bytes, text/plain".to_string(),
3003                "carrots is 3 bytes, text/plain".to_string(),
3004            ]);
3005    }
3006
3007    #[tokio::test]
3008    async fn it_should_get_multipart_stats_on_http_transport() {
3009        // Run the server.
3010        let server = TestServer::builder()
3011            .http_transport()
3012            .build(test_router())
3013            .expect("Should create test server");
3014
3015        let form = MultipartForm::new()
3016            .add_text("penguins?", "lots")
3017            .add_text("animals", "🦊🦊🦊")
3018            .add_text("carrots", 123 as u32);
3019
3020        // Get the request.
3021        server
3022            .post(&"/multipart")
3023            .multipart(form)
3024            .await
3025            .assert_json(&vec![
3026                "penguins? is 4 bytes, text/plain".to_string(),
3027                "animals is 12 bytes, text/plain".to_string(),
3028                "carrots is 3 bytes, text/plain".to_string(),
3029            ]);
3030    }
3031
3032    #[tokio::test]
3033    async fn it_should_send_text_parts_as_text() {
3034        // Run the server.
3035        let server = TestServer::builder()
3036            .mock_transport()
3037            .build(test_router())
3038            .expect("Should create test server");
3039
3040        let form = MultipartForm::new().add_part("animals", Part::text("🦊🦊🦊"));
3041
3042        // Get the request.
3043        server
3044            .post(&"/multipart")
3045            .multipart(form)
3046            .await
3047            .assert_json(&vec!["animals is 12 bytes, text/plain".to_string()]);
3048    }
3049
3050    #[tokio::test]
3051    async fn it_should_send_custom_mime_type() {
3052        // Run the server.
3053        let server = TestServer::builder()
3054            .mock_transport()
3055            .build(test_router())
3056            .expect("Should create test server");
3057
3058        let form = MultipartForm::new().add_part(
3059            "animals",
3060            Part::bytes("🦊,🦊,🦊".as_bytes()).mime_type(mime::TEXT_CSV),
3061        );
3062
3063        // Get the request.
3064        server
3065            .post(&"/multipart")
3066            .multipart(form)
3067            .await
3068            .assert_json(&vec!["animals is 14 bytes, text/csv".to_string()]);
3069    }
3070
3071    #[tokio::test]
3072    async fn it_should_send_using_include_bytes() {
3073        // Run the server.
3074        let server = TestServer::builder()
3075            .mock_transport()
3076            .build(test_router())
3077            .expect("Should create test server");
3078
3079        let form = MultipartForm::new().add_part(
3080            "file",
3081            Part::bytes(include_bytes!("../files/example.txt").as_slice())
3082                .mime_type(mime::TEXT_PLAIN),
3083        );
3084
3085        // Get the request.
3086        server
3087            .post(&"/multipart")
3088            .multipart(form)
3089            .await
3090            .assert_json(&vec!["file is 6 bytes, text/plain".to_string()]);
3091    }
3092
3093    #[tokio::test]
3094    async fn it_should_send_form_headers_in_parts() {
3095        // Run the server.
3096        let server = TestServer::builder()
3097            .mock_transport()
3098            .build(test_router())
3099            .expect("Should create test server");
3100
3101        let form = MultipartForm::new()
3102            .add_part(
3103                "part_1",
3104                Part::text("part_1_text").add_header("x-part-header-test", "part_1_header"),
3105            )
3106            .add_part(
3107                "part_2",
3108                Part::text("part_2_text").add_header("x-part-header-test", "part_2_header"),
3109            );
3110
3111        // Get the request.
3112        server
3113            .post(&"/multipart_headers")
3114            .multipart(form)
3115            .await
3116            .assert_json(&json!([
3117                {
3118                    "name": "part_1",
3119                    "text": "part_1_text",
3120                    "header": "part_1_header",
3121                },
3122                {
3123                    "name": "part_2",
3124                    "text": "part_2_text",
3125                    "header": "part_2_header",
3126                },
3127            ]));
3128    }
3129}