axum_test/
test_request.rs

1use anyhow::anyhow;
2use anyhow::Context;
3use anyhow::Error as AnyhowError;
4use anyhow::Result;
5use auto_future::AutoFuture;
6use axum::body::Body;
7use bytes::Bytes;
8use cookie::time::OffsetDateTime;
9use cookie::Cookie;
10use cookie::CookieJar;
11use http::header;
12use http::header::SET_COOKIE;
13use http::HeaderName;
14use http::HeaderValue;
15use http::Method;
16use http::Request;
17use http_body_util::BodyExt;
18use serde::Serialize;
19use std::fmt::Debug;
20use std::fmt::Display;
21use std::fs::read;
22use std::fs::read_to_string;
23use std::fs::File;
24use std::future::IntoFuture;
25use std::io::BufReader;
26use std::path::Path;
27use std::sync::Arc;
28use std::sync::Mutex;
29use url::Url;
30
31use crate::internals::ExpectedState;
32use crate::internals::QueryParamsStore;
33use crate::internals::RequestPathFormatter;
34use crate::multipart::MultipartForm;
35use crate::transport_layer::TransportLayer;
36use crate::ServerSharedState;
37use crate::TestResponse;
38
39mod test_request_config;
40pub(crate) use self::test_request_config::*;
41
42///
43/// A `TestRequest` is for building and executing a HTTP request to the [`TestServer`](crate::TestServer).
44///
45/// ## Building
46///
47/// Requests are created by the [`TestServer`](crate::TestServer), using it's builder functions.
48/// They correspond to the appropriate HTTP method: [`TestServer::get()`](crate::TestServer::get()),
49/// [`TestServer::post()`](crate::TestServer::post()), etc.
50///
51/// See there for documentation.
52///
53/// ## Customising
54///
55/// The `TestRequest` allows the caller to fill in the rest of the request
56/// to be sent to the server. Including the headers, the body, cookies,
57/// and the content type, using the relevant functions.
58///
59/// The TestRequest struct provides a number of methods to set up the request,
60/// such as json, text, bytes, expect_failure, content_type, etc.
61///
62/// ## Sending
63///
64/// Once fully configured you send the request by awaiting the request object.
65///
66/// ```rust
67/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
68/// #
69/// # use axum::Router;
70/// # use axum_test::TestServer;
71/// #
72/// # let server = TestServer::new(Router::new())?;
73/// #
74/// // Build your request
75/// let request = server.get(&"/user")
76///     .add_header("x-custom-header", "example.com")
77///     .content_type("application/yaml");
78///
79/// // await request to execute
80/// let response = request.await;
81/// #
82/// # Ok(()) }
83/// ```
84///
85/// You will receive a `TestResponse`.
86///
87/// ## Cookie Saving
88///
89/// [`TestRequest::save_cookies()`](crate::TestRequest::save_cookies()) and [`TestRequest::do_not_save_cookies()`](crate::TestRequest::do_not_save_cookies())
90/// methods allow you to set the request to save cookies to the `TestServer`,
91/// for reuse on any future requests.
92///
93/// This behaviour is **off** by default, and can be changed for all `TestRequests`
94/// when building the `TestServer`. By building it with a `TestServerConfig` where `save_cookies` is set to true.
95///
96/// ## Expecting Failure and Success
97///
98/// When making a request you can mark it to expect a response within,
99/// or outside, of the 2xx range of HTTP status codes.
100///
101/// If the response returns a status code different to what is expected,
102/// then it will panic.
103///
104/// This is useful when making multiple requests within a test.
105/// As it can find issues earlier than later.
106///
107/// See the [`TestRequest::expect_failure()`](crate::TestRequest::expect_failure()),
108/// and [`TestRequest::expect_success()`](crate::TestRequest::expect_success()).
109///
110#[derive(Debug)]
111#[must_use = "futures do nothing unless polled"]
112pub struct TestRequest {
113    config: TestRequestConfig,
114
115    server_state: Arc<Mutex<ServerSharedState>>,
116    transport: Arc<Box<dyn TransportLayer>>,
117
118    body: Option<Body>,
119
120    expected_state: ExpectedState,
121}
122
123impl TestRequest {
124    pub(crate) fn new(
125        server_state: Arc<Mutex<ServerSharedState>>,
126        transport: Arc<Box<dyn TransportLayer>>,
127        config: TestRequestConfig,
128    ) -> Self {
129        let expected_state = config.expected_state;
130
131        Self {
132            config,
133            server_state,
134            transport,
135            body: None,
136            expected_state,
137        }
138    }
139
140    /// Set the body of the request to send up data as Json,
141    /// and changes the content type to `application/json`.
142    pub fn json<J>(self, body: &J) -> Self
143    where
144        J: ?Sized + Serialize,
145    {
146        let body_bytes =
147            serde_json::to_vec(body).expect("It should serialize the content into Json");
148
149        self.bytes(body_bytes.into())
150            .content_type(mime::APPLICATION_JSON.essence_str())
151    }
152
153    /// Sends a payload as a Json request, with the contents coming from a file.
154    pub fn json_from_file<P>(self, path: P) -> Self
155    where
156        P: AsRef<Path>,
157    {
158        let path_ref = path.as_ref();
159        let file = File::open(path_ref)
160            .with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
161            .unwrap();
162
163        let reader = BufReader::new(file);
164        let payload = serde_json::from_reader::<_, serde_json::Value>(reader)
165            .with_context(|| {
166                format!(
167                    "Failed to deserialize file '{}' as Json",
168                    path_ref.display()
169                )
170            })
171            .unwrap();
172
173        self.json(&payload)
174    }
175
176    /// Set the body of the request to send up data as Yaml,
177    /// and changes the content type to `application/yaml`.
178    #[cfg(feature = "yaml")]
179    pub fn yaml<Y>(self, body: &Y) -> Self
180    where
181        Y: ?Sized + Serialize,
182    {
183        let body = serde_yaml::to_string(body).expect("It should serialize the content into Yaml");
184
185        self.bytes(body.into_bytes().into())
186            .content_type("application/yaml")
187    }
188
189    /// Sends a payload as a Yaml request, with the contents coming from a file.
190    #[cfg(feature = "yaml")]
191    pub fn yaml_from_file<P>(self, path: P) -> Self
192    where
193        P: AsRef<Path>,
194    {
195        let path_ref = path.as_ref();
196        let file = File::open(path_ref)
197            .with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
198            .unwrap();
199
200        let reader = BufReader::new(file);
201        let payload = serde_yaml::from_reader::<_, serde_yaml::Value>(reader)
202            .with_context(|| {
203                format!(
204                    "Failed to deserialize file '{}' as Yaml",
205                    path_ref.display()
206                )
207            })
208            .unwrap();
209
210        self.yaml(&payload)
211    }
212
213    /// Set the body of the request to send up data as MsgPack,
214    /// and changes the content type to `application/msgpack`.
215    #[cfg(feature = "msgpack")]
216    pub fn msgpack<M>(self, body: &M) -> Self
217    where
218        M: ?Sized + Serialize,
219    {
220        let body_bytes =
221            ::rmp_serde::to_vec(body).expect("It should serialize the content into MsgPack");
222
223        self.bytes(body_bytes.into())
224            .content_type("application/msgpack")
225    }
226
227    /// Sets the body of the request, with the content type
228    /// of 'application/x-www-form-urlencoded'.
229    pub fn form<F>(self, body: &F) -> Self
230    where
231        F: ?Sized + Serialize,
232    {
233        let body_text =
234            serde_urlencoded::to_string(body).expect("It should serialize the content into a Form");
235
236        self.bytes(body_text.into())
237            .content_type(mime::APPLICATION_WWW_FORM_URLENCODED.essence_str())
238    }
239
240    /// For sending multipart forms.
241    /// The payload is built using [`MultipartForm`](crate::multipart::MultipartForm) and [`Part`](crate::multipart::Part).
242    ///
243    /// This will be sent with the content type of 'multipart/form-data'.
244    ///
245    /// # Simple example
246    ///
247    /// ```rust
248    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
249    /// #
250    /// use axum::Router;
251    /// use axum_test::TestServer;
252    /// use axum_test::multipart::MultipartForm;
253    ///
254    /// let app = Router::new();
255    /// let server = TestServer::new(app)?;
256    ///
257    /// let multipart_form = MultipartForm::new()
258    ///     .add_text("name", "Joe")
259    ///     .add_text("animals", "foxes");
260    ///
261    /// let response = server.post(&"/my-form")
262    ///     .multipart(multipart_form)
263    ///     .await;
264    /// #
265    /// # Ok(()) }
266    /// ```
267    ///
268    /// # Sending byte parts
269    ///
270    /// ```rust
271    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
272    /// #
273    /// use axum::Router;
274    /// use axum_test::TestServer;
275    /// use axum_test::multipart::MultipartForm;
276    /// use axum_test::multipart::Part;
277    ///
278    /// let app = Router::new();
279    /// let server = TestServer::new(app)?;
280    ///
281    /// let readme_bytes = include_bytes!("../README.md");
282    /// let readme_part = Part::bytes(readme_bytes.as_slice())
283    ///     .file_name(&"README.md")
284    ///     .mime_type(&"text/markdown");
285    ///
286    /// let multipart_form = MultipartForm::new()
287    ///     .add_part("file", readme_part);
288    ///
289    /// let response = server.post(&"/my-form")
290    ///     .multipart(multipart_form)
291    ///     .await;
292    /// #
293    /// # Ok(()) }
294    /// ```
295    ///
296    pub fn multipart(mut self, multipart: MultipartForm) -> Self {
297        self.config.content_type = Some(multipart.content_type());
298        self.body = Some(multipart.into());
299
300        self
301    }
302
303    /// Set raw text as the body of the request,
304    /// and sets the content type to `text/plain`.
305    pub fn text<T>(self, raw_text: T) -> Self
306    where
307        T: Display,
308    {
309        let body_text = format!("{}", raw_text);
310
311        self.bytes(body_text.into())
312            .content_type(mime::TEXT_PLAIN.essence_str())
313    }
314
315    /// Sends a payload as plain text, with the contents coming from a file.
316    pub fn text_from_file<P>(self, path: P) -> Self
317    where
318        P: AsRef<Path>,
319    {
320        let path_ref = path.as_ref();
321        let payload = read_to_string(path_ref)
322            .with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
323            .unwrap();
324
325        self.text(payload)
326    }
327
328    /// Set raw bytes as the body of the request.
329    ///
330    /// The content type is left unchanged.
331    pub fn bytes(mut self, body_bytes: Bytes) -> Self {
332        let body: Body = body_bytes.into();
333
334        self.body = Some(body);
335        self
336    }
337
338    /// Reads the contents of the file as raw bytes, and sends it within the request.
339    ///
340    /// The content type is left unchanged, and no parsing of the file is done.
341    pub fn bytes_from_file<P>(self, path: P) -> Self
342    where
343        P: AsRef<Path>,
344    {
345        let path_ref = path.as_ref();
346        let payload = read(path_ref)
347            .with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
348            .unwrap();
349
350        self.bytes(payload.into())
351    }
352
353    /// Set the content type to use for this request in the header.
354    pub fn content_type(mut self, content_type: &str) -> Self {
355        self.config.content_type = Some(content_type.to_string());
356        self
357    }
358
359    /// Adds a Cookie to be sent with this request.
360    pub fn add_cookie(mut self, cookie: Cookie<'_>) -> Self {
361        self.config.cookies.add(cookie.into_owned());
362        self
363    }
364
365    /// Adds many cookies to be used with this request.
366    pub fn add_cookies(mut self, cookies: CookieJar) -> Self {
367        for cookie in cookies.iter() {
368            self.config.cookies.add(cookie.clone());
369        }
370
371        self
372    }
373
374    /// Clears all cookies used internally within this Request,
375    /// including any that came from the `TestServer`.
376    pub fn clear_cookies(mut self) -> Self {
377        self.config.cookies = CookieJar::new();
378        self
379    }
380
381    /// Any cookies returned will be saved to the [`TestServer`](crate::TestServer) that created this,
382    /// which will continue to use those cookies on future requests.
383    pub fn save_cookies(mut self) -> Self {
384        self.config.is_saving_cookies = true;
385        self
386    }
387
388    /// Cookies returned by this will _not_ be saved to the `TestServer`.
389    /// For use by future requests.
390    ///
391    /// This is the default behaviour.
392    /// You can change that default in [`TestServerConfig`](crate::TestServerConfig).
393    pub fn do_not_save_cookies(mut self) -> Self {
394        self.config.is_saving_cookies = false;
395        self
396    }
397
398    /// Adds query parameters to be sent with this request.
399    pub fn add_query_param<V>(self, key: &str, value: V) -> Self
400    where
401        V: Serialize,
402    {
403        self.add_query_params(&[(key, value)])
404    }
405
406    /// Adds the structure given as query parameters for this request.
407    ///
408    /// This is designed to take a list of parameters, or a body of parameters,
409    /// and then serializes them into the parameters of the request.
410    ///
411    /// # Sending a body of parameters using `json!`
412    ///
413    /// ```rust
414    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
415    /// #
416    /// use axum::Router;
417    /// use axum_test::TestServer;
418    /// use serde_json::json;
419    ///
420    /// let app = Router::new();
421    /// let server = TestServer::new(app)?;
422    ///
423    /// let response = server.get(&"/my-end-point")
424    ///     .add_query_params(json!({
425    ///         "username": "Brian",
426    ///         "age": 20
427    ///     }))
428    ///     .await;
429    /// #
430    /// # Ok(()) }
431    /// ```
432    ///
433    /// # Sending a body of parameters with Serde
434    ///
435    /// ```rust
436    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
437    /// #
438    /// use axum::Router;
439    /// use axum_test::TestServer;
440    /// use serde::Deserialize;
441    /// use serde::Serialize;
442    ///
443    /// #[derive(Serialize, Deserialize)]
444    /// struct UserQueryParams {
445    ///     username: String,
446    ///     age: u32,
447    /// }
448    ///
449    /// let app = Router::new();
450    /// let server = TestServer::new(app)?;
451    ///
452    /// let response = server.get(&"/my-end-point")
453    ///     .add_query_params(UserQueryParams {
454    ///         username: "Brian".to_string(),
455    ///         age: 20
456    ///     })
457    ///     .await;
458    /// #
459    /// # Ok(()) }
460    /// ```
461    ///
462    /// # Sending a list of parameters
463    ///
464    /// ```rust
465    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
466    /// #
467    /// use axum::Router;
468    /// use axum_test::TestServer;
469    ///
470    /// let app = Router::new();
471    /// let server = TestServer::new(app)?;
472    ///
473    /// let response = server.get(&"/my-end-point")
474    ///     .add_query_params(&[
475    ///         ("username", "Brian"),
476    ///         ("age", "20"),
477    ///     ])
478    ///     .await;
479    /// #
480    /// # Ok(()) }
481    /// ```
482    ///
483    pub fn add_query_params<V>(mut self, query_params: V) -> Self
484    where
485        V: Serialize,
486    {
487        self.config
488            .query_params
489            .add(query_params)
490            .with_context(|| {
491                format!(
492                    "It should serialize query parameters, for request {}",
493                    self.debug_request_format()
494                )
495            })
496            .unwrap();
497
498        self
499    }
500
501    /// Adds a query param onto the end of the request,
502    /// with no urlencoding of any kind.
503    ///
504    /// This exists to allow custom query parameters,
505    /// such as for the many versions of query param arrays.
506    ///
507    /// ```rust
508    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
509    /// #
510    /// use axum::Router;
511    /// use axum_test::TestServer;
512    ///
513    /// let app = Router::new();
514    /// let server = TestServer::new(app)?;
515    ///
516    /// let response = server.get(&"/my-end-point")
517    ///     .add_raw_query_param(&"my-flag")
518    ///     .add_raw_query_param(&"array[]=123")
519    ///     .add_raw_query_param(&"filter[value]=some-value")
520    ///     .await;
521    /// #
522    /// # Ok(()) }
523    /// ```
524    ///
525    pub fn add_raw_query_param(mut self, query_param: &str) -> Self {
526        self.config.query_params.add_raw(query_param.to_string());
527
528        self
529    }
530
531    /// Clears all query params set,
532    /// including any that came from the [`TestServer`](crate::TestServer).
533    pub fn clear_query_params(mut self) -> Self {
534        self.config.query_params.clear();
535        self
536    }
537
538    /// Adds a header to be sent with this request.
539    ///
540    /// ```rust
541    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
542    /// #
543    /// use axum::Router;
544    /// use axum_test::TestServer;
545    ///
546    /// let app = Router::new();
547    /// let server = TestServer::new(app)?;
548    ///
549    /// let response = server.get(&"/my-end-point")
550    ///     .add_header("x-custom-header", "custom-value")
551    ///     .add_header(http::header::CONTENT_LENGTH, 12345)
552    ///     .add_header(http::header::HOST, "example.com")
553    ///     .await;
554    /// #
555    /// # Ok(()) }
556    /// ```
557    pub fn add_header<N, V>(mut self, name: N, value: V) -> Self
558    where
559        N: TryInto<HeaderName>,
560        N::Error: Debug,
561        V: TryInto<HeaderValue>,
562        V::Error: Debug,
563    {
564        let header_name: HeaderName = name
565            .try_into()
566            .expect("Failed to convert header name to HeaderName");
567        let header_value: HeaderValue = value
568            .try_into()
569            .expect("Failed to convert header vlue to HeaderValue");
570
571        self.config.headers.push((header_name, header_value));
572        self
573    }
574
575    /// Adds an 'AUTHORIZATION' HTTP header to the request,
576    /// with no internal formatting of what is given.
577    pub fn authorization<T>(self, authorization_header: T) -> Self
578    where
579        T: AsRef<str>,
580    {
581        let authorization_header_value = HeaderValue::from_str(authorization_header.as_ref())
582            .expect("Cannot build Authorization HeaderValue from token");
583
584        self.add_header(header::AUTHORIZATION, authorization_header_value)
585    }
586
587    /// Adds an 'AUTHORIZATION' HTTP header to the request,
588    /// in the 'Bearer {token}' format.
589    pub fn authorization_bearer<T>(self, authorization_bearer_token: T) -> Self
590    where
591        T: Display,
592    {
593        let authorization_bearer_header_str = format!("Bearer {authorization_bearer_token}");
594        self.authorization(authorization_bearer_header_str)
595    }
596
597    /// Clears all headers set.
598    pub fn clear_headers(mut self) -> Self {
599        self.config.headers = vec![];
600        self
601    }
602
603    /// Sets the scheme to use when making the request. i.e. http or https.
604    /// The default scheme is 'http'.
605    ///
606    /// ```rust
607    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
608    /// #
609    /// use axum::Router;
610    /// use axum_test::TestServer;
611    ///
612    /// let app = Router::new();
613    /// let server = TestServer::new(app)?;
614    ///
615    /// let response = server
616    ///     .get(&"/my-end-point")
617    ///     .scheme(&"https")
618    ///     .await;
619    /// #
620    /// # Ok(()) }
621    /// ```
622    ///
623    pub fn scheme(mut self, scheme: &str) -> Self {
624        self.config
625            .full_request_url
626            .set_scheme(scheme)
627            .map_err(|_| anyhow!("Scheme '{scheme}' cannot be set to request"))
628            .unwrap();
629        self
630    }
631
632    /// Marks that this request is expected to always return a HTTP
633    /// status code within the 2xx range (200 to 299).
634    ///
635    /// If a code _outside_ of that range is returned,
636    /// then this will panic.
637    ///
638    /// ```rust
639    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
640    /// #
641    /// use axum::Json;
642    /// use axum::Router;
643    /// use axum::routing::put;
644    /// use serde_json::json;
645    ///
646    /// use axum_test::TestServer;
647    ///
648    /// let app = Router::new()
649    ///     .route(&"/todo", put(|| async { unimplemented!() }));
650    ///
651    /// let server = TestServer::new(app)?;
652    ///
653    /// // If this doesn't return a value in the 2xx range,
654    /// // then it will panic.
655    /// server.put(&"/todo")
656    ///     .expect_success()
657    ///     .json(&json!({
658    ///         "task": "buy milk",
659    ///     }))
660    ///     .await;
661    /// #
662    /// # Ok(())
663    /// # }
664    /// ```
665    ///
666    pub fn expect_success(self) -> Self {
667        self.expect_state(ExpectedState::Success)
668    }
669
670    /// Marks that this request is expected to return a HTTP status code
671    /// outside of the 2xx range.
672    ///
673    /// If a code _within_ the 2xx range is returned,
674    /// then this will panic.
675    pub fn expect_failure(self) -> Self {
676        self.expect_state(ExpectedState::Failure)
677    }
678
679    fn expect_state(mut self, expected_state: ExpectedState) -> Self {
680        self.expected_state = expected_state;
681        self
682    }
683
684    async fn send(self) -> Result<TestResponse> {
685        let debug_request_format = self.debug_request_format().to_string();
686
687        let method = self.config.method;
688        let expected_state = self.expected_state;
689        let save_cookies = self.config.is_saving_cookies;
690        let body = self.body.unwrap_or(Body::empty());
691        let url =
692            Self::build_url_query_params(self.config.full_request_url, &self.config.query_params);
693
694        let request = Self::build_request(
695            method.clone(),
696            &url,
697            body,
698            self.config.content_type,
699            self.config.cookies,
700            self.config.headers,
701            &debug_request_format,
702        )?;
703
704        #[allow(unused_mut)] // Allowed for the `ws` use immediately after.
705        let mut http_response = self.transport.send(request).await?;
706
707        #[cfg(feature = "ws")]
708        let websockets = {
709            let maybe_on_upgrade = http_response
710                .extensions_mut()
711                .remove::<hyper::upgrade::OnUpgrade>();
712            let transport_type = self.transport.transport_layer_type();
713
714            crate::internals::TestResponseWebSocket {
715                maybe_on_upgrade,
716                transport_type,
717            }
718        };
719
720        let (parts, response_body) = http_response.into_parts();
721        let response_bytes = response_body.collect().await?.to_bytes();
722
723        if save_cookies {
724            let cookie_headers = parts.headers.get_all(SET_COOKIE).into_iter();
725            ServerSharedState::add_cookies_by_header(&self.server_state, cookie_headers)?;
726        }
727
728        let test_response = TestResponse::new(
729            method,
730            url,
731            parts,
732            response_bytes,
733            #[cfg(feature = "ws")]
734            websockets,
735        );
736
737        // Assert if ok or not.
738        match expected_state {
739            ExpectedState::Success => test_response.assert_status_success(),
740            ExpectedState::Failure => test_response.assert_status_failure(),
741            ExpectedState::None => {}
742        }
743
744        Ok(test_response)
745    }
746
747    fn build_url_query_params(mut url: Url, query_params: &QueryParamsStore) -> Url {
748        // Add all the query params we have
749        if query_params.has_content() {
750            url.set_query(Some(&query_params.to_string()));
751        }
752
753        url
754    }
755
756    fn build_request(
757        method: Method,
758        url: &Url,
759        body: Body,
760        content_type: Option<String>,
761        cookies: CookieJar,
762        headers: Vec<(HeaderName, HeaderValue)>,
763        debug_request_format: &str,
764    ) -> Result<Request<Body>> {
765        let mut request_builder = Request::builder().uri(url.as_str()).method(method);
766
767        // Add all the headers we have.
768        if let Some(content_type) = content_type {
769            let (header_key, header_value) =
770                build_content_type_header(&content_type, debug_request_format)?;
771            request_builder = request_builder.header(header_key, header_value);
772        }
773
774        // Add all the non-expired cookies as headers
775        // Also strip cookies from their attributes, only their names and values should be preserved to conform the HTTP standard
776        let now = OffsetDateTime::now_utc();
777        for cookie in cookies.iter() {
778            let expired = cookie
779                .expires_datetime()
780                .map(|expires| expires <= now)
781                .unwrap_or(false);
782
783            if !expired {
784                let cookie_raw = cookie.stripped().to_string();
785                let header_value = HeaderValue::from_str(&cookie_raw)?;
786                request_builder = request_builder.header(header::COOKIE, header_value);
787            }
788        }
789
790        // Put headers into the request
791        for (header_name, header_value) in headers {
792            request_builder = request_builder.header(header_name, header_value);
793        }
794
795        let request = request_builder.body(body).with_context(|| {
796            format!("Expect valid hyper Request to be built, for request {debug_request_format}")
797        })?;
798
799        Ok(request)
800    }
801
802    fn debug_request_format(&self) -> RequestPathFormatter<'_> {
803        RequestPathFormatter::new(
804            &self.config.method,
805            self.config.full_request_url.as_str(),
806            Some(&self.config.query_params),
807        )
808    }
809}
810
811impl TryFrom<TestRequest> for Request<Body> {
812    type Error = AnyhowError;
813
814    fn try_from(test_request: TestRequest) -> Result<Request<Body>> {
815        let debug_request_format = test_request.debug_request_format().to_string();
816        let url = TestRequest::build_url_query_params(
817            test_request.config.full_request_url,
818            &test_request.config.query_params,
819        );
820        let body = test_request.body.unwrap_or(Body::empty());
821
822        TestRequest::build_request(
823            test_request.config.method,
824            &url,
825            body,
826            test_request.config.content_type,
827            test_request.config.cookies,
828            test_request.config.headers,
829            &debug_request_format,
830        )
831    }
832}
833
834impl IntoFuture for TestRequest {
835    type Output = TestResponse;
836    type IntoFuture = AutoFuture<TestResponse>;
837
838    fn into_future(self) -> Self::IntoFuture {
839        AutoFuture::new(async { self.send().await.context("Sending request failed").unwrap() })
840    }
841}
842
843fn build_content_type_header(
844    content_type: &str,
845    debug_request_format: &str,
846) -> Result<(HeaderName, HeaderValue)> {
847    let header_value = HeaderValue::from_str(content_type).with_context(|| {
848        format!(
849            "Failed to store header content type '{content_type}', for request {debug_request_format}"
850        )
851    })?;
852
853    Ok((header::CONTENT_TYPE, header_value))
854}
855
856#[cfg(test)]
857mod test_content_type {
858    use crate::TestServer;
859    use axum::routing::get;
860    use axum::Router;
861    use http::header::CONTENT_TYPE;
862    use http::HeaderMap;
863
864    async fn get_content_type(headers: HeaderMap) -> String {
865        headers
866            .get(CONTENT_TYPE)
867            .map(|h| h.to_str().unwrap().to_string())
868            .unwrap_or_else(|| "".to_string())
869    }
870
871    #[tokio::test]
872    async fn it_should_not_set_a_content_type_by_default() {
873        // Build an application with a route.
874        let app = Router::new().route("/content_type", get(get_content_type));
875
876        // Run the server.
877        let server = TestServer::new(app).expect("Should create test server");
878
879        // Get the request.
880        let text = server.get(&"/content_type").await.text();
881
882        assert_eq!(text, "");
883    }
884
885    #[tokio::test]
886    async fn it_should_override_server_content_type_when_present() {
887        // Build an application with a route.
888        let app = Router::new().route("/content_type", get(get_content_type));
889
890        // Run the server.
891        let server = TestServer::builder()
892            .default_content_type("text/plain")
893            .build(app)
894            .expect("Should create test server");
895
896        // Get the request.
897        let text = server
898            .get(&"/content_type")
899            .content_type(&"application/json")
900            .await
901            .text();
902
903        assert_eq!(text, "application/json");
904    }
905
906    #[tokio::test]
907    async fn it_should_set_content_type_when_present() {
908        // Build an application with a route.
909        let app = Router::new().route("/content_type", get(get_content_type));
910
911        // Run the server.
912        let server = TestServer::new(app).expect("Should create test server");
913
914        // Get the request.
915        let text = server
916            .get(&"/content_type")
917            .content_type(&"application/custom")
918            .await
919            .text();
920
921        assert_eq!(text, "application/custom");
922    }
923}
924
925#[cfg(test)]
926mod test_json {
927    use crate::TestServer;
928    use axum::extract::DefaultBodyLimit;
929    use axum::routing::post;
930    use axum::Json;
931    use axum::Router;
932    use http::header::CONTENT_TYPE;
933    use http::HeaderMap;
934    use rand::random;
935    use serde::Deserialize;
936    use serde::Serialize;
937    use serde_json::json;
938
939    #[tokio::test]
940    async fn it_should_pass_json_up_to_be_read() {
941        #[derive(Deserialize, Serialize)]
942        struct TestJson {
943            name: String,
944            age: u32,
945            pets: Option<String>,
946        }
947
948        // Build an application with a route.
949        let app = Router::new().route(
950            "/json",
951            post(|Json(json): Json<TestJson>| async move {
952                format!(
953                    "json: {}, {}, {}",
954                    json.name,
955                    json.age,
956                    json.pets.unwrap_or_else(|| "pandas".to_string())
957                )
958            }),
959        );
960
961        // Run the server.
962        let server = TestServer::new(app).expect("Should create test server");
963
964        // Get the request.
965        let text = server
966            .post(&"/json")
967            .json(&TestJson {
968                name: "Joe".to_string(),
969                age: 20,
970                pets: Some("foxes".to_string()),
971            })
972            .await
973            .text();
974
975        assert_eq!(text, "json: Joe, 20, foxes");
976    }
977
978    #[tokio::test]
979    async fn it_should_pass_json_content_type_for_json() {
980        // Build an application with a route.
981        let app = Router::new().route(
982            "/content_type",
983            post(|headers: HeaderMap| async move {
984                headers
985                    .get(CONTENT_TYPE)
986                    .map(|h| h.to_str().unwrap().to_string())
987                    .unwrap_or_else(|| "".to_string())
988            }),
989        );
990
991        // Run the server.
992        let server = TestServer::new(app).expect("Should create test server");
993
994        // Get the request.
995        let text = server.post(&"/content_type").json(&json!({})).await.text();
996
997        assert_eq!(text, "application/json");
998    }
999
1000    #[tokio::test]
1001    async fn it_should_pass_large_json_blobs_over_http() {
1002        const LARGE_BLOB_SIZE: usize = 16777216; // 16mb
1003
1004        #[derive(Deserialize, Serialize, PartialEq, Debug)]
1005        struct TestLargeJson {
1006            items: Vec<String>,
1007        }
1008
1009        let mut size = 0;
1010        let mut items = vec![];
1011        while size < LARGE_BLOB_SIZE {
1012            let item = random::<u64>().to_string();
1013            size += item.len();
1014            items.push(item);
1015        }
1016        let large_json_blob = TestLargeJson { items };
1017
1018        // Build an application with a route.
1019        let app = Router::new()
1020            .route(
1021                "/json",
1022                post(|Json(json): Json<TestLargeJson>| async { Json(json) }),
1023            )
1024            .layer(DefaultBodyLimit::max(LARGE_BLOB_SIZE * 2));
1025
1026        // Run the server.
1027        let server = TestServer::builder()
1028            .http_transport()
1029            .expect_success_by_default()
1030            .build(app)
1031            .expect("Should create test server");
1032
1033        // Get the request.
1034        server
1035            .post(&"/json")
1036            .json(&large_json_blob)
1037            .await
1038            .assert_json(&large_json_blob);
1039    }
1040}
1041
1042#[cfg(test)]
1043mod test_json_from_file {
1044    use crate::TestServer;
1045    use axum::routing::post;
1046    use axum::Json;
1047    use axum::Router;
1048    use http::header::CONTENT_TYPE;
1049    use http::HeaderMap;
1050    use serde::Deserialize;
1051    use serde::Serialize;
1052
1053    #[tokio::test]
1054    async fn it_should_pass_json_up_to_be_read() {
1055        #[derive(Deserialize, Serialize)]
1056        struct TestJson {
1057            name: String,
1058            age: u32,
1059        }
1060
1061        // Build an application with a route.
1062        let app = Router::new().route(
1063            "/json",
1064            post(|Json(json): Json<TestJson>| async move {
1065                format!("json: {}, {}", json.name, json.age,)
1066            }),
1067        );
1068
1069        // Run the server.
1070        let server = TestServer::new(app).expect("Should create test server");
1071
1072        // Get the request.
1073        let text = server
1074            .post(&"/json")
1075            .json_from_file(&"files/example.json")
1076            .await
1077            .text();
1078
1079        assert_eq!(text, "json: Joe, 20");
1080    }
1081
1082    #[tokio::test]
1083    async fn it_should_pass_json_content_type_for_json() {
1084        // Build an application with a route.
1085        let app = Router::new().route(
1086            "/content_type",
1087            post(|headers: HeaderMap| async move {
1088                headers
1089                    .get(CONTENT_TYPE)
1090                    .map(|h| h.to_str().unwrap().to_string())
1091                    .unwrap_or_else(|| "".to_string())
1092            }),
1093        );
1094
1095        // Run the server.
1096        let server = TestServer::new(app).expect("Should create test server");
1097
1098        // Get the request.
1099        let text = server
1100            .post(&"/content_type")
1101            .json_from_file(&"files/example.json")
1102            .await
1103            .text();
1104
1105        assert_eq!(text, "application/json");
1106    }
1107}
1108
1109#[cfg(feature = "yaml")]
1110#[cfg(test)]
1111mod test_yaml {
1112    use crate::TestServer;
1113    use axum::routing::post;
1114    use axum::Router;
1115    use axum_yaml::Yaml;
1116    use http::header::CONTENT_TYPE;
1117    use http::HeaderMap;
1118    use serde::Deserialize;
1119    use serde::Serialize;
1120    use serde_json::json;
1121
1122    #[tokio::test]
1123    async fn it_should_pass_yaml_up_to_be_read() {
1124        #[derive(Deserialize, Serialize)]
1125        struct TestYaml {
1126            name: String,
1127            age: u32,
1128            pets: Option<String>,
1129        }
1130
1131        // Build an application with a route.
1132        let app = Router::new().route(
1133            "/yaml",
1134            post(|Yaml(yaml): Yaml<TestYaml>| async move {
1135                format!(
1136                    "yaml: {}, {}, {}",
1137                    yaml.name,
1138                    yaml.age,
1139                    yaml.pets.unwrap_or_else(|| "pandas".to_string())
1140                )
1141            }),
1142        );
1143
1144        // Run the server.
1145        let server = TestServer::new(app).expect("Should create test server");
1146
1147        // Get the request.
1148        let text = server
1149            .post(&"/yaml")
1150            .yaml(&TestYaml {
1151                name: "Joe".to_string(),
1152                age: 20,
1153                pets: Some("foxes".to_string()),
1154            })
1155            .await
1156            .text();
1157
1158        assert_eq!(text, "yaml: Joe, 20, foxes");
1159    }
1160
1161    #[tokio::test]
1162    async fn it_should_pass_yaml_content_type_for_yaml() {
1163        // Build an application with a route.
1164        let app = Router::new().route(
1165            "/content_type",
1166            post(|headers: HeaderMap| async move {
1167                headers
1168                    .get(CONTENT_TYPE)
1169                    .map(|h| h.to_str().unwrap().to_string())
1170                    .unwrap_or_else(|| "".to_string())
1171            }),
1172        );
1173
1174        // Run the server.
1175        let server = TestServer::new(app).expect("Should create test server");
1176
1177        // Get the request.
1178        let text = server.post(&"/content_type").yaml(&json!({})).await.text();
1179
1180        assert_eq!(text, "application/yaml");
1181    }
1182}
1183
1184#[cfg(feature = "yaml")]
1185#[cfg(test)]
1186mod test_yaml_from_file {
1187    use crate::TestServer;
1188    use axum::routing::post;
1189    use axum::Router;
1190    use axum_yaml::Yaml;
1191    use http::header::CONTENT_TYPE;
1192    use http::HeaderMap;
1193    use serde::Deserialize;
1194    use serde::Serialize;
1195
1196    #[tokio::test]
1197    async fn it_should_pass_yaml_up_to_be_read() {
1198        #[derive(Deserialize, Serialize)]
1199        struct TestYaml {
1200            name: String,
1201            age: u32,
1202        }
1203
1204        // Build an application with a route.
1205        let app = Router::new().route(
1206            "/yaml",
1207            post(|Yaml(yaml): Yaml<TestYaml>| async move {
1208                format!("yaml: {}, {}", yaml.name, yaml.age,)
1209            }),
1210        );
1211
1212        // Run the server.
1213        let server = TestServer::new(app).expect("Should create test server");
1214
1215        // Get the request.
1216        let text = server
1217            .post(&"/yaml")
1218            .yaml_from_file(&"files/example.yaml")
1219            .await
1220            .text();
1221
1222        assert_eq!(text, "yaml: Joe, 20");
1223    }
1224
1225    #[tokio::test]
1226    async fn it_should_pass_yaml_content_type_for_yaml() {
1227        // Build an application with a route.
1228        let app = Router::new().route(
1229            "/content_type",
1230            post(|headers: HeaderMap| async move {
1231                headers
1232                    .get(CONTENT_TYPE)
1233                    .map(|h| h.to_str().unwrap().to_string())
1234                    .unwrap_or_else(|| "".to_string())
1235            }),
1236        );
1237
1238        // Run the server.
1239        let server = TestServer::new(app).expect("Should create test server");
1240
1241        // Get the request.
1242        let text = server
1243            .post(&"/content_type")
1244            .yaml_from_file(&"files/example.yaml")
1245            .await
1246            .text();
1247
1248        assert_eq!(text, "application/yaml");
1249    }
1250}
1251
1252#[cfg(feature = "msgpack")]
1253#[cfg(test)]
1254mod test_msgpack {
1255    use crate::TestServer;
1256    use axum::routing::post;
1257    use axum::Router;
1258    use axum_msgpack::MsgPack;
1259    use http::header::CONTENT_TYPE;
1260    use http::HeaderMap;
1261    use serde::Deserialize;
1262    use serde::Serialize;
1263    use serde_json::json;
1264
1265    #[tokio::test]
1266    async fn it_should_pass_msgpack_up_to_be_read() {
1267        #[derive(Deserialize, Serialize)]
1268        struct TestMsgPack {
1269            name: String,
1270            age: u32,
1271            pets: Option<String>,
1272        }
1273
1274        async fn get_msgpack(MsgPack(msgpack): MsgPack<TestMsgPack>) -> String {
1275            format!(
1276                "yaml: {}, {}, {}",
1277                msgpack.name,
1278                msgpack.age,
1279                msgpack.pets.unwrap_or_else(|| "pandas".to_string())
1280            )
1281        }
1282
1283        // Build an application with a route.
1284        let app = Router::new().route("/msgpack", post(get_msgpack));
1285
1286        // Run the server.
1287        let server = TestServer::new(app).expect("Should create test server");
1288
1289        // Get the request.
1290        let text = server
1291            .post(&"/msgpack")
1292            .msgpack(&TestMsgPack {
1293                name: "Joe".to_string(),
1294                age: 20,
1295                pets: Some("foxes".to_string()),
1296            })
1297            .await
1298            .text();
1299
1300        assert_eq!(text, "yaml: Joe, 20, foxes");
1301    }
1302
1303    #[tokio::test]
1304    async fn it_should_pass_msgpck_content_type_for_msgpack() {
1305        async fn get_content_type(headers: HeaderMap) -> String {
1306            headers
1307                .get(CONTENT_TYPE)
1308                .map(|h| h.to_str().unwrap().to_string())
1309                .unwrap_or_else(|| "".to_string())
1310        }
1311
1312        // Build an application with a route.
1313        let app = Router::new().route("/content_type", post(get_content_type));
1314
1315        // Run the server.
1316        let server = TestServer::new(app).expect("Should create test server");
1317
1318        // Get the request.
1319        let text = server
1320            .post(&"/content_type")
1321            .msgpack(&json!({}))
1322            .await
1323            .text();
1324
1325        assert_eq!(text, "application/msgpack");
1326    }
1327}
1328
1329#[cfg(test)]
1330mod test_form {
1331    use crate::TestServer;
1332    use axum::routing::post;
1333    use axum::Form;
1334    use axum::Router;
1335    use http::header::CONTENT_TYPE;
1336    use http::HeaderMap;
1337    use serde::Deserialize;
1338    use serde::Serialize;
1339
1340    #[tokio::test]
1341    async fn it_should_pass_form_up_to_be_read() {
1342        #[derive(Deserialize, Serialize)]
1343        struct TestForm {
1344            name: String,
1345            age: u32,
1346            pets: Option<String>,
1347        }
1348
1349        async fn get_form(Form(form): Form<TestForm>) -> String {
1350            format!(
1351                "form: {}, {}, {}",
1352                form.name,
1353                form.age,
1354                form.pets.unwrap_or_else(|| "pandas".to_string())
1355            )
1356        }
1357
1358        // Build an application with a route.
1359        let app = Router::new().route("/form", post(get_form));
1360
1361        // Run the server.
1362        let server = TestServer::new(app).expect("Should create test server");
1363
1364        // Get the request.
1365        server
1366            .post(&"/form")
1367            .form(&TestForm {
1368                name: "Joe".to_string(),
1369                age: 20,
1370                pets: Some("foxes".to_string()),
1371            })
1372            .await
1373            .assert_text("form: Joe, 20, foxes");
1374    }
1375
1376    #[tokio::test]
1377    async fn it_should_pass_form_content_type_for_form() {
1378        async fn get_content_type(headers: HeaderMap) -> String {
1379            headers
1380                .get(CONTENT_TYPE)
1381                .map(|h| h.to_str().unwrap().to_string())
1382                .unwrap_or_else(|| "".to_string())
1383        }
1384
1385        // Build an application with a route.
1386        let app = Router::new().route("/content_type", post(get_content_type));
1387
1388        // Run the server.
1389        let server = TestServer::new(app).expect("Should create test server");
1390
1391        #[derive(Serialize)]
1392        struct MyForm {
1393            message: String,
1394        }
1395
1396        // Get the request.
1397        server
1398            .post(&"/content_type")
1399            .form(&MyForm {
1400                message: "hello".to_string(),
1401            })
1402            .await
1403            .assert_text("application/x-www-form-urlencoded");
1404    }
1405}
1406
1407#[cfg(test)]
1408mod test_bytes {
1409    use crate::TestServer;
1410    use axum::extract::Request;
1411    use axum::routing::post;
1412    use axum::Router;
1413    use http::header::CONTENT_TYPE;
1414    use http::HeaderMap;
1415    use http_body_util::BodyExt;
1416
1417    #[tokio::test]
1418    async fn it_should_pass_bytes_up_to_be_read() {
1419        // Build an application with a route.
1420        let app = Router::new().route(
1421            "/bytes",
1422            post(|request: Request| async move {
1423                let body_bytes = request
1424                    .into_body()
1425                    .collect()
1426                    .await
1427                    .expect("Should read body to bytes")
1428                    .to_bytes();
1429                let body_text = String::from_utf8_lossy(&body_bytes);
1430
1431                format!("{}", body_text)
1432            }),
1433        );
1434
1435        // Run the server.
1436        let server = TestServer::new(app).expect("Should create test server");
1437
1438        // Get the request.
1439        let text = server
1440            .post(&"/bytes")
1441            .bytes("hello!".as_bytes().into())
1442            .await
1443            .text();
1444
1445        assert_eq!(text, "hello!");
1446    }
1447
1448    #[tokio::test]
1449    async fn it_should_not_change_content_type() {
1450        let app = Router::new().route(
1451            "/content_type",
1452            post(|headers: HeaderMap| async move {
1453                headers
1454                    .get(CONTENT_TYPE)
1455                    .map(|h| h.to_str().unwrap().to_string())
1456                    .unwrap_or_else(|| "".to_string())
1457            }),
1458        );
1459
1460        // Run the server.
1461        let server = TestServer::new(app).expect("Should create test server");
1462
1463        // Get the request.
1464        let text = server
1465            .post(&"/content_type")
1466            .content_type(&"application/testing")
1467            .bytes("hello!".as_bytes().into())
1468            .await
1469            .text();
1470
1471        assert_eq!(text, "application/testing");
1472    }
1473}
1474
1475#[cfg(test)]
1476mod test_bytes_from_file {
1477    use crate::TestServer;
1478    use axum::extract::Request;
1479    use axum::routing::post;
1480    use axum::Router;
1481    use http::header::CONTENT_TYPE;
1482    use http::HeaderMap;
1483    use http_body_util::BodyExt;
1484
1485    #[tokio::test]
1486    async fn it_should_pass_bytes_up_to_be_read() {
1487        // Build an application with a route.
1488        let app = Router::new().route(
1489            "/bytes",
1490            post(|request: Request| async move {
1491                let body_bytes = request
1492                    .into_body()
1493                    .collect()
1494                    .await
1495                    .expect("Should read body to bytes")
1496                    .to_bytes();
1497                let body_text = String::from_utf8_lossy(&body_bytes);
1498
1499                format!("{}", body_text)
1500            }),
1501        );
1502
1503        // Run the server.
1504        let server = TestServer::new(app).expect("Should create test server");
1505
1506        // Get the request.
1507        let text = server
1508            .post(&"/bytes")
1509            .bytes_from_file(&"files/example.txt")
1510            .await
1511            .text();
1512
1513        assert_eq!(text, "hello!");
1514    }
1515
1516    #[tokio::test]
1517    async fn it_should_not_change_content_type() {
1518        let app = Router::new().route(
1519            "/content_type",
1520            post(|headers: HeaderMap| async move {
1521                headers
1522                    .get(CONTENT_TYPE)
1523                    .map(|h| h.to_str().unwrap().to_string())
1524                    .unwrap_or_else(|| "".to_string())
1525            }),
1526        );
1527
1528        // Run the server.
1529        let server = TestServer::new(app).expect("Should create test server");
1530
1531        // Get the request.
1532        let text = server
1533            .post(&"/content_type")
1534            .content_type(&"application/testing")
1535            .bytes_from_file(&"files/example.txt")
1536            .await
1537            .text();
1538
1539        assert_eq!(text, "application/testing");
1540    }
1541}
1542
1543#[cfg(test)]
1544mod test_text {
1545    use crate::TestServer;
1546    use axum::extract::Request;
1547    use axum::routing::post;
1548    use axum::Router;
1549    use http::header::CONTENT_TYPE;
1550    use http::HeaderMap;
1551    use http_body_util::BodyExt;
1552
1553    #[tokio::test]
1554    async fn it_should_pass_text_up_to_be_read() {
1555        // Build an application with a route.
1556        let app = Router::new().route(
1557            "/text",
1558            post(|request: Request| async move {
1559                let body_bytes = request
1560                    .into_body()
1561                    .collect()
1562                    .await
1563                    .expect("Should read body to bytes")
1564                    .to_bytes();
1565                let body_text = String::from_utf8_lossy(&body_bytes);
1566
1567                format!("{}", body_text)
1568            }),
1569        );
1570
1571        // Run the server.
1572        let server = TestServer::new(app).expect("Should create test server");
1573
1574        // Get the request.
1575        let text = server.post(&"/text").text(&"hello!").await.text();
1576
1577        assert_eq!(text, "hello!");
1578    }
1579
1580    #[tokio::test]
1581    async fn it_should_pass_text_content_type_for_text() {
1582        let app = Router::new().route(
1583            "/content_type",
1584            post(|headers: HeaderMap| async move {
1585                headers
1586                    .get(CONTENT_TYPE)
1587                    .map(|h| h.to_str().unwrap().to_string())
1588                    .unwrap_or_else(|| "".to_string())
1589            }),
1590        );
1591
1592        // Run the server.
1593        let server = TestServer::new(app).expect("Should create test server");
1594
1595        // Get the request.
1596        let text = server.post(&"/content_type").text(&"hello!").await.text();
1597
1598        assert_eq!(text, "text/plain");
1599    }
1600
1601    #[tokio::test]
1602    async fn it_should_pass_large_text_blobs_over_mock_http() {
1603        const LARGE_BLOB_SIZE: usize = 16777216; // 16mb
1604        let large_blob = (0..LARGE_BLOB_SIZE).map(|_| "X").collect::<String>();
1605
1606        // Build an application with a route.
1607        let app = Router::new().route(
1608            "/text",
1609            post(|request: Request| async move {
1610                let body_bytes = request
1611                    .into_body()
1612                    .collect()
1613                    .await
1614                    .expect("Should read body to bytes")
1615                    .to_bytes();
1616                let body_text = String::from_utf8_lossy(&body_bytes);
1617
1618                format!("{}", body_text)
1619            }),
1620        );
1621
1622        // Run the server.
1623        let server = TestServer::builder()
1624            .mock_transport()
1625            .build(app)
1626            .expect("Should create test server");
1627
1628        // Get the request.
1629        let text = server.post(&"/text").text(&large_blob).await.text();
1630
1631        assert_eq!(text.len(), LARGE_BLOB_SIZE);
1632        assert_eq!(text, large_blob);
1633    }
1634
1635    #[tokio::test]
1636    async fn it_should_pass_large_text_blobs_over_http() {
1637        const LARGE_BLOB_SIZE: usize = 16777216; // 16mb
1638        let large_blob = (0..LARGE_BLOB_SIZE).map(|_| "X").collect::<String>();
1639
1640        // Build an application with a route.
1641        let app = Router::new().route(
1642            "/text",
1643            post(|request: Request| async move {
1644                let body_bytes = request
1645                    .into_body()
1646                    .collect()
1647                    .await
1648                    .expect("Should read body to bytes")
1649                    .to_bytes();
1650                let body_text = String::from_utf8_lossy(&body_bytes);
1651
1652                format!("{}", body_text)
1653            }),
1654        );
1655
1656        // Run the server.
1657        let server = TestServer::builder()
1658            .http_transport()
1659            .build(app)
1660            .expect("Should create test server");
1661
1662        // Get the request.
1663        let text = server.post(&"/text").text(&large_blob).await.text();
1664
1665        assert_eq!(text.len(), LARGE_BLOB_SIZE);
1666        assert_eq!(text, large_blob);
1667    }
1668}
1669
1670#[cfg(test)]
1671mod test_text_from_file {
1672    use crate::TestServer;
1673    use axum::extract::Request;
1674    use axum::routing::post;
1675    use axum::Router;
1676    use http::header::CONTENT_TYPE;
1677    use http::HeaderMap;
1678    use http_body_util::BodyExt;
1679
1680    #[tokio::test]
1681    async fn it_should_pass_text_up_to_be_read() {
1682        // Build an application with a route.
1683        let app = Router::new().route(
1684            "/text",
1685            post(|request: Request| async move {
1686                let body_bytes = request
1687                    .into_body()
1688                    .collect()
1689                    .await
1690                    .expect("Should read body to bytes")
1691                    .to_bytes();
1692                let body_text = String::from_utf8_lossy(&body_bytes);
1693
1694                format!("{}", body_text)
1695            }),
1696        );
1697
1698        // Run the server.
1699        let server = TestServer::new(app).expect("Should create test server");
1700
1701        // Get the request.
1702        let text = server
1703            .post(&"/text")
1704            .text_from_file(&"files/example.txt")
1705            .await
1706            .text();
1707
1708        assert_eq!(text, "hello!");
1709    }
1710
1711    #[tokio::test]
1712    async fn it_should_pass_text_content_type_for_text() {
1713        // Build an application with a route.
1714        let app = Router::new().route(
1715            "/content_type",
1716            post(|headers: HeaderMap| async move {
1717                headers
1718                    .get(CONTENT_TYPE)
1719                    .map(|h| h.to_str().unwrap().to_string())
1720                    .unwrap_or_else(|| "".to_string())
1721            }),
1722        );
1723
1724        // Run the server.
1725        let server = TestServer::new(app).expect("Should create test server");
1726
1727        // Get the request.
1728        let text = server
1729            .post(&"/content_type")
1730            .text_from_file(&"files/example.txt")
1731            .await
1732            .text();
1733
1734        assert_eq!(text, "text/plain");
1735    }
1736}
1737
1738#[cfg(test)]
1739mod test_expect_success {
1740    use crate::TestServer;
1741    use axum::routing::get;
1742    use axum::Router;
1743    use http::StatusCode;
1744
1745    #[tokio::test]
1746    async fn it_should_not_panic_if_success_is_returned() {
1747        async fn get_ping() -> &'static str {
1748            "pong!"
1749        }
1750
1751        // Build an application with a route.
1752        let app = Router::new().route("/ping", get(get_ping));
1753
1754        // Run the server.
1755        let server = TestServer::new(app).expect("Should create test server");
1756
1757        // Get the request.
1758        server.get(&"/ping").expect_success().await;
1759    }
1760
1761    #[tokio::test]
1762    async fn it_should_not_panic_on_other_2xx_status_code() {
1763        async fn get_accepted() -> StatusCode {
1764            StatusCode::ACCEPTED
1765        }
1766
1767        // Build an application with a route.
1768        let app = Router::new().route("/accepted", get(get_accepted));
1769
1770        // Run the server.
1771        let server = TestServer::new(app).expect("Should create test server");
1772
1773        // Get the request.
1774        server.get(&"/accepted").expect_success().await;
1775    }
1776
1777    #[tokio::test]
1778    #[should_panic]
1779    async fn it_should_panic_on_404() {
1780        // Build an application with a route.
1781        let app = Router::new();
1782
1783        // Run the server.
1784        let server = TestServer::new(app).expect("Should create test server");
1785
1786        // Get the request.
1787        server.get(&"/some_unknown_route").expect_success().await;
1788    }
1789
1790    #[tokio::test]
1791    async fn it_should_override_what_test_server_has_set() {
1792        async fn get_ping() -> &'static str {
1793            "pong!"
1794        }
1795
1796        // Build an application with a route.
1797        let app = Router::new().route("/ping", get(get_ping));
1798
1799        // Run the server.
1800        let mut server = TestServer::new(app).expect("Should create test server");
1801        server.expect_failure();
1802
1803        // Get the request.
1804        server.get(&"/ping").expect_success().await;
1805    }
1806}
1807
1808#[cfg(test)]
1809mod test_expect_failure {
1810    use crate::TestServer;
1811    use axum::routing::get;
1812    use axum::Router;
1813    use http::StatusCode;
1814
1815    #[tokio::test]
1816    async fn it_should_not_panic_if_expect_failure_on_404() {
1817        // Build an application with a route.
1818        let app = Router::new();
1819
1820        // Run the server.
1821        let server = TestServer::new(app).expect("Should create test server");
1822
1823        // Get the request.
1824        server.get(&"/some_unknown_route").expect_failure().await;
1825    }
1826
1827    #[tokio::test]
1828    #[should_panic]
1829    async fn it_should_panic_if_success_is_returned() {
1830        async fn get_ping() -> &'static str {
1831            "pong!"
1832        }
1833
1834        // Build an application with a route.
1835        let app = Router::new().route("/ping", get(get_ping));
1836
1837        // Run the server.
1838        let server = TestServer::new(app).expect("Should create test server");
1839
1840        // Get the request.
1841        server.get(&"/ping").expect_failure().await;
1842    }
1843
1844    #[tokio::test]
1845    #[should_panic]
1846    async fn it_should_panic_on_other_2xx_status_code() {
1847        async fn get_accepted() -> StatusCode {
1848            StatusCode::ACCEPTED
1849        }
1850
1851        // Build an application with a route.
1852        let app = Router::new().route("/accepted", get(get_accepted));
1853
1854        // Run the server.
1855        let server = TestServer::new(app).expect("Should create test server");
1856
1857        // Get the request.
1858        server.get(&"/accepted").expect_failure().await;
1859    }
1860
1861    #[tokio::test]
1862    async fn it_should_should_override_what_test_server_has_set() {
1863        // Build an application with a route.
1864        let app = Router::new();
1865
1866        // Run the server.
1867        let mut server = TestServer::new(app).expect("Should create test server");
1868        server.expect_success();
1869
1870        // Get the request.
1871        server.get(&"/some_unknown_route").expect_failure().await;
1872    }
1873}
1874
1875#[cfg(test)]
1876mod test_add_cookie {
1877    use crate::TestServer;
1878    use axum::routing::get;
1879    use axum::Router;
1880    use axum_extra::extract::cookie::CookieJar;
1881    use cookie::time::Duration;
1882    use cookie::time::OffsetDateTime;
1883    use cookie::Cookie;
1884
1885    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
1886
1887    async fn get_cookie(cookies: CookieJar) -> (CookieJar, String) {
1888        let cookie = cookies.get(&TEST_COOKIE_NAME);
1889        let cookie_value = cookie
1890            .map(|c| c.value().to_string())
1891            .unwrap_or_else(|| "cookie-not-found".to_string());
1892
1893        (cookies, cookie_value)
1894    }
1895
1896    #[tokio::test]
1897    async fn it_should_send_cookies_added_to_request() {
1898        let app = Router::new().route("/cookie", get(get_cookie));
1899        let server = TestServer::new(app).expect("Should create test server");
1900
1901        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1902        let response_text = server.get(&"/cookie").add_cookie(cookie).await.text();
1903        assert_eq!(response_text, "my-custom-cookie");
1904    }
1905
1906    #[tokio::test]
1907    async fn it_should_send_non_expired_cookies_added_to_request() {
1908        let app = Router::new().route("/cookie", get(get_cookie));
1909        let server = TestServer::new(app).expect("Should create test server");
1910
1911        let mut cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1912        cookie.set_expires(
1913            OffsetDateTime::now_utc()
1914                .checked_add(Duration::minutes(10))
1915                .unwrap(),
1916        );
1917        let response_text = server.get(&"/cookie").add_cookie(cookie).await.text();
1918        assert_eq!(response_text, "my-custom-cookie");
1919    }
1920
1921    #[tokio::test]
1922    async fn it_should_not_send_expired_cookies_added_to_request() {
1923        let app = Router::new().route("/cookie", get(get_cookie));
1924        let server = TestServer::new(app).expect("Should create test server");
1925
1926        let mut cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1927        cookie.set_expires(OffsetDateTime::now_utc());
1928        let response_text = server.get(&"/cookie").add_cookie(cookie).await.text();
1929        assert_eq!(response_text, "cookie-not-found");
1930    }
1931}
1932
1933#[cfg(test)]
1934mod test_add_cookies {
1935    use crate::TestServer;
1936    use axum::http::header::HeaderMap;
1937    use axum::routing::get;
1938    use axum::Router;
1939    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1940    use cookie::Cookie;
1941    use cookie::CookieJar;
1942    use cookie::SameSite;
1943
1944    async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1945        let mut all_cookies = cookies
1946            .iter()
1947            .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1948            .collect::<Vec<String>>();
1949        all_cookies.sort();
1950
1951        all_cookies.join(&", ")
1952    }
1953
1954    async fn get_cookie_headers_joined(headers: HeaderMap) -> String {
1955        let cookies: String = headers
1956            .get_all("cookie")
1957            .into_iter()
1958            .map(|c| c.to_str().unwrap_or("").to_string())
1959            .reduce(|a, b| a + "; " + &b)
1960            .unwrap_or_else(|| String::new());
1961
1962        cookies
1963    }
1964
1965    #[tokio::test]
1966    async fn it_should_send_all_cookies_added_by_jar() {
1967        let app = Router::new().route("/cookies", get(route_get_cookies));
1968        let server = TestServer::new(app).expect("Should create test server");
1969
1970        // Build cookies to send up
1971        let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1972        let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1973        let mut cookie_jar = CookieJar::new();
1974        cookie_jar.add(cookie_1);
1975        cookie_jar.add(cookie_2);
1976
1977        server
1978            .get(&"/cookies")
1979            .add_cookies(cookie_jar)
1980            .await
1981            .assert_text("first-cookie=my-custom-cookie, second-cookie=other-cookie");
1982    }
1983
1984    #[tokio::test]
1985    async fn it_should_send_all_cookies_stripped_by_their_attributes() {
1986        let app = Router::new().route("/cookies", get(get_cookie_headers_joined));
1987        let server = TestServer::new(app).expect("Should create test server");
1988
1989        const TEST_COOKIE_NAME: &'static str = &"test-cookie";
1990        const TEST_COOKIE_VALUE: &'static str = &"my-custom-cookie";
1991
1992        // Build cookie to send up
1993        let cookie = Cookie::build((TEST_COOKIE_NAME, TEST_COOKIE_VALUE))
1994            .http_only(true)
1995            .secure(true)
1996            .same_site(SameSite::Strict)
1997            .path("/cookie")
1998            .build();
1999        let mut cookie_jar = CookieJar::new();
2000        cookie_jar.add(cookie);
2001
2002        server
2003            .get(&"/cookies")
2004            .add_cookies(cookie_jar)
2005            .await
2006            .assert_text(format!("{}={}", TEST_COOKIE_NAME, TEST_COOKIE_VALUE));
2007    }
2008}
2009
2010#[cfg(test)]
2011mod test_save_cookies {
2012    use crate::TestServer;
2013    use axum::extract::Request;
2014    use axum::http::header::HeaderMap;
2015    use axum::routing::get;
2016    use axum::routing::put;
2017    use axum::Router;
2018    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
2019    use cookie::Cookie;
2020    use cookie::SameSite;
2021    use http_body_util::BodyExt;
2022
2023    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
2024
2025    async fn put_cookie_with_attributes(
2026        mut cookies: AxumCookieJar,
2027        request: Request,
2028    ) -> (AxumCookieJar, &'static str) {
2029        let body_bytes = request
2030            .into_body()
2031            .collect()
2032            .await
2033            .expect("Should turn the body into bytes")
2034            .to_bytes();
2035
2036        let body_text: String = String::from_utf8_lossy(&body_bytes).to_string();
2037        let cookie = Cookie::build((TEST_COOKIE_NAME, body_text))
2038            .http_only(true)
2039            .secure(true)
2040            .same_site(SameSite::Strict)
2041            .path("/cookie")
2042            .build();
2043        cookies = cookies.add(cookie);
2044
2045        (cookies, &"done")
2046    }
2047
2048    async fn get_cookie_headers_joined(headers: HeaderMap) -> String {
2049        let cookies: String = headers
2050            .get_all("cookie")
2051            .into_iter()
2052            .map(|c| c.to_str().unwrap_or("").to_string())
2053            .reduce(|a, b| a + "; " + &b)
2054            .unwrap_or_else(|| String::new());
2055
2056        cookies
2057    }
2058
2059    #[tokio::test]
2060    async fn it_should_strip_cookies_from_their_attributes() {
2061        let app = Router::new()
2062            .route("/cookie", put(put_cookie_with_attributes))
2063            .route("/cookie", get(get_cookie_headers_joined));
2064        let server = TestServer::new(app).expect("Should create test server");
2065
2066        // Create a cookie.
2067        server
2068            .put(&"/cookie")
2069            .text(&"cookie-found!")
2070            .save_cookies()
2071            .await;
2072
2073        // Check, only the cookie names and their values should come back.
2074        let response_text = server.get(&"/cookie").await.text();
2075
2076        assert_eq!(response_text, format!("{}=cookie-found!", TEST_COOKIE_NAME));
2077    }
2078}
2079
2080#[cfg(test)]
2081mod test_do_not_save_cookies {
2082    use crate::TestServer;
2083    use axum::extract::Request;
2084    use axum::http::header::HeaderMap;
2085    use axum::routing::get;
2086    use axum::routing::put;
2087    use axum::Router;
2088    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
2089    use cookie::Cookie;
2090    use cookie::SameSite;
2091    use http_body_util::BodyExt;
2092
2093    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
2094
2095    async fn put_cookie_with_attributes(
2096        mut cookies: AxumCookieJar,
2097        request: Request,
2098    ) -> (AxumCookieJar, &'static str) {
2099        let body_bytes = request
2100            .into_body()
2101            .collect()
2102            .await
2103            .expect("Should turn the body into bytes")
2104            .to_bytes();
2105
2106        let body_text: String = String::from_utf8_lossy(&body_bytes).to_string();
2107        let cookie = Cookie::build((TEST_COOKIE_NAME, body_text))
2108            .http_only(true)
2109            .secure(true)
2110            .same_site(SameSite::Strict)
2111            .path("/cookie")
2112            .build();
2113        cookies = cookies.add(cookie);
2114
2115        (cookies, &"done")
2116    }
2117
2118    async fn get_cookie_headers_joined(headers: HeaderMap) -> String {
2119        let cookies: String = headers
2120            .get_all("cookie")
2121            .into_iter()
2122            .map(|c| c.to_str().unwrap_or("").to_string())
2123            .reduce(|a, b| a + "; " + &b)
2124            .unwrap_or_else(|| String::new());
2125
2126        cookies
2127    }
2128
2129    #[tokio::test]
2130    async fn it_should_not_save_cookies_when_set() {
2131        let app = Router::new()
2132            .route("/cookie", put(put_cookie_with_attributes))
2133            .route("/cookie", get(get_cookie_headers_joined));
2134        let server = TestServer::new(app).expect("Should create test server");
2135
2136        // Create a cookie.
2137        server
2138            .put(&"/cookie")
2139            .text(&"cookie-found!")
2140            .do_not_save_cookies()
2141            .await;
2142
2143        // Check, only the cookie names and their values should come back.
2144        let response_text = server.get(&"/cookie").await.text();
2145
2146        assert_eq!(response_text, "");
2147    }
2148
2149    #[tokio::test]
2150    async fn it_should_override_test_server_and_not_save_cookies_when_set() {
2151        let app = Router::new()
2152            .route("/cookie", put(put_cookie_with_attributes))
2153            .route("/cookie", get(get_cookie_headers_joined));
2154        let server = TestServer::builder()
2155            .save_cookies()
2156            .build(app)
2157            .expect("Should create test server");
2158
2159        // Create a cookie.
2160        server
2161            .put(&"/cookie")
2162            .text(&"cookie-found!")
2163            .do_not_save_cookies()
2164            .await;
2165
2166        // Check, only the cookie names and their values should come back.
2167        let response_text = server.get(&"/cookie").await.text();
2168
2169        assert_eq!(response_text, "");
2170    }
2171}
2172
2173#[cfg(test)]
2174mod test_clear_cookies {
2175    use crate::TestServer;
2176    use axum::extract::Request;
2177    use axum::routing::get;
2178    use axum::routing::put;
2179    use axum::Router;
2180    use axum_extra::extract::cookie::Cookie as AxumCookie;
2181    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
2182    use cookie::Cookie;
2183    use cookie::CookieJar;
2184    use http_body_util::BodyExt;
2185
2186    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
2187
2188    async fn get_cookie(cookies: AxumCookieJar) -> (AxumCookieJar, String) {
2189        let cookie = cookies.get(&TEST_COOKIE_NAME);
2190        let cookie_value = cookie
2191            .map(|c| c.value().to_string())
2192            .unwrap_or_else(|| "cookie-not-found".to_string());
2193
2194        (cookies, cookie_value)
2195    }
2196
2197    async fn put_cookie(
2198        mut cookies: AxumCookieJar,
2199        request: Request,
2200    ) -> (AxumCookieJar, &'static str) {
2201        let body_bytes = request
2202            .into_body()
2203            .collect()
2204            .await
2205            .expect("Should turn the body into bytes")
2206            .to_bytes();
2207
2208        let body_text: String = String::from_utf8_lossy(&body_bytes).to_string();
2209        let cookie = AxumCookie::new(TEST_COOKIE_NAME, body_text);
2210        cookies = cookies.add(cookie);
2211
2212        (cookies, &"done")
2213    }
2214
2215    #[tokio::test]
2216    async fn it_should_clear_cookie_added_to_request() {
2217        let app = Router::new().route("/cookie", get(get_cookie));
2218        let server = TestServer::new(app).expect("Should create test server");
2219
2220        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
2221        let response_text = server
2222            .get(&"/cookie")
2223            .add_cookie(cookie)
2224            .clear_cookies()
2225            .await
2226            .text();
2227
2228        assert_eq!(response_text, "cookie-not-found");
2229    }
2230
2231    #[tokio::test]
2232    async fn it_should_clear_cookie_jar_added_to_request() {
2233        let app = Router::new().route("/cookie", get(get_cookie));
2234        let server = TestServer::new(app).expect("Should create test server");
2235
2236        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
2237        let mut cookie_jar = CookieJar::new();
2238        cookie_jar.add(cookie);
2239
2240        let response_text = server
2241            .get(&"/cookie")
2242            .add_cookies(cookie_jar)
2243            .clear_cookies()
2244            .await
2245            .text();
2246
2247        assert_eq!(response_text, "cookie-not-found");
2248    }
2249
2250    #[tokio::test]
2251    async fn it_should_clear_cookies_saved_by_past_request() {
2252        let app = Router::new()
2253            .route("/cookie", put(put_cookie))
2254            .route("/cookie", get(get_cookie));
2255        let server = TestServer::new(app).expect("Should create test server");
2256
2257        // Create a cookie.
2258        server
2259            .put(&"/cookie")
2260            .text(&"cookie-found!")
2261            .save_cookies()
2262            .await;
2263
2264        // Check it comes back.
2265        let response_text = server.get(&"/cookie").clear_cookies().await.text();
2266
2267        assert_eq!(response_text, "cookie-not-found");
2268    }
2269
2270    #[tokio::test]
2271    async fn it_should_clear_cookies_added_to_test_server() {
2272        let app = Router::new()
2273            .route("/cookie", put(put_cookie))
2274            .route("/cookie", get(get_cookie));
2275        let mut server = TestServer::new(app).expect("Should create test server");
2276
2277        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
2278        server.add_cookie(cookie);
2279
2280        // Check it comes back.
2281        let response_text = server.get(&"/cookie").clear_cookies().await.text();
2282
2283        assert_eq!(response_text, "cookie-not-found");
2284    }
2285}
2286
2287#[cfg(test)]
2288mod test_add_header {
2289    use super::*;
2290    use crate::TestServer;
2291    use axum::extract::FromRequestParts;
2292    use axum::routing::get;
2293    use axum::Router;
2294    use http::request::Parts;
2295    use http::HeaderName;
2296    use http::HeaderValue;
2297    use hyper::StatusCode;
2298    use std::marker::Sync;
2299
2300    const TEST_HEADER_NAME: &'static str = &"test-header";
2301    const TEST_HEADER_CONTENT: &'static str = &"Test header content";
2302
2303    struct TestHeader(Vec<u8>);
2304
2305    impl<S: Sync> FromRequestParts<S> for TestHeader {
2306        type Rejection = (StatusCode, &'static str);
2307
2308        async fn from_request_parts(
2309            parts: &mut Parts,
2310            _state: &S,
2311        ) -> Result<TestHeader, Self::Rejection> {
2312            parts
2313                .headers
2314                .get(HeaderName::from_static(TEST_HEADER_NAME))
2315                .map(|v| TestHeader(v.as_bytes().to_vec()))
2316                .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
2317        }
2318    }
2319
2320    async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
2321        header
2322    }
2323
2324    #[tokio::test]
2325    async fn it_should_send_header_added_to_request() {
2326        // Build an application with a route.
2327        let app = Router::new().route("/header", get(ping_header));
2328
2329        // Run the server.
2330        let server = TestServer::new(app).expect("Should create test server");
2331
2332        // Send a request with the header
2333        let response = server
2334            .get(&"/header")
2335            .add_header(
2336                HeaderName::from_static(TEST_HEADER_NAME),
2337                HeaderValue::from_static(TEST_HEADER_CONTENT),
2338            )
2339            .await;
2340
2341        // Check it sent back the right text
2342        response.assert_text(TEST_HEADER_CONTENT)
2343    }
2344}
2345
2346#[cfg(test)]
2347mod test_authorization {
2348    use super::*;
2349    use crate::TestServer;
2350    use axum::extract::FromRequestParts;
2351    use axum::routing::get;
2352    use axum::Router;
2353    use http::request::Parts;
2354    use hyper::StatusCode;
2355    use std::marker::Sync;
2356
2357    fn new_test_server() -> TestServer {
2358        struct TestHeader(String);
2359
2360        impl<S: Sync> FromRequestParts<S> for TestHeader {
2361            type Rejection = (StatusCode, &'static str);
2362
2363            async fn from_request_parts(
2364                parts: &mut Parts,
2365                _state: &S,
2366            ) -> Result<TestHeader, Self::Rejection> {
2367                parts
2368                    .headers
2369                    .get(header::AUTHORIZATION)
2370                    .map(|v| TestHeader(v.to_str().unwrap().to_string()))
2371                    .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
2372            }
2373        }
2374
2375        async fn ping_auth_header(TestHeader(header): TestHeader) -> String {
2376            header
2377        }
2378
2379        // Build an application with a route.
2380        let app = Router::new().route("/auth-header", get(ping_auth_header));
2381
2382        // Run the server.
2383        let mut server = TestServer::new(app).expect("Should create test server");
2384        server.expect_success();
2385
2386        server
2387    }
2388
2389    #[tokio::test]
2390    async fn it_should_send_header_added_to_request() {
2391        let server = new_test_server();
2392
2393        // Send a request with the header
2394        let response = server
2395            .get(&"/auth-header")
2396            .authorization("Bearer abc123")
2397            .await;
2398
2399        // Check it sent back the right text
2400        response.assert_text("Bearer abc123")
2401    }
2402}
2403
2404#[cfg(test)]
2405mod test_authorization_bearer {
2406    use super::*;
2407    use crate::TestServer;
2408    use axum::extract::FromRequestParts;
2409    use axum::routing::get;
2410    use axum::Router;
2411    use http::request::Parts;
2412    use hyper::StatusCode;
2413    use std::marker::Sync;
2414
2415    fn new_test_server() -> TestServer {
2416        struct TestHeader(String);
2417
2418        impl<S: Sync> FromRequestParts<S> for TestHeader {
2419            type Rejection = (StatusCode, &'static str);
2420
2421            async fn from_request_parts(
2422                parts: &mut Parts,
2423                _state: &S,
2424            ) -> Result<TestHeader, Self::Rejection> {
2425                parts
2426                    .headers
2427                    .get(header::AUTHORIZATION)
2428                    .map(|v| TestHeader(v.to_str().unwrap().to_string().replace("Bearer ", "")))
2429                    .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
2430            }
2431        }
2432
2433        async fn ping_auth_header(TestHeader(header): TestHeader) -> String {
2434            header
2435        }
2436
2437        // Build an application with a route.
2438        let app = Router::new().route("/auth-header", get(ping_auth_header));
2439
2440        // Run the server.
2441        let mut server = TestServer::new(app).expect("Should create test server");
2442        server.expect_success();
2443
2444        server
2445    }
2446
2447    #[tokio::test]
2448    async fn it_should_send_header_added_to_request() {
2449        let server = new_test_server();
2450
2451        // Send a request with the header
2452        let response = server
2453            .get(&"/auth-header")
2454            .authorization_bearer("abc123")
2455            .await;
2456
2457        // Check it sent back the right text
2458        response.assert_text("abc123")
2459    }
2460}
2461
2462#[cfg(test)]
2463mod test_clear_headers {
2464    use super::*;
2465    use crate::TestServer;
2466    use axum::extract::FromRequestParts;
2467    use axum::routing::get;
2468    use axum::Router;
2469    use http::request::Parts;
2470    use http::HeaderName;
2471    use http::HeaderValue;
2472    use hyper::StatusCode;
2473    use std::marker::Sync;
2474
2475    const TEST_HEADER_NAME: &'static str = &"test-header";
2476    const TEST_HEADER_CONTENT: &'static str = &"Test header content";
2477
2478    struct TestHeader(Vec<u8>);
2479
2480    impl<S: Sync> FromRequestParts<S> for TestHeader {
2481        type Rejection = (StatusCode, &'static str);
2482
2483        async fn from_request_parts(
2484            parts: &mut Parts,
2485            _state: &S,
2486        ) -> Result<TestHeader, Self::Rejection> {
2487            parts
2488                .headers
2489                .get(HeaderName::from_static(TEST_HEADER_NAME))
2490                .map(|v| TestHeader(v.as_bytes().to_vec()))
2491                .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
2492        }
2493    }
2494
2495    async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
2496        header
2497    }
2498
2499    #[tokio::test]
2500    async fn it_should_clear_headers_added_to_request() {
2501        // Build an application with a route.
2502        let app = Router::new().route("/header", get(ping_header));
2503
2504        // Run the server.
2505        let server = TestServer::new(app).expect("Should create test server");
2506
2507        // Send a request with the header
2508        let response = server
2509            .get(&"/header")
2510            .add_header(
2511                HeaderName::from_static(TEST_HEADER_NAME),
2512                HeaderValue::from_static(TEST_HEADER_CONTENT),
2513            )
2514            .clear_headers()
2515            .await;
2516
2517        // Check it sent back the right text
2518        response.assert_status_bad_request();
2519        response.assert_text("Missing test header");
2520    }
2521
2522    #[tokio::test]
2523    async fn it_should_clear_headers_added_to_server() {
2524        // Build an application with a route.
2525        let app = Router::new().route("/header", get(ping_header));
2526
2527        // Run the server.
2528        let mut server = TestServer::new(app).expect("Should create test server");
2529        server.add_header(
2530            HeaderName::from_static(TEST_HEADER_NAME),
2531            HeaderValue::from_static(TEST_HEADER_CONTENT),
2532        );
2533
2534        // Send a request with the header
2535        let response = server.get(&"/header").clear_headers().await;
2536
2537        // Check it sent back the right text
2538        response.assert_status_bad_request();
2539        response.assert_text("Missing test header");
2540    }
2541}
2542
2543#[cfg(test)]
2544mod test_add_query_params {
2545    use crate::TestServer;
2546    use axum::extract::Query as AxumStdQuery;
2547    use axum::routing::get;
2548    use axum::Router;
2549    use serde::Deserialize;
2550    use serde::Serialize;
2551    use serde_json::json;
2552
2553    #[derive(Debug, Deserialize, Serialize)]
2554    struct QueryParam {
2555        message: String,
2556    }
2557
2558    async fn get_query_param(AxumStdQuery(params): AxumStdQuery<QueryParam>) -> String {
2559        params.message
2560    }
2561
2562    #[derive(Debug, Deserialize, Serialize)]
2563    struct QueryParam2 {
2564        message: String,
2565        other: String,
2566    }
2567
2568    async fn get_query_param_2(AxumStdQuery(params): AxumStdQuery<QueryParam2>) -> String {
2569        format!("{}-{}", params.message, params.other)
2570    }
2571
2572    fn build_app() -> Router {
2573        Router::new()
2574            .route("/query", get(get_query_param))
2575            .route("/query-2", get(get_query_param_2))
2576    }
2577
2578    #[tokio::test]
2579    async fn it_should_pass_up_query_params_from_serialization() {
2580        // Run the server.
2581        let server = TestServer::new(build_app()).expect("Should create test server");
2582
2583        // Get the request.
2584        server
2585            .get(&"/query")
2586            .add_query_params(QueryParam {
2587                message: "it works".to_string(),
2588            })
2589            .await
2590            .assert_text(&"it works");
2591    }
2592
2593    #[tokio::test]
2594    async fn it_should_pass_up_query_params_from_pairs() {
2595        // Run the server.
2596        let server = TestServer::new(build_app()).expect("Should create test server");
2597
2598        // Get the request.
2599        server
2600            .get(&"/query")
2601            .add_query_params(&[("message", "it works")])
2602            .await
2603            .assert_text(&"it works");
2604    }
2605
2606    #[tokio::test]
2607    async fn it_should_pass_up_multiple_query_params_from_multiple_params() {
2608        // Run the server.
2609        let server = TestServer::new(build_app()).expect("Should create test server");
2610
2611        // Get the request.
2612        server
2613            .get(&"/query-2")
2614            .add_query_params(&[("message", "it works"), ("other", "yup")])
2615            .await
2616            .assert_text(&"it works-yup");
2617    }
2618
2619    #[tokio::test]
2620    async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
2621        // Run the server.
2622        let server = TestServer::new(build_app()).expect("Should create test server");
2623
2624        // Get the request.
2625        server
2626            .get(&"/query-2")
2627            .add_query_params(&[("message", "it works")])
2628            .add_query_params(&[("other", "yup")])
2629            .await
2630            .assert_text(&"it works-yup");
2631    }
2632
2633    #[tokio::test]
2634    async fn it_should_pass_up_multiple_query_params_from_json() {
2635        // Run the server.
2636        let server = TestServer::new(build_app()).expect("Should create test server");
2637
2638        // Get the request.
2639        server
2640            .get(&"/query-2")
2641            .add_query_params(json!({
2642                "message": "it works",
2643                "other": "yup"
2644            }))
2645            .await
2646            .assert_text(&"it works-yup");
2647    }
2648}
2649
2650#[cfg(test)]
2651mod test_add_raw_query_param {
2652    use crate::TestServer;
2653    use axum::extract::Query as AxumStdQuery;
2654    use axum::routing::get;
2655    use axum::Router;
2656    use axum_extra::extract::Query as AxumExtraQuery;
2657    use serde::Deserialize;
2658    use serde::Serialize;
2659    use std::fmt::Write;
2660
2661    #[derive(Debug, Deserialize, Serialize)]
2662    struct QueryParam {
2663        message: String,
2664    }
2665
2666    async fn get_query_param(AxumStdQuery(params): AxumStdQuery<QueryParam>) -> String {
2667        params.message
2668    }
2669
2670    #[derive(Debug, Deserialize, Serialize)]
2671    struct QueryParamExtra {
2672        #[serde(default)]
2673        items: Vec<String>,
2674
2675        #[serde(default, rename = "arrs[]")]
2676        arrs: Vec<String>,
2677    }
2678
2679    async fn get_query_param_extra(
2680        AxumExtraQuery(params): AxumExtraQuery<QueryParamExtra>,
2681    ) -> String {
2682        let mut output = String::new();
2683
2684        if params.items.len() > 0 {
2685            write!(output, "{}", params.items.join(", ")).unwrap();
2686        }
2687
2688        if params.arrs.len() > 0 {
2689            write!(output, "{}", params.arrs.join(", ")).unwrap();
2690        }
2691
2692        output
2693    }
2694
2695    fn build_app() -> Router {
2696        Router::new()
2697            .route("/query", get(get_query_param))
2698            .route("/query-extra", get(get_query_param_extra))
2699    }
2700
2701    #[tokio::test]
2702    async fn it_should_pass_up_query_param_as_is() {
2703        // Run the server.
2704        let server = TestServer::new(build_app()).expect("Should create test server");
2705
2706        // Get the request.
2707        server
2708            .get(&"/query")
2709            .add_raw_query_param(&"message=it-works")
2710            .await
2711            .assert_text(&"it-works");
2712    }
2713
2714    #[tokio::test]
2715    async fn it_should_pass_up_array_query_params_as_one_string() {
2716        // Run the server.
2717        let server = TestServer::new(build_app()).expect("Should create test server");
2718
2719        // Get the request.
2720        server
2721            .get(&"/query-extra")
2722            .add_raw_query_param(&"items=one&items=two&items=three")
2723            .await
2724            .assert_text(&"one, two, three");
2725    }
2726
2727    #[tokio::test]
2728    async fn it_should_pass_up_array_query_params_as_multiple_params() {
2729        // Run the server.
2730        let server = TestServer::new(build_app()).expect("Should create test server");
2731
2732        // Get the request.
2733        server
2734            .get(&"/query-extra")
2735            .add_raw_query_param(&"arrs[]=one")
2736            .add_raw_query_param(&"arrs[]=two")
2737            .add_raw_query_param(&"arrs[]=three")
2738            .await
2739            .assert_text(&"one, two, three");
2740    }
2741}
2742
2743#[cfg(test)]
2744mod test_add_query_param {
2745    use crate::TestServer;
2746    use axum::extract::Query;
2747    use axum::routing::get;
2748    use axum::Router;
2749    use serde::Deserialize;
2750    use serde::Serialize;
2751
2752    #[derive(Debug, Deserialize, Serialize)]
2753    struct QueryParam {
2754        message: String,
2755    }
2756
2757    async fn get_query_param(Query(params): Query<QueryParam>) -> String {
2758        params.message
2759    }
2760
2761    #[derive(Debug, Deserialize, Serialize)]
2762    struct QueryParam2 {
2763        message: String,
2764        other: String,
2765    }
2766
2767    async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
2768        format!("{}-{}", params.message, params.other)
2769    }
2770
2771    #[tokio::test]
2772    async fn it_should_pass_up_query_params_from_pairs() {
2773        // Build an application with a route.
2774        let app = Router::new().route("/query", get(get_query_param));
2775
2776        // Run the server.
2777        let server = TestServer::new(app).expect("Should create test server");
2778
2779        // Get the request.
2780        server
2781            .get(&"/query")
2782            .add_query_param("message", "it works")
2783            .await
2784            .assert_text(&"it works");
2785    }
2786
2787    #[tokio::test]
2788    async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
2789        // Build an application with a route.
2790        let app = Router::new().route("/query-2", get(get_query_param_2));
2791
2792        // Run the server.
2793        let server = TestServer::new(app).expect("Should create test server");
2794
2795        // Get the request.
2796        server
2797            .get(&"/query-2")
2798            .add_query_param("message", "it works")
2799            .add_query_param("other", "yup")
2800            .await
2801            .assert_text(&"it works-yup");
2802    }
2803}
2804
2805#[cfg(test)]
2806mod test_clear_query_params {
2807    use crate::TestServer;
2808    use axum::extract::Query;
2809    use axum::routing::get;
2810    use axum::Router;
2811    use serde::Deserialize;
2812    use serde::Serialize;
2813
2814    #[derive(Debug, Deserialize, Serialize)]
2815    struct QueryParams {
2816        first: Option<String>,
2817        second: Option<String>,
2818    }
2819
2820    async fn get_query_params(Query(params): Query<QueryParams>) -> String {
2821        format!(
2822            "has first? {}, has second? {}",
2823            params.first.is_some(),
2824            params.second.is_some()
2825        )
2826    }
2827
2828    #[tokio::test]
2829    async fn it_should_clear_all_params_set() {
2830        // Build an application with a route.
2831        let app = Router::new().route("/query", get(get_query_params));
2832
2833        // Run the server.
2834        let server = TestServer::new(app).expect("Should create test server");
2835
2836        // Get the request.
2837        server
2838            .get(&"/query")
2839            .add_query_params(QueryParams {
2840                first: Some("first".to_string()),
2841                second: Some("second".to_string()),
2842            })
2843            .clear_query_params()
2844            .await
2845            .assert_text(&"has first? false, has second? false");
2846    }
2847
2848    #[tokio::test]
2849    async fn it_should_clear_all_params_set_and_allow_replacement() {
2850        // Build an application with a route.
2851        let app = Router::new().route("/query", get(get_query_params));
2852
2853        // Run the server.
2854        let server = TestServer::new(app).expect("Should create test server");
2855
2856        // Get the request.
2857        server
2858            .get(&"/query")
2859            .add_query_params(QueryParams {
2860                first: Some("first".to_string()),
2861                second: Some("second".to_string()),
2862            })
2863            .clear_query_params()
2864            .add_query_params(QueryParams {
2865                first: Some("first".to_string()),
2866                second: Some("second".to_string()),
2867            })
2868            .await
2869            .assert_text(&"has first? true, has second? true");
2870    }
2871}
2872
2873#[cfg(test)]
2874mod test_scheme {
2875    use crate::TestServer;
2876    use axum::extract::Request;
2877    use axum::routing::get;
2878    use axum::Router;
2879
2880    async fn route_get_scheme(request: Request) -> String {
2881        request.uri().scheme_str().unwrap().to_string()
2882    }
2883
2884    #[tokio::test]
2885    async fn it_should_return_http_by_default() {
2886        let router = Router::new().route("/scheme", get(route_get_scheme));
2887        let server = TestServer::builder().build(router).unwrap();
2888
2889        server.get("/scheme").await.assert_text("http");
2890    }
2891
2892    #[tokio::test]
2893    async fn it_should_return_http_when_set() {
2894        let router = Router::new().route("/scheme", get(route_get_scheme));
2895        let server = TestServer::builder().build(router).unwrap();
2896
2897        server
2898            .get("/scheme")
2899            .scheme(&"http")
2900            .await
2901            .assert_text("http");
2902    }
2903
2904    #[tokio::test]
2905    async fn it_should_return_https_when_set() {
2906        let router = Router::new().route("/scheme", get(route_get_scheme));
2907        let server = TestServer::builder().build(router).unwrap();
2908
2909        server
2910            .get("/scheme")
2911            .scheme(&"https")
2912            .await
2913            .assert_text("https");
2914    }
2915
2916    #[tokio::test]
2917    async fn it_should_override_test_server_when_set() {
2918        let router = Router::new().route("/scheme", get(route_get_scheme));
2919
2920        let mut server = TestServer::builder().build(router).unwrap();
2921        server.scheme(&"https");
2922
2923        server
2924            .get("/scheme")
2925            .scheme(&"http") // set it back to http
2926            .await
2927            .assert_text("http");
2928    }
2929}
2930
2931#[cfg(test)]
2932mod test_multipart {
2933    use crate::multipart::MultipartForm;
2934    use crate::multipart::Part;
2935    use crate::TestServer;
2936    use axum::extract::Multipart;
2937    use axum::routing::post;
2938    use axum::Json;
2939    use axum::Router;
2940    use serde_json::json;
2941    use serde_json::Value;
2942
2943    async fn route_post_multipart(mut multipart: Multipart) -> Json<Vec<String>> {
2944        let mut fields = vec![];
2945
2946        while let Some(field) = multipart.next_field().await.unwrap() {
2947            let name = field.name().unwrap().to_string();
2948            let content_type = field.content_type().unwrap().to_owned();
2949            let data = field.bytes().await.unwrap();
2950
2951            let field_stats = format!("{name} is {} bytes, {content_type}", data.len());
2952            fields.push(field_stats);
2953        }
2954
2955        Json(fields)
2956    }
2957
2958    async fn route_post_multipart_headers(mut multipart: Multipart) -> Json<Vec<Value>> {
2959        let mut sent_part_headers = vec![];
2960
2961        while let Some(field) = multipart.next_field().await.unwrap() {
2962            let part_name = field.name().unwrap().to_string();
2963            let part_header_value = field
2964                .headers()
2965                .get("x-part-header-test")
2966                .unwrap()
2967                .to_str()
2968                .unwrap()
2969                .to_string();
2970            let part_text = String::from_utf8(field.bytes().await.unwrap().into()).unwrap();
2971
2972            sent_part_headers.push(json!({
2973                "name": part_name,
2974                "text": part_text,
2975                "header": part_header_value,
2976            }))
2977        }
2978
2979        Json(sent_part_headers)
2980    }
2981
2982    fn test_router() -> Router {
2983        Router::new()
2984            .route("/multipart", post(route_post_multipart))
2985            .route("/multipart_headers", post(route_post_multipart_headers))
2986    }
2987
2988    #[tokio::test]
2989    async fn it_should_get_multipart_stats_on_mock_transport() {
2990        // Run the server.
2991        let server = TestServer::builder()
2992            .mock_transport()
2993            .build(test_router())
2994            .expect("Should create test server");
2995
2996        let form = MultipartForm::new()
2997            .add_text("penguins?", "lots")
2998            .add_text("animals", "🦊🦊🦊")
2999            .add_text("carrots", 123 as u32);
3000
3001        // Get the request.
3002        server
3003            .post(&"/multipart")
3004            .multipart(form)
3005            .await
3006            .assert_json(&vec![
3007                "penguins? is 4 bytes, text/plain".to_string(),
3008                "animals is 12 bytes, text/plain".to_string(),
3009                "carrots is 3 bytes, text/plain".to_string(),
3010            ]);
3011    }
3012
3013    #[tokio::test]
3014    async fn it_should_get_multipart_stats_on_http_transport() {
3015        // Run the server.
3016        let server = TestServer::builder()
3017            .http_transport()
3018            .build(test_router())
3019            .expect("Should create test server");
3020
3021        let form = MultipartForm::new()
3022            .add_text("penguins?", "lots")
3023            .add_text("animals", "🦊🦊🦊")
3024            .add_text("carrots", 123 as u32);
3025
3026        // Get the request.
3027        server
3028            .post(&"/multipart")
3029            .multipart(form)
3030            .await
3031            .assert_json(&vec![
3032                "penguins? is 4 bytes, text/plain".to_string(),
3033                "animals is 12 bytes, text/plain".to_string(),
3034                "carrots is 3 bytes, text/plain".to_string(),
3035            ]);
3036    }
3037
3038    #[tokio::test]
3039    async fn it_should_send_text_parts_as_text() {
3040        // Run the server.
3041        let server = TestServer::builder()
3042            .mock_transport()
3043            .build(test_router())
3044            .expect("Should create test server");
3045
3046        let form = MultipartForm::new().add_part("animals", Part::text("🦊🦊🦊"));
3047
3048        // Get the request.
3049        server
3050            .post(&"/multipart")
3051            .multipart(form)
3052            .await
3053            .assert_json(&vec!["animals is 12 bytes, text/plain".to_string()]);
3054    }
3055
3056    #[tokio::test]
3057    async fn it_should_send_custom_mime_type() {
3058        // Run the server.
3059        let server = TestServer::builder()
3060            .mock_transport()
3061            .build(test_router())
3062            .expect("Should create test server");
3063
3064        let form = MultipartForm::new().add_part(
3065            "animals",
3066            Part::bytes("🦊,🦊,🦊".as_bytes()).mime_type(mime::TEXT_CSV),
3067        );
3068
3069        // Get the request.
3070        server
3071            .post(&"/multipart")
3072            .multipart(form)
3073            .await
3074            .assert_json(&vec!["animals is 14 bytes, text/csv".to_string()]);
3075    }
3076
3077    #[tokio::test]
3078    async fn it_should_send_using_include_bytes() {
3079        // Run the server.
3080        let server = TestServer::builder()
3081            .mock_transport()
3082            .build(test_router())
3083            .expect("Should create test server");
3084
3085        let form = MultipartForm::new().add_part(
3086            "file",
3087            Part::bytes(include_bytes!("../files/example.txt").as_slice())
3088                .mime_type(mime::TEXT_PLAIN),
3089        );
3090
3091        // Get the request.
3092        server
3093            .post(&"/multipart")
3094            .multipart(form)
3095            .await
3096            .assert_json(&vec!["file is 6 bytes, text/plain".to_string()]);
3097    }
3098
3099    #[tokio::test]
3100    async fn it_should_send_form_headers_in_parts() {
3101        // Run the server.
3102        let server = TestServer::builder()
3103            .mock_transport()
3104            .build(test_router())
3105            .expect("Should create test server");
3106
3107        let form = MultipartForm::new()
3108            .add_part(
3109                "part_1",
3110                Part::text("part_1_text").add_header("x-part-header-test", "part_1_header"),
3111            )
3112            .add_part(
3113                "part_2",
3114                Part::text("part_2_text").add_header("x-part-header-test", "part_2_header"),
3115            );
3116
3117        // Get the request.
3118        server
3119            .post(&"/multipart_headers")
3120            .multipart(form)
3121            .await
3122            .assert_json(&json!([
3123                {
3124                    "name": "part_1",
3125                    "text": "part_1_text",
3126                    "header": "part_1_header",
3127                },
3128                {
3129                    "name": "part_2",
3130                    "text": "part_2_text",
3131                    "header": "part_2_header",
3132                },
3133            ]));
3134    }
3135}