axum_test/
test_request.rs

1use anyhow::Context;
2use anyhow::Error as AnyhowError;
3use anyhow::Result;
4use anyhow::anyhow;
5use axum::body::Body;
6use bytes::Bytes;
7use cookie::Cookie;
8use cookie::CookieJar;
9use cookie::time::OffsetDateTime;
10use http::HeaderName;
11use http::HeaderValue;
12use http::Method;
13use http::Request;
14use http::header;
15use http::header::SET_COOKIE;
16use http_body_util::BodyExt;
17use serde::Serialize;
18use std::fmt::Debug;
19use std::fmt::Display;
20use std::fs::File;
21use std::fs::read;
22use std::fs::read_to_string;
23use std::future::{Future, IntoFuture};
24use std::io::BufReader;
25use std::path::Path;
26use std::pin::Pin;
27use std::sync::Arc;
28use std::sync::Mutex;
29use url::Url;
30
31use crate::ServerSharedState;
32use crate::TestResponse;
33use crate::internals::ExpectedState;
34use crate::internals::QueryParamsStore;
35use crate::internals::RequestPathFormatter;
36use crate::multipart::MultipartForm;
37use crate::transport_layer::TransportLayer;
38
39mod test_request_config;
40pub(crate) use self::test_request_config::*;
41
42///
43/// A `TestRequest` is for building and executing a HTTP request to the [`TestServer`](crate::TestServer).
44///
45/// ## Building
46///
47/// Requests are created by the [`TestServer`](crate::TestServer), using it's builder functions.
48/// They correspond to the appropriate HTTP method: [`TestServer::get()`](crate::TestServer::get()),
49/// [`TestServer::post()`](crate::TestServer::post()), etc.
50///
51/// See there for documentation.
52///
53/// ## Customising
54///
55/// The `TestRequest` allows the caller to fill in the rest of the request
56/// to be sent to the server. Including the headers, the body, cookies,
57/// and the content type, using the relevant functions.
58///
59/// The TestRequest struct provides a number of methods to set up the request,
60/// such as json, text, bytes, expect_failure, content_type, etc.
61///
62/// ## Sending
63///
64/// Once fully configured you send the request by awaiting the request object.
65///
66/// ```rust
67/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
68/// #
69/// # use axum::Router;
70/// # use axum_test::TestServer;
71/// #
72/// # let server = TestServer::new(Router::new())?;
73/// #
74/// // Build your request
75/// let request = server.get(&"/user")
76///     .add_header("x-custom-header", "example.com")
77///     .content_type("application/yaml");
78///
79/// // await request to execute
80/// let response = request.await;
81/// #
82/// # Ok(()) }
83/// ```
84///
85/// You will receive a `TestResponse`.
86///
87/// ## Cookie Saving
88///
89/// [`TestRequest::save_cookies()`](crate::TestRequest::save_cookies()) and [`TestRequest::do_not_save_cookies()`](crate::TestRequest::do_not_save_cookies())
90/// methods allow you to set the request to save cookies to the `TestServer`,
91/// for reuse on any future requests.
92///
93/// This behaviour is **off** by default, and can be changed for all `TestRequests`
94/// when building the `TestServer`. By building it with a `TestServerConfig` where `save_cookies` is set to true.
95///
96/// ## Expecting Failure and Success
97///
98/// When making a request you can mark it to expect a response within,
99/// or outside, of the 2xx range of HTTP status codes.
100///
101/// If the response returns a status code different to what is expected,
102/// then it will panic.
103///
104/// This is useful when making multiple requests within a test.
105/// As it can find issues earlier than later.
106///
107/// See the [`TestRequest::expect_failure()`](crate::TestRequest::expect_failure()),
108/// and [`TestRequest::expect_success()`](crate::TestRequest::expect_success()).
109///
110#[derive(Debug)]
111#[must_use = "futures do nothing unless polled"]
112pub struct TestRequest {
113    config: TestRequestConfig,
114
115    server_state: Arc<Mutex<ServerSharedState>>,
116    transport: Arc<Box<dyn TransportLayer>>,
117
118    body: Option<Body>,
119
120    expected_state: ExpectedState,
121}
122
123impl TestRequest {
124    pub(crate) fn new(
125        server_state: Arc<Mutex<ServerSharedState>>,
126        transport: Arc<Box<dyn TransportLayer>>,
127        config: TestRequestConfig,
128    ) -> Self {
129        let expected_state = config.expected_state;
130
131        Self {
132            config,
133            server_state,
134            transport,
135            body: None,
136            expected_state,
137        }
138    }
139
140    /// Set the body of the request to send up data as Json,
141    /// and changes the content type to `application/json`.
142    pub fn json<J>(self, body: &J) -> Self
143    where
144        J: ?Sized + Serialize,
145    {
146        let body_bytes =
147            serde_json::to_vec(body).expect("It should serialize the content into Json");
148
149        self.bytes(body_bytes.into())
150            .content_type(mime::APPLICATION_JSON.essence_str())
151    }
152
153    /// Sends a payload as a Json request, with the contents coming from a file.
154    pub fn json_from_file<P>(self, path: P) -> Self
155    where
156        P: AsRef<Path>,
157    {
158        let path_ref = path.as_ref();
159        let file = File::open(path_ref)
160            .with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
161            .unwrap();
162
163        let reader = BufReader::new(file);
164        let payload = serde_json::from_reader::<_, serde_json::Value>(reader)
165            .with_context(|| {
166                format!(
167                    "Failed to deserialize file '{}' as Json",
168                    path_ref.display()
169                )
170            })
171            .unwrap();
172
173        self.json(&payload)
174    }
175
176    /// Set the body of the request to send up data as Yaml,
177    /// and changes the content type to `application/yaml`.
178    #[cfg(feature = "yaml")]
179    pub fn yaml<Y>(self, body: &Y) -> Self
180    where
181        Y: ?Sized + Serialize,
182    {
183        let body = serde_yaml::to_string(body).expect("It should serialize the content into Yaml");
184
185        self.bytes(body.into_bytes().into())
186            .content_type("application/yaml")
187    }
188
189    /// Sends a payload as a Yaml request, with the contents coming from a file.
190    #[cfg(feature = "yaml")]
191    pub fn yaml_from_file<P>(self, path: P) -> Self
192    where
193        P: AsRef<Path>,
194    {
195        let path_ref = path.as_ref();
196        let file = File::open(path_ref)
197            .with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
198            .unwrap();
199
200        let reader = BufReader::new(file);
201        let payload = serde_yaml::from_reader::<_, serde_yaml::Value>(reader)
202            .with_context(|| {
203                format!(
204                    "Failed to deserialize file '{}' as Yaml",
205                    path_ref.display()
206                )
207            })
208            .unwrap();
209
210        self.yaml(&payload)
211    }
212
213    /// Set the body of the request to send up data as MsgPack,
214    /// and changes the content type to `application/msgpack`.
215    #[cfg(feature = "msgpack")]
216    pub fn msgpack<M>(self, body: &M) -> Self
217    where
218        M: ?Sized + Serialize,
219    {
220        let body_bytes =
221            ::rmp_serde::to_vec(body).expect("It should serialize the content into MsgPack");
222
223        self.bytes(body_bytes.into())
224            .content_type("application/msgpack")
225    }
226
227    /// Sets the body of the request, with the content type
228    /// of 'application/x-www-form-urlencoded'.
229    pub fn form<F>(self, body: &F) -> Self
230    where
231        F: ?Sized + Serialize,
232    {
233        let body_text =
234            serde_urlencoded::to_string(body).expect("It should serialize the content into a Form");
235
236        self.bytes(body_text.into())
237            .content_type(mime::APPLICATION_WWW_FORM_URLENCODED.essence_str())
238    }
239
240    /// For sending multipart forms.
241    /// The payload is built using [`MultipartForm`](crate::multipart::MultipartForm) and [`Part`](crate::multipart::Part).
242    ///
243    /// This will be sent with the content type of 'multipart/form-data'.
244    ///
245    /// # Simple example
246    ///
247    /// ```rust
248    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
249    /// #
250    /// use axum::Router;
251    /// use axum_test::TestServer;
252    /// use axum_test::multipart::MultipartForm;
253    ///
254    /// let app = Router::new();
255    /// let server = TestServer::new(app)?;
256    ///
257    /// let multipart_form = MultipartForm::new()
258    ///     .add_text("name", "Joe")
259    ///     .add_text("animals", "foxes");
260    ///
261    /// let response = server.post(&"/my-form")
262    ///     .multipart(multipart_form)
263    ///     .await;
264    /// #
265    /// # Ok(()) }
266    /// ```
267    ///
268    /// # Sending byte parts
269    ///
270    /// ```rust
271    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
272    /// #
273    /// use axum::Router;
274    /// use axum_test::TestServer;
275    /// use axum_test::multipart::MultipartForm;
276    /// use axum_test::multipart::Part;
277    ///
278    /// let app = Router::new();
279    /// let server = TestServer::new(app)?;
280    ///
281    /// let readme_bytes = include_bytes!("../README.md");
282    /// let readme_part = Part::bytes(readme_bytes.as_slice())
283    ///     .file_name(&"README.md")
284    ///     .mime_type(&"text/markdown");
285    ///
286    /// let multipart_form = MultipartForm::new()
287    ///     .add_part("file", readme_part);
288    ///
289    /// let response = server.post(&"/my-form")
290    ///     .multipart(multipart_form)
291    ///     .await;
292    /// #
293    /// # Ok(()) }
294    /// ```
295    ///
296    pub fn multipart(mut self, multipart: MultipartForm) -> Self {
297        self.config.content_type = Some(multipart.content_type());
298        self.body = Some(multipart.into());
299
300        self
301    }
302
303    /// Set raw text as the body of the request,
304    /// and sets the content type to `text/plain`.
305    pub fn text<T>(self, raw_text: T) -> Self
306    where
307        T: Display,
308    {
309        let body_text = raw_text.to_string();
310
311        self.bytes(body_text.into())
312            .content_type(mime::TEXT_PLAIN.essence_str())
313    }
314
315    /// Sends a payload as plain text, with the contents coming from a file.
316    pub fn text_from_file<P>(self, path: P) -> Self
317    where
318        P: AsRef<Path>,
319    {
320        let path_ref = path.as_ref();
321        let payload = read_to_string(path_ref)
322            .with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
323            .unwrap();
324
325        self.text(payload)
326    }
327
328    /// Set raw bytes as the body of the request.
329    ///
330    /// The content type is left unchanged.
331    pub fn bytes(mut self, body_bytes: Bytes) -> Self {
332        let body: Body = body_bytes.into();
333
334        self.body = Some(body);
335        self
336    }
337
338    /// Reads the contents of the file as raw bytes, and sends it within the request.
339    ///
340    /// The content type is left unchanged, and no parsing of the file is done.
341    pub fn bytes_from_file<P>(self, path: P) -> Self
342    where
343        P: AsRef<Path>,
344    {
345        let path_ref = path.as_ref();
346        let payload = read(path_ref)
347            .with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
348            .unwrap();
349
350        self.bytes(payload.into())
351    }
352
353    /// Set the content type to use for this request in the header.
354    pub fn content_type(mut self, content_type: &str) -> Self {
355        self.config.content_type = Some(content_type.to_string());
356        self
357    }
358
359    /// Adds a Cookie to be sent with this request.
360    pub fn add_cookie(mut self, cookie: Cookie<'_>) -> Self {
361        self.config.cookies.add(cookie.into_owned());
362        self
363    }
364
365    /// Adds many cookies to be used with this request.
366    pub fn add_cookies(mut self, cookies: CookieJar) -> Self {
367        for cookie in cookies.iter() {
368            self.config.cookies.add(cookie.clone());
369        }
370
371        self
372    }
373
374    /// Clears all cookies used internally within this Request,
375    /// including any that came from the `TestServer`.
376    pub fn clear_cookies(mut self) -> Self {
377        self.config.cookies = CookieJar::new();
378        self
379    }
380
381    /// Any cookies returned will be saved to the [`TestServer`](crate::TestServer) that created this,
382    /// which will continue to use those cookies on future requests.
383    pub fn save_cookies(mut self) -> Self {
384        self.config.is_saving_cookies = true;
385        self
386    }
387
388    /// Cookies returned by this will _not_ be saved to the `TestServer`.
389    /// For use by future requests.
390    ///
391    /// This is the default behaviour.
392    /// You can change that default in [`TestServerConfig`](crate::TestServerConfig).
393    pub fn do_not_save_cookies(mut self) -> Self {
394        self.config.is_saving_cookies = false;
395        self
396    }
397
398    /// Adds query parameters to be sent with this request.
399    pub fn add_query_param<V>(self, key: &str, value: V) -> Self
400    where
401        V: Serialize,
402    {
403        self.add_query_params(&[(key, value)])
404    }
405
406    /// Adds the structure given as query parameters for this request.
407    ///
408    /// This is designed to take a list of parameters, or a body of parameters,
409    /// and then serializes them into the parameters of the request.
410    ///
411    /// # Sending a body of parameters using `json!`
412    ///
413    /// ```rust
414    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
415    /// #
416    /// use axum::Router;
417    /// use axum_test::TestServer;
418    /// use serde_json::json;
419    ///
420    /// let app = Router::new();
421    /// let server = TestServer::new(app)?;
422    ///
423    /// let response = server.get(&"/my-end-point")
424    ///     .add_query_params(json!({
425    ///         "username": "Brian",
426    ///         "age": 20
427    ///     }))
428    ///     .await;
429    /// #
430    /// # Ok(()) }
431    /// ```
432    ///
433    /// # Sending a body of parameters with Serde
434    ///
435    /// ```rust
436    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
437    /// #
438    /// use axum::Router;
439    /// use axum_test::TestServer;
440    /// use serde::Deserialize;
441    /// use serde::Serialize;
442    ///
443    /// #[derive(Serialize, Deserialize)]
444    /// struct UserQueryParams {
445    ///     username: String,
446    ///     age: u32,
447    /// }
448    ///
449    /// let app = Router::new();
450    /// let server = TestServer::new(app)?;
451    ///
452    /// let response = server.get(&"/my-end-point")
453    ///     .add_query_params(UserQueryParams {
454    ///         username: "Brian".to_string(),
455    ///         age: 20
456    ///     })
457    ///     .await;
458    /// #
459    /// # Ok(()) }
460    /// ```
461    ///
462    /// # Sending a list of parameters
463    ///
464    /// ```rust
465    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
466    /// #
467    /// use axum::Router;
468    /// use axum_test::TestServer;
469    ///
470    /// let app = Router::new();
471    /// let server = TestServer::new(app)?;
472    ///
473    /// let response = server.get(&"/my-end-point")
474    ///     .add_query_params(&[
475    ///         ("username", "Brian"),
476    ///         ("age", "20"),
477    ///     ])
478    ///     .await;
479    /// #
480    /// # Ok(()) }
481    /// ```
482    ///
483    pub fn add_query_params<V>(mut self, query_params: V) -> Self
484    where
485        V: Serialize,
486    {
487        self.config
488            .query_params
489            .add(query_params)
490            .with_context(|| {
491                format!(
492                    "It should serialize query parameters, for request {}",
493                    self.debug_request_format()
494                )
495            })
496            .unwrap();
497
498        self
499    }
500
501    /// Adds a query param onto the end of the request,
502    /// with no urlencoding of any kind.
503    ///
504    /// This exists to allow custom query parameters,
505    /// such as for the many versions of query param arrays.
506    ///
507    /// ```rust
508    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
509    /// #
510    /// use axum::Router;
511    /// use axum_test::TestServer;
512    ///
513    /// let app = Router::new();
514    /// let server = TestServer::new(app)?;
515    ///
516    /// let response = server.get(&"/my-end-point")
517    ///     .add_raw_query_param(&"my-flag")
518    ///     .add_raw_query_param(&"array[]=123")
519    ///     .add_raw_query_param(&"filter[value]=some-value")
520    ///     .await;
521    /// #
522    /// # Ok(()) }
523    /// ```
524    ///
525    pub fn add_raw_query_param(mut self, query_param: &str) -> Self {
526        self.config.query_params.add_raw(query_param.to_string());
527
528        self
529    }
530
531    /// Clears all query params set,
532    /// including any that came from the [`TestServer`](crate::TestServer).
533    pub fn clear_query_params(mut self) -> Self {
534        self.config.query_params.clear();
535        self
536    }
537
538    /// Adds a header to be sent with this request.
539    ///
540    /// ```rust
541    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
542    /// #
543    /// use axum::Router;
544    /// use axum_test::TestServer;
545    ///
546    /// let app = Router::new();
547    /// let server = TestServer::new(app)?;
548    ///
549    /// let response = server.get(&"/my-end-point")
550    ///     .add_header("x-custom-header", "custom-value")
551    ///     .add_header(http::header::CONTENT_LENGTH, 12345)
552    ///     .add_header(http::header::HOST, "example.com")
553    ///     .await;
554    /// #
555    /// # Ok(()) }
556    /// ```
557    pub fn add_header<N, V>(mut self, name: N, value: V) -> Self
558    where
559        N: TryInto<HeaderName>,
560        N::Error: Debug,
561        V: TryInto<HeaderValue>,
562        V::Error: Debug,
563    {
564        let header_name: HeaderName = name
565            .try_into()
566            .expect("Failed to convert header name to HeaderName");
567        let header_value: HeaderValue = value
568            .try_into()
569            .expect("Failed to convert header vlue to HeaderValue");
570
571        self.config.headers.push((header_name, header_value));
572        self
573    }
574
575    /// Adds an 'AUTHORIZATION' HTTP header to the request,
576    /// with no internal formatting of what is given.
577    pub fn authorization<T>(self, authorization_header: T) -> Self
578    where
579        T: AsRef<str>,
580    {
581        let authorization_header_value = HeaderValue::from_str(authorization_header.as_ref())
582            .expect("Cannot build Authorization HeaderValue from token");
583
584        self.add_header(header::AUTHORIZATION, authorization_header_value)
585    }
586
587    /// Adds an 'AUTHORIZATION' HTTP header to the request,
588    /// in the 'Bearer {token}' format.
589    pub fn authorization_bearer<T>(self, authorization_bearer_token: T) -> Self
590    where
591        T: Display,
592    {
593        let authorization_bearer_header_str = format!("Bearer {authorization_bearer_token}");
594        self.authorization(authorization_bearer_header_str)
595    }
596
597    /// Clears all headers set.
598    pub fn clear_headers(mut self) -> Self {
599        self.config.headers = vec![];
600        self
601    }
602
603    /// Sets the scheme to use when making the request. i.e. http or https.
604    /// The default scheme is 'http'.
605    ///
606    /// ```rust
607    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
608    /// #
609    /// use axum::Router;
610    /// use axum_test::TestServer;
611    ///
612    /// let app = Router::new();
613    /// let server = TestServer::new(app)?;
614    ///
615    /// let response = server
616    ///     .get(&"/my-end-point")
617    ///     .scheme(&"https")
618    ///     .await;
619    /// #
620    /// # Ok(()) }
621    /// ```
622    ///
623    pub fn scheme(mut self, scheme: &str) -> Self {
624        self.config
625            .full_request_url
626            .set_scheme(scheme)
627            .map_err(|_| anyhow!("Scheme '{scheme}' cannot be set to request"))
628            .unwrap();
629        self
630    }
631
632    /// Marks that this request is expected to always return a HTTP
633    /// status code within the 2xx range (200 to 299).
634    ///
635    /// If a code _outside_ of that range is returned,
636    /// then this will panic.
637    ///
638    /// ```rust
639    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
640    /// #
641    /// use axum::Json;
642    /// use axum::Router;
643    /// use axum::http::StatusCode;
644    /// use axum::routing::put;
645    /// use serde_json::json;
646    ///
647    /// use axum_test::TestServer;
648    ///
649    /// let app = Router::new()
650    ///     .route(&"/todo", put(|| async { StatusCode::NOT_FOUND }));
651    ///
652    /// let server = TestServer::new(app)?;
653    ///
654    /// // If this doesn't return a value in the 2xx range,
655    /// // then it will panic.
656    /// server.put(&"/todo")
657    ///     .expect_success()
658    ///     .json(&json!({
659    ///         "task": "buy milk",
660    ///     }))
661    ///     .await;
662    /// #
663    /// # Ok(())
664    /// # }
665    /// ```
666    ///
667    pub fn expect_success(self) -> Self {
668        self.expect_state(ExpectedState::Success)
669    }
670
671    /// Marks that this request is expected to return a HTTP status code
672    /// outside of the 2xx range.
673    ///
674    /// If a code _within_ the 2xx range is returned,
675    /// then this will panic.
676    pub fn expect_failure(self) -> Self {
677        self.expect_state(ExpectedState::Failure)
678    }
679
680    fn expect_state(mut self, expected_state: ExpectedState) -> Self {
681        self.expected_state = expected_state;
682        self
683    }
684
685    async fn send(self) -> Result<TestResponse> {
686        let debug_request_format = self.debug_request_format().to_string();
687
688        let method = self.config.method;
689        let expected_state = self.expected_state;
690        let save_cookies = self.config.is_saving_cookies;
691        let body = self.body.unwrap_or(Body::empty());
692        let url =
693            Self::build_url_query_params(self.config.full_request_url, &self.config.query_params);
694
695        let request = Self::build_request(
696            method.clone(),
697            &url,
698            body,
699            self.config.content_type,
700            self.config.cookies,
701            self.config.headers,
702            &debug_request_format,
703        )?;
704
705        #[allow(unused_mut)] // Allowed for the `ws` use immediately after.
706        let mut http_response = self.transport.send(request).await?;
707
708        #[cfg(feature = "ws")]
709        let websockets = {
710            let maybe_on_upgrade = http_response
711                .extensions_mut()
712                .remove::<hyper::upgrade::OnUpgrade>();
713            let transport_type = self.transport.transport_layer_type();
714
715            crate::internals::TestResponseWebSocket {
716                maybe_on_upgrade,
717                transport_type,
718            }
719        };
720
721        let (parts, response_body) = http_response.into_parts();
722        let response_bytes = response_body.collect().await?.to_bytes();
723
724        if save_cookies {
725            let cookie_headers = parts.headers.get_all(SET_COOKIE).into_iter();
726            ServerSharedState::add_cookies_by_header(&self.server_state, cookie_headers)?;
727        }
728
729        let test_response = TestResponse::new(
730            method,
731            url,
732            parts,
733            response_bytes,
734            #[cfg(feature = "ws")]
735            websockets,
736        );
737
738        // Assert if ok or not.
739        match expected_state {
740            ExpectedState::Success => test_response.assert_status_success(),
741            ExpectedState::Failure => test_response.assert_status_failure(),
742            ExpectedState::None => {}
743        }
744
745        Ok(test_response)
746    }
747
748    fn build_url_query_params(mut url: Url, query_params: &QueryParamsStore) -> Url {
749        // Add all the query params we have
750        if query_params.has_content() {
751            url.set_query(Some(&query_params.to_string()));
752        }
753
754        url
755    }
756
757    fn build_request(
758        method: Method,
759        url: &Url,
760        body: Body,
761        content_type: Option<String>,
762        cookies: CookieJar,
763        headers: Vec<(HeaderName, HeaderValue)>,
764        debug_request_format: &str,
765    ) -> Result<Request<Body>> {
766        let mut request_builder = Request::builder().uri(url.as_str()).method(method);
767
768        // Add all the headers we have.
769        if let Some(content_type) = content_type {
770            let (header_key, header_value) =
771                build_content_type_header(&content_type, debug_request_format)?;
772            request_builder = request_builder.header(header_key, header_value);
773        }
774
775        // Add all the non-expired cookies as headers
776        // Also strip cookies from their attributes, only their names and values should be preserved to conform the HTTP standard
777        let now = OffsetDateTime::now_utc();
778        for cookie in cookies.iter() {
779            let expired = cookie
780                .expires_datetime()
781                .map(|expires| expires <= now)
782                .unwrap_or(false);
783
784            if !expired {
785                let cookie_raw = cookie.stripped().to_string();
786                let header_value = HeaderValue::from_str(&cookie_raw)?;
787                request_builder = request_builder.header(header::COOKIE, header_value);
788            }
789        }
790
791        // Put headers into the request
792        for (header_name, header_value) in headers {
793            request_builder = request_builder.header(header_name, header_value);
794        }
795
796        let request = request_builder.body(body).with_context(|| {
797            format!("Expect valid hyper Request to be built, for request {debug_request_format}")
798        })?;
799
800        Ok(request)
801    }
802
803    fn debug_request_format(&self) -> RequestPathFormatter<'_> {
804        RequestPathFormatter::new(
805            &self.config.method,
806            self.config.full_request_url.as_str(),
807            Some(&self.config.query_params),
808        )
809    }
810}
811
812impl TryFrom<TestRequest> for Request<Body> {
813    type Error = AnyhowError;
814
815    fn try_from(test_request: TestRequest) -> Result<Request<Body>> {
816        let debug_request_format = test_request.debug_request_format().to_string();
817        let url = TestRequest::build_url_query_params(
818            test_request.config.full_request_url,
819            &test_request.config.query_params,
820        );
821        let body = test_request.body.unwrap_or(Body::empty());
822
823        TestRequest::build_request(
824            test_request.config.method,
825            &url,
826            body,
827            test_request.config.content_type,
828            test_request.config.cookies,
829            test_request.config.headers,
830            &debug_request_format,
831        )
832    }
833}
834
835impl IntoFuture for TestRequest {
836    type Output = TestResponse;
837    type IntoFuture = Pin<Box<dyn Future<Output = TestResponse> + Send>>;
838
839    fn into_future(self) -> Self::IntoFuture {
840        Box::pin(async { self.send().await.context("Sending request failed").unwrap() })
841    }
842}
843
844fn build_content_type_header(
845    content_type: &str,
846    debug_request_format: &str,
847) -> Result<(HeaderName, HeaderValue)> {
848    let header_value = HeaderValue::from_str(content_type).with_context(|| {
849        format!(
850            "Failed to store header content type '{content_type}', for request {debug_request_format}"
851        )
852    })?;
853
854    Ok((header::CONTENT_TYPE, header_value))
855}
856
857#[cfg(test)]
858mod test_content_type {
859    use crate::TestServer;
860    use axum::Router;
861    use axum::routing::get;
862    use http::HeaderMap;
863    use http::header::CONTENT_TYPE;
864
865    async fn get_content_type(headers: HeaderMap) -> String {
866        headers
867            .get(CONTENT_TYPE)
868            .map(|h| h.to_str().unwrap().to_string())
869            .unwrap_or_else(|| "".to_string())
870    }
871
872    #[tokio::test]
873    async fn it_should_not_set_a_content_type_by_default() {
874        // Build an application with a route.
875        let app = Router::new().route("/content_type", get(get_content_type));
876
877        // Run the server.
878        let server = TestServer::new(app).expect("Should create test server");
879
880        // Get the request.
881        let text = server.get(&"/content_type").await.text();
882
883        assert_eq!(text, "");
884    }
885
886    #[tokio::test]
887    async fn it_should_override_server_content_type_when_present() {
888        // Build an application with a route.
889        let app = Router::new().route("/content_type", get(get_content_type));
890
891        // Run the server.
892        let server = TestServer::builder()
893            .default_content_type("text/plain")
894            .build(app)
895            .expect("Should create test server");
896
897        // Get the request.
898        let text = server
899            .get(&"/content_type")
900            .content_type(&"application/json")
901            .await
902            .text();
903
904        assert_eq!(text, "application/json");
905    }
906
907    #[tokio::test]
908    async fn it_should_set_content_type_when_present() {
909        // Build an application with a route.
910        let app = Router::new().route("/content_type", get(get_content_type));
911
912        // Run the server.
913        let server = TestServer::new(app).expect("Should create test server");
914
915        // Get the request.
916        let text = server
917            .get(&"/content_type")
918            .content_type(&"application/custom")
919            .await
920            .text();
921
922        assert_eq!(text, "application/custom");
923    }
924}
925
926#[cfg(test)]
927mod test_json {
928    use crate::TestServer;
929    use axum::Json;
930    use axum::Router;
931    use axum::extract::DefaultBodyLimit;
932    use axum::routing::post;
933    use http::HeaderMap;
934    use http::header::CONTENT_TYPE;
935    use rand::random;
936    use serde::Deserialize;
937    use serde::Serialize;
938    use serde_json::json;
939
940    #[tokio::test]
941    async fn it_should_pass_json_up_to_be_read() {
942        #[derive(Deserialize, Serialize)]
943        struct TestJson {
944            name: String,
945            age: u32,
946            pets: Option<String>,
947        }
948
949        // Build an application with a route.
950        let app = Router::new().route(
951            "/json",
952            post(|Json(json): Json<TestJson>| async move {
953                format!(
954                    "json: {}, {}, {}",
955                    json.name,
956                    json.age,
957                    json.pets.unwrap_or_else(|| "pandas".to_string())
958                )
959            }),
960        );
961
962        // Run the server.
963        let server = TestServer::new(app).expect("Should create test server");
964
965        // Get the request.
966        let text = server
967            .post(&"/json")
968            .json(&TestJson {
969                name: "Joe".to_string(),
970                age: 20,
971                pets: Some("foxes".to_string()),
972            })
973            .await
974            .text();
975
976        assert_eq!(text, "json: Joe, 20, foxes");
977    }
978
979    #[tokio::test]
980    async fn it_should_pass_json_content_type_for_json() {
981        // Build an application with a route.
982        let app = Router::new().route(
983            "/content_type",
984            post(|headers: HeaderMap| async move {
985                headers
986                    .get(CONTENT_TYPE)
987                    .map(|h| h.to_str().unwrap().to_string())
988                    .unwrap_or_else(|| "".to_string())
989            }),
990        );
991
992        // Run the server.
993        let server = TestServer::new(app).expect("Should create test server");
994
995        // Get the request.
996        let text = server.post(&"/content_type").json(&json!({})).await.text();
997
998        assert_eq!(text, "application/json");
999    }
1000
1001    #[tokio::test]
1002    async fn it_should_pass_large_json_blobs_over_http() {
1003        const LARGE_BLOB_SIZE: usize = 16777216; // 16mb
1004
1005        #[derive(Deserialize, Serialize, PartialEq, Debug)]
1006        struct TestLargeJson {
1007            items: Vec<String>,
1008        }
1009
1010        let mut size = 0;
1011        let mut items = vec![];
1012        while size < LARGE_BLOB_SIZE {
1013            let item = random::<u64>().to_string();
1014            size += item.len();
1015            items.push(item);
1016        }
1017        let large_json_blob = TestLargeJson { items };
1018
1019        // Build an application with a route.
1020        let app = Router::new()
1021            .route(
1022                "/json",
1023                post(|Json(json): Json<TestLargeJson>| async { Json(json) }),
1024            )
1025            .layer(DefaultBodyLimit::max(LARGE_BLOB_SIZE * 2));
1026
1027        // Run the server.
1028        let server = TestServer::builder()
1029            .http_transport()
1030            .expect_success_by_default()
1031            .build(app)
1032            .expect("Should create test server");
1033
1034        // Get the request.
1035        server
1036            .post(&"/json")
1037            .json(&large_json_blob)
1038            .await
1039            .assert_json(&large_json_blob);
1040    }
1041}
1042
1043#[cfg(test)]
1044mod test_json_from_file {
1045    use crate::TestServer;
1046    use axum::Json;
1047    use axum::Router;
1048    use axum::routing::post;
1049    use http::HeaderMap;
1050    use http::header::CONTENT_TYPE;
1051    use serde::Deserialize;
1052    use serde::Serialize;
1053
1054    #[tokio::test]
1055    async fn it_should_pass_json_up_to_be_read() {
1056        #[derive(Deserialize, Serialize)]
1057        struct TestJson {
1058            name: String,
1059            age: u32,
1060        }
1061
1062        // Build an application with a route.
1063        let app = Router::new().route(
1064            "/json",
1065            post(|Json(json): Json<TestJson>| async move {
1066                format!("json: {}, {}", json.name, json.age,)
1067            }),
1068        );
1069
1070        // Run the server.
1071        let server = TestServer::new(app).expect("Should create test server");
1072
1073        // Get the request.
1074        let text = server
1075            .post(&"/json")
1076            .json_from_file(&"files/example.json")
1077            .await
1078            .text();
1079
1080        assert_eq!(text, "json: Joe, 20");
1081    }
1082
1083    #[tokio::test]
1084    async fn it_should_pass_json_content_type_for_json() {
1085        // Build an application with a route.
1086        let app = Router::new().route(
1087            "/content_type",
1088            post(|headers: HeaderMap| async move {
1089                headers
1090                    .get(CONTENT_TYPE)
1091                    .map(|h| h.to_str().unwrap().to_string())
1092                    .unwrap_or_else(|| "".to_string())
1093            }),
1094        );
1095
1096        // Run the server.
1097        let server = TestServer::new(app).expect("Should create test server");
1098
1099        // Get the request.
1100        let text = server
1101            .post(&"/content_type")
1102            .json_from_file(&"files/example.json")
1103            .await
1104            .text();
1105
1106        assert_eq!(text, "application/json");
1107    }
1108}
1109
1110#[cfg(feature = "yaml")]
1111#[cfg(test)]
1112mod test_yaml {
1113    use crate::TestServer;
1114    use axum::Router;
1115    use axum::routing::post;
1116    use axum_yaml::Yaml;
1117    use http::HeaderMap;
1118    use http::header::CONTENT_TYPE;
1119    use serde::Deserialize;
1120    use serde::Serialize;
1121    use serde_json::json;
1122
1123    #[tokio::test]
1124    async fn it_should_pass_yaml_up_to_be_read() {
1125        #[derive(Deserialize, Serialize)]
1126        struct TestYaml {
1127            name: String,
1128            age: u32,
1129            pets: Option<String>,
1130        }
1131
1132        // Build an application with a route.
1133        let app = Router::new().route(
1134            "/yaml",
1135            post(|Yaml(yaml): Yaml<TestYaml>| async move {
1136                format!(
1137                    "yaml: {}, {}, {}",
1138                    yaml.name,
1139                    yaml.age,
1140                    yaml.pets.unwrap_or_else(|| "pandas".to_string())
1141                )
1142            }),
1143        );
1144
1145        // Run the server.
1146        let server = TestServer::new(app).expect("Should create test server");
1147
1148        // Get the request.
1149        let text = server
1150            .post(&"/yaml")
1151            .yaml(&TestYaml {
1152                name: "Joe".to_string(),
1153                age: 20,
1154                pets: Some("foxes".to_string()),
1155            })
1156            .await
1157            .text();
1158
1159        assert_eq!(text, "yaml: Joe, 20, foxes");
1160    }
1161
1162    #[tokio::test]
1163    async fn it_should_pass_yaml_content_type_for_yaml() {
1164        // Build an application with a route.
1165        let app = Router::new().route(
1166            "/content_type",
1167            post(|headers: HeaderMap| async move {
1168                headers
1169                    .get(CONTENT_TYPE)
1170                    .map(|h| h.to_str().unwrap().to_string())
1171                    .unwrap_or_else(|| "".to_string())
1172            }),
1173        );
1174
1175        // Run the server.
1176        let server = TestServer::new(app).expect("Should create test server");
1177
1178        // Get the request.
1179        let text = server.post(&"/content_type").yaml(&json!({})).await.text();
1180
1181        assert_eq!(text, "application/yaml");
1182    }
1183}
1184
1185#[cfg(feature = "yaml")]
1186#[cfg(test)]
1187mod test_yaml_from_file {
1188    use crate::TestServer;
1189    use axum::Router;
1190    use axum::routing::post;
1191    use axum_yaml::Yaml;
1192    use http::HeaderMap;
1193    use http::header::CONTENT_TYPE;
1194    use serde::Deserialize;
1195    use serde::Serialize;
1196
1197    #[tokio::test]
1198    async fn it_should_pass_yaml_up_to_be_read() {
1199        #[derive(Deserialize, Serialize)]
1200        struct TestYaml {
1201            name: String,
1202            age: u32,
1203        }
1204
1205        // Build an application with a route.
1206        let app = Router::new().route(
1207            "/yaml",
1208            post(|Yaml(yaml): Yaml<TestYaml>| async move {
1209                format!("yaml: {}, {}", yaml.name, yaml.age,)
1210            }),
1211        );
1212
1213        // Run the server.
1214        let server = TestServer::new(app).expect("Should create test server");
1215
1216        // Get the request.
1217        let text = server
1218            .post(&"/yaml")
1219            .yaml_from_file(&"files/example.yaml")
1220            .await
1221            .text();
1222
1223        assert_eq!(text, "yaml: Joe, 20");
1224    }
1225
1226    #[tokio::test]
1227    async fn it_should_pass_yaml_content_type_for_yaml() {
1228        // Build an application with a route.
1229        let app = Router::new().route(
1230            "/content_type",
1231            post(|headers: HeaderMap| async move {
1232                headers
1233                    .get(CONTENT_TYPE)
1234                    .map(|h| h.to_str().unwrap().to_string())
1235                    .unwrap_or_else(|| "".to_string())
1236            }),
1237        );
1238
1239        // Run the server.
1240        let server = TestServer::new(app).expect("Should create test server");
1241
1242        // Get the request.
1243        let text = server
1244            .post(&"/content_type")
1245            .yaml_from_file(&"files/example.yaml")
1246            .await
1247            .text();
1248
1249        assert_eq!(text, "application/yaml");
1250    }
1251}
1252
1253#[cfg(feature = "msgpack")]
1254#[cfg(test)]
1255mod test_msgpack {
1256    use crate::TestServer;
1257    use axum::Router;
1258    use axum::routing::post;
1259    use axum_msgpack::MsgPack;
1260    use http::HeaderMap;
1261    use http::header::CONTENT_TYPE;
1262    use serde::Deserialize;
1263    use serde::Serialize;
1264    use serde_json::json;
1265
1266    #[tokio::test]
1267    async fn it_should_pass_msgpack_up_to_be_read() {
1268        #[derive(Deserialize, Serialize)]
1269        struct TestMsgPack {
1270            name: String,
1271            age: u32,
1272            pets: Option<String>,
1273        }
1274
1275        async fn get_msgpack(MsgPack(msgpack): MsgPack<TestMsgPack>) -> String {
1276            format!(
1277                "yaml: {}, {}, {}",
1278                msgpack.name,
1279                msgpack.age,
1280                msgpack.pets.unwrap_or_else(|| "pandas".to_string())
1281            )
1282        }
1283
1284        // Build an application with a route.
1285        let app = Router::new().route("/msgpack", post(get_msgpack));
1286
1287        // Run the server.
1288        let server = TestServer::new(app).expect("Should create test server");
1289
1290        // Get the request.
1291        let text = server
1292            .post(&"/msgpack")
1293            .msgpack(&TestMsgPack {
1294                name: "Joe".to_string(),
1295                age: 20,
1296                pets: Some("foxes".to_string()),
1297            })
1298            .await
1299            .text();
1300
1301        assert_eq!(text, "yaml: Joe, 20, foxes");
1302    }
1303
1304    #[tokio::test]
1305    async fn it_should_pass_msgpck_content_type_for_msgpack() {
1306        async fn get_content_type(headers: HeaderMap) -> String {
1307            headers
1308                .get(CONTENT_TYPE)
1309                .map(|h| h.to_str().unwrap().to_string())
1310                .unwrap_or_else(|| "".to_string())
1311        }
1312
1313        // Build an application with a route.
1314        let app = Router::new().route("/content_type", post(get_content_type));
1315
1316        // Run the server.
1317        let server = TestServer::new(app).expect("Should create test server");
1318
1319        // Get the request.
1320        let text = server
1321            .post(&"/content_type")
1322            .msgpack(&json!({}))
1323            .await
1324            .text();
1325
1326        assert_eq!(text, "application/msgpack");
1327    }
1328}
1329
1330#[cfg(test)]
1331mod test_form {
1332    use crate::TestServer;
1333    use axum::Form;
1334    use axum::Router;
1335    use axum::routing::post;
1336    use http::HeaderMap;
1337    use http::header::CONTENT_TYPE;
1338    use serde::Deserialize;
1339    use serde::Serialize;
1340
1341    #[tokio::test]
1342    async fn it_should_pass_form_up_to_be_read() {
1343        #[derive(Deserialize, Serialize)]
1344        struct TestForm {
1345            name: String,
1346            age: u32,
1347            pets: Option<String>,
1348        }
1349
1350        async fn get_form(Form(form): Form<TestForm>) -> String {
1351            format!(
1352                "form: {}, {}, {}",
1353                form.name,
1354                form.age,
1355                form.pets.unwrap_or_else(|| "pandas".to_string())
1356            )
1357        }
1358
1359        // Build an application with a route.
1360        let app = Router::new().route("/form", post(get_form));
1361
1362        // Run the server.
1363        let server = TestServer::new(app).expect("Should create test server");
1364
1365        // Get the request.
1366        server
1367            .post(&"/form")
1368            .form(&TestForm {
1369                name: "Joe".to_string(),
1370                age: 20,
1371                pets: Some("foxes".to_string()),
1372            })
1373            .await
1374            .assert_text("form: Joe, 20, foxes");
1375    }
1376
1377    #[tokio::test]
1378    async fn it_should_pass_form_content_type_for_form() {
1379        async fn get_content_type(headers: HeaderMap) -> String {
1380            headers
1381                .get(CONTENT_TYPE)
1382                .map(|h| h.to_str().unwrap().to_string())
1383                .unwrap_or_else(|| "".to_string())
1384        }
1385
1386        // Build an application with a route.
1387        let app = Router::new().route("/content_type", post(get_content_type));
1388
1389        // Run the server.
1390        let server = TestServer::new(app).expect("Should create test server");
1391
1392        #[derive(Serialize)]
1393        struct MyForm {
1394            message: String,
1395        }
1396
1397        // Get the request.
1398        server
1399            .post(&"/content_type")
1400            .form(&MyForm {
1401                message: "hello".to_string(),
1402            })
1403            .await
1404            .assert_text("application/x-www-form-urlencoded");
1405    }
1406}
1407
1408#[cfg(test)]
1409mod test_bytes {
1410    use crate::TestServer;
1411    use axum::Router;
1412    use axum::extract::Request;
1413    use axum::routing::post;
1414    use http::HeaderMap;
1415    use http::header::CONTENT_TYPE;
1416    use http_body_util::BodyExt;
1417
1418    #[tokio::test]
1419    async fn it_should_pass_bytes_up_to_be_read() {
1420        // Build an application with a route.
1421        let app = Router::new().route(
1422            "/bytes",
1423            post(|request: Request| async move {
1424                let body_bytes = request
1425                    .into_body()
1426                    .collect()
1427                    .await
1428                    .expect("Should read body to bytes")
1429                    .to_bytes();
1430
1431                String::from_utf8_lossy(&body_bytes).to_string()
1432            }),
1433        );
1434
1435        // Run the server.
1436        let server = TestServer::new(app).expect("Should create test server");
1437
1438        // Get the request.
1439        let text = server
1440            .post(&"/bytes")
1441            .bytes("hello!".as_bytes().into())
1442            .await
1443            .text();
1444
1445        assert_eq!(text, "hello!");
1446    }
1447
1448    #[tokio::test]
1449    async fn it_should_not_change_content_type() {
1450        let app = Router::new().route(
1451            "/content_type",
1452            post(|headers: HeaderMap| async move {
1453                headers
1454                    .get(CONTENT_TYPE)
1455                    .map(|h| h.to_str().unwrap().to_string())
1456                    .unwrap_or_else(|| "".to_string())
1457            }),
1458        );
1459
1460        // Run the server.
1461        let server = TestServer::new(app).expect("Should create test server");
1462
1463        // Get the request.
1464        let text = server
1465            .post(&"/content_type")
1466            .content_type(&"application/testing")
1467            .bytes("hello!".as_bytes().into())
1468            .await
1469            .text();
1470
1471        assert_eq!(text, "application/testing");
1472    }
1473}
1474
1475#[cfg(test)]
1476mod test_bytes_from_file {
1477    use crate::TestServer;
1478    use axum::Router;
1479    use axum::extract::Request;
1480    use axum::routing::post;
1481    use http::HeaderMap;
1482    use http::header::CONTENT_TYPE;
1483    use http_body_util::BodyExt;
1484
1485    #[tokio::test]
1486    async fn it_should_pass_bytes_up_to_be_read() {
1487        // Build an application with a route.
1488        let app = Router::new().route(
1489            "/bytes",
1490            post(|request: Request| async move {
1491                let body_bytes = request
1492                    .into_body()
1493                    .collect()
1494                    .await
1495                    .expect("Should read body to bytes")
1496                    .to_bytes();
1497
1498                String::from_utf8_lossy(&body_bytes).to_string()
1499            }),
1500        );
1501
1502        // Run the server.
1503        let server = TestServer::new(app).expect("Should create test server");
1504
1505        // Get the request.
1506        let text = server
1507            .post(&"/bytes")
1508            .bytes_from_file(&"files/example.txt")
1509            .await
1510            .text();
1511
1512        assert_eq!(text, "hello!");
1513    }
1514
1515    #[tokio::test]
1516    async fn it_should_not_change_content_type() {
1517        let app = Router::new().route(
1518            "/content_type",
1519            post(|headers: HeaderMap| async move {
1520                headers
1521                    .get(CONTENT_TYPE)
1522                    .map(|h| h.to_str().unwrap().to_string())
1523                    .unwrap_or_else(|| "".to_string())
1524            }),
1525        );
1526
1527        // Run the server.
1528        let server = TestServer::new(app).expect("Should create test server");
1529
1530        // Get the request.
1531        let text = server
1532            .post(&"/content_type")
1533            .content_type(&"application/testing")
1534            .bytes_from_file(&"files/example.txt")
1535            .await
1536            .text();
1537
1538        assert_eq!(text, "application/testing");
1539    }
1540}
1541
1542#[cfg(test)]
1543mod test_text {
1544    use crate::TestServer;
1545    use axum::Router;
1546    use axum::extract::Request;
1547    use axum::routing::post;
1548    use http::HeaderMap;
1549    use http::header::CONTENT_TYPE;
1550    use http_body_util::BodyExt;
1551
1552    #[tokio::test]
1553    async fn it_should_pass_text_up_to_be_read() {
1554        // Build an application with a route.
1555        let app = Router::new().route(
1556            "/text",
1557            post(|request: Request| async move {
1558                let body_bytes = request
1559                    .into_body()
1560                    .collect()
1561                    .await
1562                    .expect("Should read body to bytes")
1563                    .to_bytes();
1564
1565                String::from_utf8_lossy(&body_bytes).to_string()
1566            }),
1567        );
1568
1569        // Run the server.
1570        let server = TestServer::new(app).expect("Should create test server");
1571
1572        // Get the request.
1573        let text = server.post(&"/text").text(&"hello!").await.text();
1574
1575        assert_eq!(text, "hello!");
1576    }
1577
1578    #[tokio::test]
1579    async fn it_should_pass_text_content_type_for_text() {
1580        let app = Router::new().route(
1581            "/content_type",
1582            post(|headers: HeaderMap| async move {
1583                headers
1584                    .get(CONTENT_TYPE)
1585                    .map(|h| h.to_str().unwrap().to_string())
1586                    .unwrap_or_else(|| "".to_string())
1587            }),
1588        );
1589
1590        // Run the server.
1591        let server = TestServer::new(app).expect("Should create test server");
1592
1593        // Get the request.
1594        let text = server.post(&"/content_type").text(&"hello!").await.text();
1595
1596        assert_eq!(text, "text/plain");
1597    }
1598
1599    #[tokio::test]
1600    async fn it_should_pass_large_text_blobs_over_mock_http() {
1601        const LARGE_BLOB_SIZE: usize = 16777216; // 16mb
1602        let large_blob = (0..LARGE_BLOB_SIZE).map(|_| "X").collect::<String>();
1603
1604        // Build an application with a route.
1605        let app = Router::new().route(
1606            "/text",
1607            post(|request: Request| async move {
1608                let body_bytes = request
1609                    .into_body()
1610                    .collect()
1611                    .await
1612                    .expect("Should read body to bytes")
1613                    .to_bytes();
1614
1615                String::from_utf8_lossy(&body_bytes).to_string()
1616            }),
1617        );
1618
1619        // Run the server.
1620        let server = TestServer::builder()
1621            .mock_transport()
1622            .build(app)
1623            .expect("Should create test server");
1624
1625        // Get the request.
1626        let text = server.post(&"/text").text(&large_blob).await.text();
1627
1628        assert_eq!(text.len(), LARGE_BLOB_SIZE);
1629        assert_eq!(text, large_blob);
1630    }
1631
1632    #[tokio::test]
1633    async fn it_should_pass_large_text_blobs_over_http() {
1634        const LARGE_BLOB_SIZE: usize = 16777216; // 16mb
1635        let large_blob = (0..LARGE_BLOB_SIZE).map(|_| "X").collect::<String>();
1636
1637        // Build an application with a route.
1638        let app = Router::new().route(
1639            "/text",
1640            post(|request: Request| async move {
1641                let body_bytes = request
1642                    .into_body()
1643                    .collect()
1644                    .await
1645                    .expect("Should read body to bytes")
1646                    .to_bytes();
1647
1648                String::from_utf8_lossy(&body_bytes).to_string()
1649            }),
1650        );
1651
1652        // Run the server.
1653        let server = TestServer::builder()
1654            .http_transport()
1655            .build(app)
1656            .expect("Should create test server");
1657
1658        // Get the request.
1659        let text = server.post(&"/text").text(&large_blob).await.text();
1660
1661        assert_eq!(text.len(), LARGE_BLOB_SIZE);
1662        assert_eq!(text, large_blob);
1663    }
1664}
1665
1666#[cfg(test)]
1667mod test_text_from_file {
1668    use crate::TestServer;
1669    use axum::Router;
1670    use axum::extract::Request;
1671    use axum::routing::post;
1672    use http::HeaderMap;
1673    use http::header::CONTENT_TYPE;
1674    use http_body_util::BodyExt;
1675
1676    #[tokio::test]
1677    async fn it_should_pass_text_up_to_be_read() {
1678        // Build an application with a route.
1679        let app = Router::new().route(
1680            "/text",
1681            post(|request: Request| async move {
1682                let body_bytes = request
1683                    .into_body()
1684                    .collect()
1685                    .await
1686                    .expect("Should read body to bytes")
1687                    .to_bytes();
1688
1689                String::from_utf8_lossy(&body_bytes).to_string()
1690            }),
1691        );
1692
1693        // Run the server.
1694        let server = TestServer::new(app).expect("Should create test server");
1695
1696        // Get the request.
1697        let text = server
1698            .post(&"/text")
1699            .text_from_file(&"files/example.txt")
1700            .await
1701            .text();
1702
1703        assert_eq!(text, "hello!");
1704    }
1705
1706    #[tokio::test]
1707    async fn it_should_pass_text_content_type_for_text() {
1708        // Build an application with a route.
1709        let app = Router::new().route(
1710            "/content_type",
1711            post(|headers: HeaderMap| async move {
1712                headers
1713                    .get(CONTENT_TYPE)
1714                    .map(|h| h.to_str().unwrap().to_string())
1715                    .unwrap_or_else(|| "".to_string())
1716            }),
1717        );
1718
1719        // Run the server.
1720        let server = TestServer::new(app).expect("Should create test server");
1721
1722        // Get the request.
1723        let text = server
1724            .post(&"/content_type")
1725            .text_from_file(&"files/example.txt")
1726            .await
1727            .text();
1728
1729        assert_eq!(text, "text/plain");
1730    }
1731}
1732
1733#[cfg(test)]
1734mod test_expect_success {
1735    use crate::TestServer;
1736    use axum::Router;
1737    use axum::routing::get;
1738    use http::StatusCode;
1739
1740    #[tokio::test]
1741    async fn it_should_not_panic_if_success_is_returned() {
1742        async fn get_ping() -> &'static str {
1743            "pong!"
1744        }
1745
1746        // Build an application with a route.
1747        let app = Router::new().route("/ping", get(get_ping));
1748
1749        // Run the server.
1750        let server = TestServer::new(app).expect("Should create test server");
1751
1752        // Get the request.
1753        server.get(&"/ping").expect_success().await;
1754    }
1755
1756    #[tokio::test]
1757    async fn it_should_not_panic_on_other_2xx_status_code() {
1758        async fn get_accepted() -> StatusCode {
1759            StatusCode::ACCEPTED
1760        }
1761
1762        // Build an application with a route.
1763        let app = Router::new().route("/accepted", get(get_accepted));
1764
1765        // Run the server.
1766        let server = TestServer::new(app).expect("Should create test server");
1767
1768        // Get the request.
1769        server.get(&"/accepted").expect_success().await;
1770    }
1771
1772    #[tokio::test]
1773    #[should_panic]
1774    async fn it_should_panic_on_404() {
1775        // Build an application with a route.
1776        let app = Router::new();
1777
1778        // Run the server.
1779        let server = TestServer::new(app).expect("Should create test server");
1780
1781        // Get the request.
1782        server.get(&"/some_unknown_route").expect_success().await;
1783    }
1784
1785    #[tokio::test]
1786    async fn it_should_override_what_test_server_has_set() {
1787        async fn get_ping() -> &'static str {
1788            "pong!"
1789        }
1790
1791        // Build an application with a route.
1792        let app = Router::new().route("/ping", get(get_ping));
1793
1794        // Run the server.
1795        let mut server = TestServer::new(app).expect("Should create test server");
1796        server.expect_failure();
1797
1798        // Get the request.
1799        server.get(&"/ping").expect_success().await;
1800    }
1801}
1802
1803#[cfg(test)]
1804mod test_expect_failure {
1805    use crate::TestServer;
1806    use axum::Router;
1807    use axum::routing::get;
1808    use http::StatusCode;
1809
1810    #[tokio::test]
1811    async fn it_should_not_panic_if_expect_failure_on_404() {
1812        // Build an application with a route.
1813        let app = Router::new();
1814
1815        // Run the server.
1816        let server = TestServer::new(app).expect("Should create test server");
1817
1818        // Get the request.
1819        server.get(&"/some_unknown_route").expect_failure().await;
1820    }
1821
1822    #[tokio::test]
1823    #[should_panic]
1824    async fn it_should_panic_if_success_is_returned() {
1825        async fn get_ping() -> &'static str {
1826            "pong!"
1827        }
1828
1829        // Build an application with a route.
1830        let app = Router::new().route("/ping", get(get_ping));
1831
1832        // Run the server.
1833        let server = TestServer::new(app).expect("Should create test server");
1834
1835        // Get the request.
1836        server.get(&"/ping").expect_failure().await;
1837    }
1838
1839    #[tokio::test]
1840    #[should_panic]
1841    async fn it_should_panic_on_other_2xx_status_code() {
1842        async fn get_accepted() -> StatusCode {
1843            StatusCode::ACCEPTED
1844        }
1845
1846        // Build an application with a route.
1847        let app = Router::new().route("/accepted", get(get_accepted));
1848
1849        // Run the server.
1850        let server = TestServer::new(app).expect("Should create test server");
1851
1852        // Get the request.
1853        server.get(&"/accepted").expect_failure().await;
1854    }
1855
1856    #[tokio::test]
1857    async fn it_should_should_override_what_test_server_has_set() {
1858        // Build an application with a route.
1859        let app = Router::new();
1860
1861        // Run the server.
1862        let mut server = TestServer::new(app).expect("Should create test server");
1863        server.expect_success();
1864
1865        // Get the request.
1866        server.get(&"/some_unknown_route").expect_failure().await;
1867    }
1868}
1869
1870#[cfg(test)]
1871mod test_add_cookie {
1872    use crate::TestServer;
1873    use axum::Router;
1874    use axum::routing::get;
1875    use axum_extra::extract::cookie::CookieJar;
1876    use cookie::Cookie;
1877    use cookie::time::Duration;
1878    use cookie::time::OffsetDateTime;
1879
1880    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
1881
1882    async fn get_cookie(cookies: CookieJar) -> (CookieJar, String) {
1883        let cookie = cookies.get(&TEST_COOKIE_NAME);
1884        let cookie_value = cookie
1885            .map(|c| c.value().to_string())
1886            .unwrap_or_else(|| "cookie-not-found".to_string());
1887
1888        (cookies, cookie_value)
1889    }
1890
1891    #[tokio::test]
1892    async fn it_should_send_cookies_added_to_request() {
1893        let app = Router::new().route("/cookie", get(get_cookie));
1894        let server = TestServer::new(app).expect("Should create test server");
1895
1896        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1897        let response_text = server.get(&"/cookie").add_cookie(cookie).await.text();
1898        assert_eq!(response_text, "my-custom-cookie");
1899    }
1900
1901    #[tokio::test]
1902    async fn it_should_send_non_expired_cookies_added_to_request() {
1903        let app = Router::new().route("/cookie", get(get_cookie));
1904        let server = TestServer::new(app).expect("Should create test server");
1905
1906        let mut cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1907        cookie.set_expires(
1908            OffsetDateTime::now_utc()
1909                .checked_add(Duration::minutes(10))
1910                .unwrap(),
1911        );
1912        let response_text = server.get(&"/cookie").add_cookie(cookie).await.text();
1913        assert_eq!(response_text, "my-custom-cookie");
1914    }
1915
1916    #[tokio::test]
1917    async fn it_should_not_send_expired_cookies_added_to_request() {
1918        let app = Router::new().route("/cookie", get(get_cookie));
1919        let server = TestServer::new(app).expect("Should create test server");
1920
1921        let mut cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1922        cookie.set_expires(OffsetDateTime::now_utc());
1923        let response_text = server.get(&"/cookie").add_cookie(cookie).await.text();
1924        assert_eq!(response_text, "cookie-not-found");
1925    }
1926}
1927
1928#[cfg(test)]
1929mod test_add_cookies {
1930    use crate::TestServer;
1931    use axum::Router;
1932    use axum::http::header::HeaderMap;
1933    use axum::routing::get;
1934    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1935    use cookie::Cookie;
1936    use cookie::CookieJar;
1937    use cookie::SameSite;
1938
1939    async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1940        let mut all_cookies = cookies
1941            .iter()
1942            .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1943            .collect::<Vec<String>>();
1944        all_cookies.sort();
1945
1946        all_cookies.join(&", ")
1947    }
1948
1949    async fn get_cookie_headers_joined(headers: HeaderMap) -> String {
1950        let cookies: String = headers
1951            .get_all("cookie")
1952            .into_iter()
1953            .map(|c| c.to_str().unwrap_or("").to_string())
1954            .reduce(|a, b| a + "; " + &b)
1955            .unwrap_or_else(|| String::new());
1956
1957        cookies
1958    }
1959
1960    #[tokio::test]
1961    async fn it_should_send_all_cookies_added_by_jar() {
1962        let app = Router::new().route("/cookies", get(route_get_cookies));
1963        let server = TestServer::new(app).expect("Should create test server");
1964
1965        // Build cookies to send up
1966        let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1967        let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1968        let mut cookie_jar = CookieJar::new();
1969        cookie_jar.add(cookie_1);
1970        cookie_jar.add(cookie_2);
1971
1972        server
1973            .get(&"/cookies")
1974            .add_cookies(cookie_jar)
1975            .await
1976            .assert_text("first-cookie=my-custom-cookie, second-cookie=other-cookie");
1977    }
1978
1979    #[tokio::test]
1980    async fn it_should_send_all_cookies_stripped_by_their_attributes() {
1981        let app = Router::new().route("/cookies", get(get_cookie_headers_joined));
1982        let server = TestServer::new(app).expect("Should create test server");
1983
1984        const TEST_COOKIE_NAME: &'static str = &"test-cookie";
1985        const TEST_COOKIE_VALUE: &'static str = &"my-custom-cookie";
1986
1987        // Build cookie to send up
1988        let cookie = Cookie::build((TEST_COOKIE_NAME, TEST_COOKIE_VALUE))
1989            .http_only(true)
1990            .secure(true)
1991            .same_site(SameSite::Strict)
1992            .path("/cookie")
1993            .build();
1994        let mut cookie_jar = CookieJar::new();
1995        cookie_jar.add(cookie);
1996
1997        server
1998            .get(&"/cookies")
1999            .add_cookies(cookie_jar)
2000            .await
2001            .assert_text(format!("{}={}", TEST_COOKIE_NAME, TEST_COOKIE_VALUE));
2002    }
2003}
2004
2005#[cfg(test)]
2006mod test_save_cookies {
2007    use crate::TestServer;
2008    use axum::Router;
2009    use axum::extract::Request;
2010    use axum::http::header::HeaderMap;
2011    use axum::routing::get;
2012    use axum::routing::put;
2013    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
2014    use cookie::Cookie;
2015    use cookie::SameSite;
2016    use http_body_util::BodyExt;
2017
2018    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
2019
2020    async fn put_cookie_with_attributes(
2021        mut cookies: AxumCookieJar,
2022        request: Request,
2023    ) -> (AxumCookieJar, &'static str) {
2024        let body_bytes = request
2025            .into_body()
2026            .collect()
2027            .await
2028            .expect("Should turn the body into bytes")
2029            .to_bytes();
2030
2031        let body_text: String = String::from_utf8_lossy(&body_bytes).to_string();
2032        let cookie = Cookie::build((TEST_COOKIE_NAME, body_text))
2033            .http_only(true)
2034            .secure(true)
2035            .same_site(SameSite::Strict)
2036            .path("/cookie")
2037            .build();
2038        cookies = cookies.add(cookie);
2039
2040        (cookies, &"done")
2041    }
2042
2043    async fn get_cookie_headers_joined(headers: HeaderMap) -> String {
2044        let cookies: String = headers
2045            .get_all("cookie")
2046            .into_iter()
2047            .map(|c| c.to_str().unwrap_or("").to_string())
2048            .reduce(|a, b| a + "; " + &b)
2049            .unwrap_or_else(|| String::new());
2050
2051        cookies
2052    }
2053
2054    #[tokio::test]
2055    async fn it_should_strip_cookies_from_their_attributes() {
2056        let app = Router::new()
2057            .route("/cookie", put(put_cookie_with_attributes))
2058            .route("/cookie", get(get_cookie_headers_joined));
2059        let server = TestServer::new(app).expect("Should create test server");
2060
2061        // Create a cookie.
2062        server
2063            .put(&"/cookie")
2064            .text(&"cookie-found!")
2065            .save_cookies()
2066            .await;
2067
2068        // Check, only the cookie names and their values should come back.
2069        let response_text = server.get(&"/cookie").await.text();
2070
2071        assert_eq!(response_text, format!("{}=cookie-found!", TEST_COOKIE_NAME));
2072    }
2073}
2074
2075#[cfg(test)]
2076mod test_do_not_save_cookies {
2077    use crate::TestServer;
2078    use axum::Router;
2079    use axum::extract::Request;
2080    use axum::http::header::HeaderMap;
2081    use axum::routing::get;
2082    use axum::routing::put;
2083    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
2084    use cookie::Cookie;
2085    use cookie::SameSite;
2086    use http_body_util::BodyExt;
2087
2088    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
2089
2090    async fn put_cookie_with_attributes(
2091        mut cookies: AxumCookieJar,
2092        request: Request,
2093    ) -> (AxumCookieJar, &'static str) {
2094        let body_bytes = request
2095            .into_body()
2096            .collect()
2097            .await
2098            .expect("Should turn the body into bytes")
2099            .to_bytes();
2100
2101        let body_text: String = String::from_utf8_lossy(&body_bytes).to_string();
2102        let cookie = Cookie::build((TEST_COOKIE_NAME, body_text))
2103            .http_only(true)
2104            .secure(true)
2105            .same_site(SameSite::Strict)
2106            .path("/cookie")
2107            .build();
2108        cookies = cookies.add(cookie);
2109
2110        (cookies, &"done")
2111    }
2112
2113    async fn get_cookie_headers_joined(headers: HeaderMap) -> String {
2114        let cookies: String = headers
2115            .get_all("cookie")
2116            .into_iter()
2117            .map(|c| c.to_str().unwrap_or("").to_string())
2118            .reduce(|a, b| a + "; " + &b)
2119            .unwrap_or_else(|| String::new());
2120
2121        cookies
2122    }
2123
2124    #[tokio::test]
2125    async fn it_should_not_save_cookies_when_set() {
2126        let app = Router::new()
2127            .route("/cookie", put(put_cookie_with_attributes))
2128            .route("/cookie", get(get_cookie_headers_joined));
2129        let server = TestServer::new(app).expect("Should create test server");
2130
2131        // Create a cookie.
2132        server
2133            .put(&"/cookie")
2134            .text(&"cookie-found!")
2135            .do_not_save_cookies()
2136            .await;
2137
2138        // Check, only the cookie names and their values should come back.
2139        let response_text = server.get(&"/cookie").await.text();
2140
2141        assert_eq!(response_text, "");
2142    }
2143
2144    #[tokio::test]
2145    async fn it_should_override_test_server_and_not_save_cookies_when_set() {
2146        let app = Router::new()
2147            .route("/cookie", put(put_cookie_with_attributes))
2148            .route("/cookie", get(get_cookie_headers_joined));
2149        let server = TestServer::builder()
2150            .save_cookies()
2151            .build(app)
2152            .expect("Should create test server");
2153
2154        // Create a cookie.
2155        server
2156            .put(&"/cookie")
2157            .text(&"cookie-found!")
2158            .do_not_save_cookies()
2159            .await;
2160
2161        // Check, only the cookie names and their values should come back.
2162        let response_text = server.get(&"/cookie").await.text();
2163
2164        assert_eq!(response_text, "");
2165    }
2166}
2167
2168#[cfg(test)]
2169mod test_clear_cookies {
2170    use crate::TestServer;
2171    use axum::Router;
2172    use axum::extract::Request;
2173    use axum::routing::get;
2174    use axum::routing::put;
2175    use axum_extra::extract::cookie::Cookie as AxumCookie;
2176    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
2177    use cookie::Cookie;
2178    use cookie::CookieJar;
2179    use http_body_util::BodyExt;
2180
2181    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
2182
2183    async fn get_cookie(cookies: AxumCookieJar) -> (AxumCookieJar, String) {
2184        let cookie = cookies.get(&TEST_COOKIE_NAME);
2185        let cookie_value = cookie
2186            .map(|c| c.value().to_string())
2187            .unwrap_or_else(|| "cookie-not-found".to_string());
2188
2189        (cookies, cookie_value)
2190    }
2191
2192    async fn put_cookie(
2193        mut cookies: AxumCookieJar,
2194        request: Request,
2195    ) -> (AxumCookieJar, &'static str) {
2196        let body_bytes = request
2197            .into_body()
2198            .collect()
2199            .await
2200            .expect("Should turn the body into bytes")
2201            .to_bytes();
2202
2203        let body_text: String = String::from_utf8_lossy(&body_bytes).to_string();
2204        let cookie = AxumCookie::new(TEST_COOKIE_NAME, body_text);
2205        cookies = cookies.add(cookie);
2206
2207        (cookies, &"done")
2208    }
2209
2210    #[tokio::test]
2211    async fn it_should_clear_cookie_added_to_request() {
2212        let app = Router::new().route("/cookie", get(get_cookie));
2213        let server = TestServer::new(app).expect("Should create test server");
2214
2215        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
2216        let response_text = server
2217            .get(&"/cookie")
2218            .add_cookie(cookie)
2219            .clear_cookies()
2220            .await
2221            .text();
2222
2223        assert_eq!(response_text, "cookie-not-found");
2224    }
2225
2226    #[tokio::test]
2227    async fn it_should_clear_cookie_jar_added_to_request() {
2228        let app = Router::new().route("/cookie", get(get_cookie));
2229        let server = TestServer::new(app).expect("Should create test server");
2230
2231        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
2232        let mut cookie_jar = CookieJar::new();
2233        cookie_jar.add(cookie);
2234
2235        let response_text = server
2236            .get(&"/cookie")
2237            .add_cookies(cookie_jar)
2238            .clear_cookies()
2239            .await
2240            .text();
2241
2242        assert_eq!(response_text, "cookie-not-found");
2243    }
2244
2245    #[tokio::test]
2246    async fn it_should_clear_cookies_saved_by_past_request() {
2247        let app = Router::new()
2248            .route("/cookie", put(put_cookie))
2249            .route("/cookie", get(get_cookie));
2250        let server = TestServer::new(app).expect("Should create test server");
2251
2252        // Create a cookie.
2253        server
2254            .put(&"/cookie")
2255            .text(&"cookie-found!")
2256            .save_cookies()
2257            .await;
2258
2259        // Check it comes back.
2260        let response_text = server.get(&"/cookie").clear_cookies().await.text();
2261
2262        assert_eq!(response_text, "cookie-not-found");
2263    }
2264
2265    #[tokio::test]
2266    async fn it_should_clear_cookies_added_to_test_server() {
2267        let app = Router::new()
2268            .route("/cookie", put(put_cookie))
2269            .route("/cookie", get(get_cookie));
2270        let mut server = TestServer::new(app).expect("Should create test server");
2271
2272        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
2273        server.add_cookie(cookie);
2274
2275        // Check it comes back.
2276        let response_text = server.get(&"/cookie").clear_cookies().await.text();
2277
2278        assert_eq!(response_text, "cookie-not-found");
2279    }
2280}
2281
2282#[cfg(test)]
2283mod test_add_header {
2284    use super::*;
2285    use crate::TestServer;
2286    use axum::Router;
2287    use axum::extract::FromRequestParts;
2288    use axum::routing::get;
2289    use http::HeaderName;
2290    use http::HeaderValue;
2291    use http::request::Parts;
2292    use hyper::StatusCode;
2293    use std::marker::Sync;
2294
2295    const TEST_HEADER_NAME: &'static str = &"test-header";
2296    const TEST_HEADER_CONTENT: &'static str = &"Test header content";
2297
2298    struct TestHeader(Vec<u8>);
2299
2300    impl<S: Sync> FromRequestParts<S> for TestHeader {
2301        type Rejection = (StatusCode, &'static str);
2302
2303        async fn from_request_parts(
2304            parts: &mut Parts,
2305            _state: &S,
2306        ) -> Result<TestHeader, Self::Rejection> {
2307            parts
2308                .headers
2309                .get(HeaderName::from_static(TEST_HEADER_NAME))
2310                .map(|v| TestHeader(v.as_bytes().to_vec()))
2311                .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
2312        }
2313    }
2314
2315    async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
2316        header
2317    }
2318
2319    #[tokio::test]
2320    async fn it_should_send_header_added_to_request() {
2321        // Build an application with a route.
2322        let app = Router::new().route("/header", get(ping_header));
2323
2324        // Run the server.
2325        let server = TestServer::new(app).expect("Should create test server");
2326
2327        // Send a request with the header
2328        let response = server
2329            .get(&"/header")
2330            .add_header(
2331                HeaderName::from_static(TEST_HEADER_NAME),
2332                HeaderValue::from_static(TEST_HEADER_CONTENT),
2333            )
2334            .await;
2335
2336        // Check it sent back the right text
2337        response.assert_text(TEST_HEADER_CONTENT)
2338    }
2339}
2340
2341#[cfg(test)]
2342mod test_authorization {
2343    use super::*;
2344    use crate::TestServer;
2345    use axum::Router;
2346    use axum::extract::FromRequestParts;
2347    use axum::routing::get;
2348    use http::request::Parts;
2349    use hyper::StatusCode;
2350    use std::marker::Sync;
2351
2352    fn new_test_server() -> TestServer {
2353        struct TestHeader(String);
2354
2355        impl<S: Sync> FromRequestParts<S> for TestHeader {
2356            type Rejection = (StatusCode, &'static str);
2357
2358            async fn from_request_parts(
2359                parts: &mut Parts,
2360                _state: &S,
2361            ) -> Result<TestHeader, Self::Rejection> {
2362                parts
2363                    .headers
2364                    .get(header::AUTHORIZATION)
2365                    .map(|v| TestHeader(v.to_str().unwrap().to_string()))
2366                    .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
2367            }
2368        }
2369
2370        async fn ping_auth_header(TestHeader(header): TestHeader) -> String {
2371            header
2372        }
2373
2374        // Build an application with a route.
2375        let app = Router::new().route("/auth-header", get(ping_auth_header));
2376
2377        // Run the server.
2378        let mut server = TestServer::new(app).expect("Should create test server");
2379        server.expect_success();
2380
2381        server
2382    }
2383
2384    #[tokio::test]
2385    async fn it_should_send_header_added_to_request() {
2386        let server = new_test_server();
2387
2388        // Send a request with the header
2389        let response = server
2390            .get(&"/auth-header")
2391            .authorization("Bearer abc123")
2392            .await;
2393
2394        // Check it sent back the right text
2395        response.assert_text("Bearer abc123")
2396    }
2397}
2398
2399#[cfg(test)]
2400mod test_authorization_bearer {
2401    use super::*;
2402    use crate::TestServer;
2403    use axum::Router;
2404    use axum::extract::FromRequestParts;
2405    use axum::routing::get;
2406    use http::request::Parts;
2407    use hyper::StatusCode;
2408    use std::marker::Sync;
2409
2410    fn new_test_server() -> TestServer {
2411        struct TestHeader(String);
2412
2413        impl<S: Sync> FromRequestParts<S> for TestHeader {
2414            type Rejection = (StatusCode, &'static str);
2415
2416            async fn from_request_parts(
2417                parts: &mut Parts,
2418                _state: &S,
2419            ) -> Result<TestHeader, Self::Rejection> {
2420                parts
2421                    .headers
2422                    .get(header::AUTHORIZATION)
2423                    .map(|v| TestHeader(v.to_str().unwrap().to_string().replace("Bearer ", "")))
2424                    .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
2425            }
2426        }
2427
2428        async fn ping_auth_header(TestHeader(header): TestHeader) -> String {
2429            header
2430        }
2431
2432        // Build an application with a route.
2433        let app = Router::new().route("/auth-header", get(ping_auth_header));
2434
2435        // Run the server.
2436        let mut server = TestServer::new(app).expect("Should create test server");
2437        server.expect_success();
2438
2439        server
2440    }
2441
2442    #[tokio::test]
2443    async fn it_should_send_header_added_to_request() {
2444        let server = new_test_server();
2445
2446        // Send a request with the header
2447        let response = server
2448            .get(&"/auth-header")
2449            .authorization_bearer("abc123")
2450            .await;
2451
2452        // Check it sent back the right text
2453        response.assert_text("abc123")
2454    }
2455}
2456
2457#[cfg(test)]
2458mod test_clear_headers {
2459    use super::*;
2460    use crate::TestServer;
2461    use axum::Router;
2462    use axum::extract::FromRequestParts;
2463    use axum::routing::get;
2464    use http::HeaderName;
2465    use http::HeaderValue;
2466    use http::request::Parts;
2467    use hyper::StatusCode;
2468    use std::marker::Sync;
2469
2470    const TEST_HEADER_NAME: &'static str = &"test-header";
2471    const TEST_HEADER_CONTENT: &'static str = &"Test header content";
2472
2473    struct TestHeader(Vec<u8>);
2474
2475    impl<S: Sync> FromRequestParts<S> for TestHeader {
2476        type Rejection = (StatusCode, &'static str);
2477
2478        async fn from_request_parts(
2479            parts: &mut Parts,
2480            _state: &S,
2481        ) -> Result<TestHeader, Self::Rejection> {
2482            parts
2483                .headers
2484                .get(HeaderName::from_static(TEST_HEADER_NAME))
2485                .map(|v| TestHeader(v.as_bytes().to_vec()))
2486                .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
2487        }
2488    }
2489
2490    async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
2491        header
2492    }
2493
2494    #[tokio::test]
2495    async fn it_should_clear_headers_added_to_request() {
2496        // Build an application with a route.
2497        let app = Router::new().route("/header", get(ping_header));
2498
2499        // Run the server.
2500        let server = TestServer::new(app).expect("Should create test server");
2501
2502        // Send a request with the header
2503        let response = server
2504            .get(&"/header")
2505            .add_header(
2506                HeaderName::from_static(TEST_HEADER_NAME),
2507                HeaderValue::from_static(TEST_HEADER_CONTENT),
2508            )
2509            .clear_headers()
2510            .await;
2511
2512        // Check it sent back the right text
2513        response.assert_status_bad_request();
2514        response.assert_text("Missing test header");
2515    }
2516
2517    #[tokio::test]
2518    async fn it_should_clear_headers_added_to_server() {
2519        // Build an application with a route.
2520        let app = Router::new().route("/header", get(ping_header));
2521
2522        // Run the server.
2523        let mut server = TestServer::new(app).expect("Should create test server");
2524        server.add_header(
2525            HeaderName::from_static(TEST_HEADER_NAME),
2526            HeaderValue::from_static(TEST_HEADER_CONTENT),
2527        );
2528
2529        // Send a request with the header
2530        let response = server.get(&"/header").clear_headers().await;
2531
2532        // Check it sent back the right text
2533        response.assert_status_bad_request();
2534        response.assert_text("Missing test header");
2535    }
2536}
2537
2538#[cfg(test)]
2539mod test_add_query_params {
2540    use crate::TestServer;
2541    use axum::Router;
2542    use axum::extract::Query as AxumStdQuery;
2543    use axum::routing::get;
2544    use serde::Deserialize;
2545    use serde::Serialize;
2546    use serde_json::json;
2547
2548    #[derive(Debug, Deserialize, Serialize)]
2549    struct QueryParam {
2550        message: String,
2551    }
2552
2553    async fn get_query_param(AxumStdQuery(params): AxumStdQuery<QueryParam>) -> String {
2554        params.message
2555    }
2556
2557    #[derive(Debug, Deserialize, Serialize)]
2558    struct QueryParam2 {
2559        message: String,
2560        other: String,
2561    }
2562
2563    async fn get_query_param_2(AxumStdQuery(params): AxumStdQuery<QueryParam2>) -> String {
2564        format!("{}-{}", params.message, params.other)
2565    }
2566
2567    fn build_app() -> Router {
2568        Router::new()
2569            .route("/query", get(get_query_param))
2570            .route("/query-2", get(get_query_param_2))
2571    }
2572
2573    #[tokio::test]
2574    async fn it_should_pass_up_query_params_from_serialization() {
2575        // Run the server.
2576        let server = TestServer::new(build_app()).expect("Should create test server");
2577
2578        // Get the request.
2579        server
2580            .get(&"/query")
2581            .add_query_params(QueryParam {
2582                message: "it works".to_string(),
2583            })
2584            .await
2585            .assert_text(&"it works");
2586    }
2587
2588    #[tokio::test]
2589    async fn it_should_pass_up_query_params_from_pairs() {
2590        // Run the server.
2591        let server = TestServer::new(build_app()).expect("Should create test server");
2592
2593        // Get the request.
2594        server
2595            .get(&"/query")
2596            .add_query_params(&[("message", "it works")])
2597            .await
2598            .assert_text(&"it works");
2599    }
2600
2601    #[tokio::test]
2602    async fn it_should_pass_up_multiple_query_params_from_multiple_params() {
2603        // Run the server.
2604        let server = TestServer::new(build_app()).expect("Should create test server");
2605
2606        // Get the request.
2607        server
2608            .get(&"/query-2")
2609            .add_query_params(&[("message", "it works"), ("other", "yup")])
2610            .await
2611            .assert_text(&"it works-yup");
2612    }
2613
2614    #[tokio::test]
2615    async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
2616        // Run the server.
2617        let server = TestServer::new(build_app()).expect("Should create test server");
2618
2619        // Get the request.
2620        server
2621            .get(&"/query-2")
2622            .add_query_params(&[("message", "it works")])
2623            .add_query_params(&[("other", "yup")])
2624            .await
2625            .assert_text(&"it works-yup");
2626    }
2627
2628    #[tokio::test]
2629    async fn it_should_pass_up_multiple_query_params_from_json() {
2630        // Run the server.
2631        let server = TestServer::new(build_app()).expect("Should create test server");
2632
2633        // Get the request.
2634        server
2635            .get(&"/query-2")
2636            .add_query_params(json!({
2637                "message": "it works",
2638                "other": "yup"
2639            }))
2640            .await
2641            .assert_text(&"it works-yup");
2642    }
2643}
2644
2645#[cfg(test)]
2646mod test_add_raw_query_param {
2647    use crate::TestServer;
2648    use axum::Router;
2649    use axum::extract::Query as AxumStdQuery;
2650    use axum::routing::get;
2651    use axum_extra::extract::Query as AxumExtraQuery;
2652    use serde::Deserialize;
2653    use serde::Serialize;
2654    use std::fmt::Write;
2655
2656    #[derive(Debug, Deserialize, Serialize)]
2657    struct QueryParam {
2658        message: String,
2659    }
2660
2661    async fn get_query_param(AxumStdQuery(params): AxumStdQuery<QueryParam>) -> String {
2662        params.message
2663    }
2664
2665    #[derive(Debug, Deserialize, Serialize)]
2666    struct QueryParamExtra {
2667        #[serde(default)]
2668        items: Vec<String>,
2669
2670        #[serde(default, rename = "arrs[]")]
2671        arrs: Vec<String>,
2672    }
2673
2674    async fn get_query_param_extra(
2675        AxumExtraQuery(params): AxumExtraQuery<QueryParamExtra>,
2676    ) -> String {
2677        let mut output = String::new();
2678
2679        if params.items.len() > 0 {
2680            write!(output, "{}", params.items.join(", ")).unwrap();
2681        }
2682
2683        if params.arrs.len() > 0 {
2684            write!(output, "{}", params.arrs.join(", ")).unwrap();
2685        }
2686
2687        output
2688    }
2689
2690    fn build_app() -> Router {
2691        Router::new()
2692            .route("/query", get(get_query_param))
2693            .route("/query-extra", get(get_query_param_extra))
2694    }
2695
2696    #[tokio::test]
2697    async fn it_should_pass_up_query_param_as_is() {
2698        // Run the server.
2699        let server = TestServer::new(build_app()).expect("Should create test server");
2700
2701        // Get the request.
2702        server
2703            .get(&"/query")
2704            .add_raw_query_param(&"message=it-works")
2705            .await
2706            .assert_text(&"it-works");
2707    }
2708
2709    #[tokio::test]
2710    async fn it_should_pass_up_array_query_params_as_one_string() {
2711        // Run the server.
2712        let server = TestServer::new(build_app()).expect("Should create test server");
2713
2714        // Get the request.
2715        server
2716            .get(&"/query-extra")
2717            .add_raw_query_param(&"items=one&items=two&items=three")
2718            .await
2719            .assert_text(&"one, two, three");
2720    }
2721
2722    #[tokio::test]
2723    async fn it_should_pass_up_array_query_params_as_multiple_params() {
2724        // Run the server.
2725        let server = TestServer::new(build_app()).expect("Should create test server");
2726
2727        // Get the request.
2728        server
2729            .get(&"/query-extra")
2730            .add_raw_query_param(&"arrs[]=one")
2731            .add_raw_query_param(&"arrs[]=two")
2732            .add_raw_query_param(&"arrs[]=three")
2733            .await
2734            .assert_text(&"one, two, three");
2735    }
2736}
2737
2738#[cfg(test)]
2739mod test_add_query_param {
2740    use crate::TestServer;
2741    use axum::Router;
2742    use axum::extract::Query;
2743    use axum::routing::get;
2744    use serde::Deserialize;
2745    use serde::Serialize;
2746
2747    #[derive(Debug, Deserialize, Serialize)]
2748    struct QueryParam {
2749        message: String,
2750    }
2751
2752    async fn get_query_param(Query(params): Query<QueryParam>) -> String {
2753        params.message
2754    }
2755
2756    #[derive(Debug, Deserialize, Serialize)]
2757    struct QueryParam2 {
2758        message: String,
2759        other: String,
2760    }
2761
2762    async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
2763        format!("{}-{}", params.message, params.other)
2764    }
2765
2766    #[tokio::test]
2767    async fn it_should_pass_up_query_params_from_pairs() {
2768        // Build an application with a route.
2769        let app = Router::new().route("/query", get(get_query_param));
2770
2771        // Run the server.
2772        let server = TestServer::new(app).expect("Should create test server");
2773
2774        // Get the request.
2775        server
2776            .get(&"/query")
2777            .add_query_param("message", "it works")
2778            .await
2779            .assert_text(&"it works");
2780    }
2781
2782    #[tokio::test]
2783    async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
2784        // Build an application with a route.
2785        let app = Router::new().route("/query-2", get(get_query_param_2));
2786
2787        // Run the server.
2788        let server = TestServer::new(app).expect("Should create test server");
2789
2790        // Get the request.
2791        server
2792            .get(&"/query-2")
2793            .add_query_param("message", "it works")
2794            .add_query_param("other", "yup")
2795            .await
2796            .assert_text(&"it works-yup");
2797    }
2798}
2799
2800#[cfg(test)]
2801mod test_clear_query_params {
2802    use crate::TestServer;
2803    use axum::Router;
2804    use axum::extract::Query;
2805    use axum::routing::get;
2806    use serde::Deserialize;
2807    use serde::Serialize;
2808
2809    #[derive(Debug, Deserialize, Serialize)]
2810    struct QueryParams {
2811        first: Option<String>,
2812        second: Option<String>,
2813    }
2814
2815    async fn get_query_params(Query(params): Query<QueryParams>) -> String {
2816        format!(
2817            "has first? {}, has second? {}",
2818            params.first.is_some(),
2819            params.second.is_some()
2820        )
2821    }
2822
2823    #[tokio::test]
2824    async fn it_should_clear_all_params_set() {
2825        // Build an application with a route.
2826        let app = Router::new().route("/query", get(get_query_params));
2827
2828        // Run the server.
2829        let server = TestServer::new(app).expect("Should create test server");
2830
2831        // Get the request.
2832        server
2833            .get(&"/query")
2834            .add_query_params(QueryParams {
2835                first: Some("first".to_string()),
2836                second: Some("second".to_string()),
2837            })
2838            .clear_query_params()
2839            .await
2840            .assert_text(&"has first? false, has second? false");
2841    }
2842
2843    #[tokio::test]
2844    async fn it_should_clear_all_params_set_and_allow_replacement() {
2845        // Build an application with a route.
2846        let app = Router::new().route("/query", get(get_query_params));
2847
2848        // Run the server.
2849        let server = TestServer::new(app).expect("Should create test server");
2850
2851        // Get the request.
2852        server
2853            .get(&"/query")
2854            .add_query_params(QueryParams {
2855                first: Some("first".to_string()),
2856                second: Some("second".to_string()),
2857            })
2858            .clear_query_params()
2859            .add_query_params(QueryParams {
2860                first: Some("first".to_string()),
2861                second: Some("second".to_string()),
2862            })
2863            .await
2864            .assert_text(&"has first? true, has second? true");
2865    }
2866}
2867
2868#[cfg(test)]
2869mod test_scheme {
2870    use crate::TestServer;
2871    use axum::Router;
2872    use axum::extract::Request;
2873    use axum::routing::get;
2874
2875    async fn route_get_scheme(request: Request) -> String {
2876        request.uri().scheme_str().unwrap().to_string()
2877    }
2878
2879    #[tokio::test]
2880    async fn it_should_return_http_by_default() {
2881        let router = Router::new().route("/scheme", get(route_get_scheme));
2882        let server = TestServer::builder().build(router).unwrap();
2883
2884        server.get("/scheme").await.assert_text("http");
2885    }
2886
2887    #[tokio::test]
2888    async fn it_should_return_http_when_set() {
2889        let router = Router::new().route("/scheme", get(route_get_scheme));
2890        let server = TestServer::builder().build(router).unwrap();
2891
2892        server
2893            .get("/scheme")
2894            .scheme(&"http")
2895            .await
2896            .assert_text("http");
2897    }
2898
2899    #[tokio::test]
2900    async fn it_should_return_https_when_set() {
2901        let router = Router::new().route("/scheme", get(route_get_scheme));
2902        let server = TestServer::builder().build(router).unwrap();
2903
2904        server
2905            .get("/scheme")
2906            .scheme(&"https")
2907            .await
2908            .assert_text("https");
2909    }
2910
2911    #[tokio::test]
2912    async fn it_should_override_test_server_when_set() {
2913        let router = Router::new().route("/scheme", get(route_get_scheme));
2914
2915        let mut server = TestServer::builder().build(router).unwrap();
2916        server.scheme(&"https");
2917
2918        server
2919            .get("/scheme")
2920            .scheme(&"http") // set it back to http
2921            .await
2922            .assert_text("http");
2923    }
2924}
2925
2926#[cfg(test)]
2927mod test_multipart {
2928    use crate::TestServer;
2929    use crate::multipart::MultipartForm;
2930    use crate::multipart::Part;
2931    use axum::Json;
2932    use axum::Router;
2933    use axum::extract::Multipart;
2934    use axum::routing::post;
2935    use serde_json::Value;
2936    use serde_json::json;
2937
2938    async fn route_post_multipart(mut multipart: Multipart) -> Json<Vec<String>> {
2939        let mut fields = vec![];
2940
2941        while let Some(field) = multipart.next_field().await.unwrap() {
2942            let name = field.name().unwrap().to_string();
2943            let content_type = field.content_type().unwrap().to_owned();
2944            let data = field.bytes().await.unwrap();
2945
2946            let field_stats = format!("{name} is {} bytes, {content_type}", data.len());
2947            fields.push(field_stats);
2948        }
2949
2950        Json(fields)
2951    }
2952
2953    async fn route_post_multipart_headers(mut multipart: Multipart) -> Json<Vec<Value>> {
2954        let mut sent_part_headers = vec![];
2955
2956        while let Some(field) = multipart.next_field().await.unwrap() {
2957            let part_name = field.name().unwrap().to_string();
2958            let part_header_value = field
2959                .headers()
2960                .get("x-part-header-test")
2961                .unwrap()
2962                .to_str()
2963                .unwrap()
2964                .to_string();
2965            let part_text = String::from_utf8(field.bytes().await.unwrap().into()).unwrap();
2966
2967            sent_part_headers.push(json!({
2968                "name": part_name,
2969                "text": part_text,
2970                "header": part_header_value,
2971            }))
2972        }
2973
2974        Json(sent_part_headers)
2975    }
2976
2977    fn test_router() -> Router {
2978        Router::new()
2979            .route("/multipart", post(route_post_multipart))
2980            .route("/multipart_headers", post(route_post_multipart_headers))
2981    }
2982
2983    #[tokio::test]
2984    async fn it_should_get_multipart_stats_on_mock_transport() {
2985        // Run the server.
2986        let server = TestServer::builder()
2987            .mock_transport()
2988            .build(test_router())
2989            .expect("Should create test server");
2990
2991        let form = MultipartForm::new()
2992            .add_text("penguins?", "lots")
2993            .add_text("animals", "🦊🦊🦊")
2994            .add_text("carrots", 123 as u32);
2995
2996        // Get the request.
2997        server
2998            .post(&"/multipart")
2999            .multipart(form)
3000            .await
3001            .assert_json(&vec![
3002                "penguins? is 4 bytes, text/plain".to_string(),
3003                "animals is 12 bytes, text/plain".to_string(),
3004                "carrots is 3 bytes, text/plain".to_string(),
3005            ]);
3006    }
3007
3008    #[tokio::test]
3009    async fn it_should_get_multipart_stats_on_http_transport() {
3010        // Run the server.
3011        let server = TestServer::builder()
3012            .http_transport()
3013            .build(test_router())
3014            .expect("Should create test server");
3015
3016        let form = MultipartForm::new()
3017            .add_text("penguins?", "lots")
3018            .add_text("animals", "🦊🦊🦊")
3019            .add_text("carrots", 123 as u32);
3020
3021        // Get the request.
3022        server
3023            .post(&"/multipart")
3024            .multipart(form)
3025            .await
3026            .assert_json(&vec![
3027                "penguins? is 4 bytes, text/plain".to_string(),
3028                "animals is 12 bytes, text/plain".to_string(),
3029                "carrots is 3 bytes, text/plain".to_string(),
3030            ]);
3031    }
3032
3033    #[tokio::test]
3034    async fn it_should_send_text_parts_as_text() {
3035        // Run the server.
3036        let server = TestServer::builder()
3037            .mock_transport()
3038            .build(test_router())
3039            .expect("Should create test server");
3040
3041        let form = MultipartForm::new().add_part("animals", Part::text("🦊🦊🦊"));
3042
3043        // Get the request.
3044        server
3045            .post(&"/multipart")
3046            .multipart(form)
3047            .await
3048            .assert_json(&vec!["animals is 12 bytes, text/plain".to_string()]);
3049    }
3050
3051    #[tokio::test]
3052    async fn it_should_send_custom_mime_type() {
3053        // Run the server.
3054        let server = TestServer::builder()
3055            .mock_transport()
3056            .build(test_router())
3057            .expect("Should create test server");
3058
3059        let form = MultipartForm::new().add_part(
3060            "animals",
3061            Part::bytes("🦊,🦊,🦊".as_bytes()).mime_type(mime::TEXT_CSV),
3062        );
3063
3064        // Get the request.
3065        server
3066            .post(&"/multipart")
3067            .multipart(form)
3068            .await
3069            .assert_json(&vec!["animals is 14 bytes, text/csv".to_string()]);
3070    }
3071
3072    #[tokio::test]
3073    async fn it_should_send_using_include_bytes() {
3074        // Run the server.
3075        let server = TestServer::builder()
3076            .mock_transport()
3077            .build(test_router())
3078            .expect("Should create test server");
3079
3080        let form = MultipartForm::new().add_part(
3081            "file",
3082            Part::bytes(include_bytes!("../files/example.txt").as_slice())
3083                .mime_type(mime::TEXT_PLAIN),
3084        );
3085
3086        // Get the request.
3087        server
3088            .post(&"/multipart")
3089            .multipart(form)
3090            .await
3091            .assert_json(&vec!["file is 6 bytes, text/plain".to_string()]);
3092    }
3093
3094    #[tokio::test]
3095    async fn it_should_send_form_headers_in_parts() {
3096        // Run the server.
3097        let server = TestServer::builder()
3098            .mock_transport()
3099            .build(test_router())
3100            .expect("Should create test server");
3101
3102        let form = MultipartForm::new()
3103            .add_part(
3104                "part_1",
3105                Part::text("part_1_text").add_header("x-part-header-test", "part_1_header"),
3106            )
3107            .add_part(
3108                "part_2",
3109                Part::text("part_2_text").add_header("x-part-header-test", "part_2_header"),
3110            );
3111
3112        // Get the request.
3113        server
3114            .post(&"/multipart_headers")
3115            .multipart(form)
3116            .await
3117            .assert_json(&json!([
3118                {
3119                    "name": "part_1",
3120                    "text": "part_1_text",
3121                    "header": "part_1_header",
3122                },
3123                {
3124                    "name": "part_2",
3125                    "text": "part_2_text",
3126                    "header": "part_2_header",
3127                },
3128            ]));
3129    }
3130}