axum_test/
test_request.rs

1use anyhow::Context;
2use anyhow::Error as AnyhowError;
3use anyhow::Result;
4use anyhow::anyhow;
5use axum::body::Body;
6use bytes::Bytes;
7use cookie::Cookie;
8use cookie::CookieJar;
9use cookie::time::OffsetDateTime;
10use http::HeaderName;
11use http::HeaderValue;
12use http::Method;
13use http::Request;
14use http::header;
15use http::header::SET_COOKIE;
16use http_body_util::BodyExt;
17use serde::Serialize;
18use std::fmt::Debug;
19use std::fmt::Display;
20use std::fs::File;
21use std::fs::read;
22use std::fs::read_to_string;
23use std::future::{Future, IntoFuture};
24use std::io::BufReader;
25use std::path::Path;
26use std::pin::Pin;
27use std::sync::Arc;
28use std::sync::Mutex;
29use url::Url;
30
31use crate::ServerSharedState;
32use crate::TestResponse;
33use crate::internals::ExpectedState;
34use crate::internals::QueryParamsStore;
35use crate::internals::RequestPathFormatter;
36use crate::multipart::MultipartForm;
37use crate::transport_layer::TransportLayer;
38
39mod test_request_config;
40pub(crate) use self::test_request_config::*;
41
42///
43/// A `TestRequest` is for building and executing a HTTP request to the [`TestServer`](crate::TestServer).
44///
45/// ## Building
46///
47/// Requests are created by the [`TestServer`](crate::TestServer), using it's builder functions.
48/// They correspond to the appropriate HTTP method: [`TestServer::get()`](crate::TestServer::get()),
49/// [`TestServer::post()`](crate::TestServer::post()), etc.
50///
51/// See there for documentation.
52///
53/// ## Customising
54///
55/// The `TestRequest` allows the caller to fill in the rest of the request
56/// to be sent to the server. Including the headers, the body, cookies,
57/// and the content type, using the relevant functions.
58///
59/// The TestRequest struct provides a number of methods to set up the request,
60/// such as json, text, bytes, expect_failure, content_type, etc.
61///
62/// ## Sending
63///
64/// Once fully configured you send the request by awaiting the request object.
65///
66/// ```rust
67/// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
68/// #
69/// # use axum::Router;
70/// # use axum_test::TestServer;
71/// #
72/// # let server = TestServer::new(Router::new())?;
73/// #
74/// // Build your request
75/// let request = server.get(&"/user")
76///     .add_header("x-custom-header", "example.com")
77///     .content_type("application/yaml");
78///
79/// // await request to execute
80/// let response = request.await;
81/// #
82/// # Ok(()) }
83/// ```
84///
85/// You will receive a `TestResponse`.
86///
87/// ## Cookie Saving
88///
89/// [`TestRequest::save_cookies()`](crate::TestRequest::save_cookies()) and [`TestRequest::do_not_save_cookies()`](crate::TestRequest::do_not_save_cookies())
90/// methods allow you to set the request to save cookies to the `TestServer`,
91/// for reuse on any future requests.
92///
93/// This behaviour is **off** by default, and can be changed for all `TestRequests`
94/// when building the `TestServer`. By building it with a `TestServerConfig` where `save_cookies` is set to true.
95///
96/// ## Expecting Failure and Success
97///
98/// When making a request you can mark it to expect a response within,
99/// or outside, of the 2xx range of HTTP status codes.
100///
101/// If the response returns a status code different to what is expected,
102/// then it will panic.
103///
104/// This is useful when making multiple requests within a test.
105/// As it can find issues earlier than later.
106///
107/// See the [`TestRequest::expect_failure()`](crate::TestRequest::expect_failure()),
108/// and [`TestRequest::expect_success()`](crate::TestRequest::expect_success()).
109///
110#[derive(Debug)]
111#[must_use = "futures do nothing unless polled"]
112pub struct TestRequest {
113    config: TestRequestConfig,
114
115    server_state: Arc<Mutex<ServerSharedState>>,
116    transport: Arc<Box<dyn TransportLayer>>,
117
118    body: Option<Body>,
119
120    expected_state: ExpectedState,
121}
122
123impl TestRequest {
124    pub(crate) fn new(
125        server_state: Arc<Mutex<ServerSharedState>>,
126        transport: Arc<Box<dyn TransportLayer>>,
127        config: TestRequestConfig,
128    ) -> Self {
129        let expected_state = config.expected_state;
130
131        Self {
132            config,
133            server_state,
134            transport,
135            body: None,
136            expected_state,
137        }
138    }
139
140    /// Set the body of the request to send up data as Json,
141    /// and changes the content type to `application/json`.
142    pub fn json<J>(self, body: &J) -> Self
143    where
144        J: ?Sized + Serialize,
145    {
146        let body_bytes =
147            serde_json::to_vec(body).expect("It should serialize the content into Json");
148
149        self.bytes(body_bytes.into())
150            .content_type(mime::APPLICATION_JSON.essence_str())
151    }
152
153    /// Sends a payload as a Json request, with the contents coming from a file.
154    pub fn json_from_file<P>(self, path: P) -> Self
155    where
156        P: AsRef<Path>,
157    {
158        let path_ref = path.as_ref();
159        let file = File::open(path_ref)
160            .with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
161            .unwrap();
162
163        let reader = BufReader::new(file);
164        let payload = serde_json::from_reader::<_, serde_json::Value>(reader)
165            .with_context(|| {
166                format!(
167                    "Failed to deserialize file '{}' as Json",
168                    path_ref.display()
169                )
170            })
171            .unwrap();
172
173        self.json(&payload)
174    }
175
176    /// Set the body of the request to send up data as Yaml,
177    /// and changes the content type to `application/yaml`.
178    #[cfg(feature = "yaml")]
179    pub fn yaml<Y>(self, body: &Y) -> Self
180    where
181        Y: ?Sized + Serialize,
182    {
183        let body = serde_yaml::to_string(body).expect("It should serialize the content into Yaml");
184
185        self.bytes(body.into_bytes().into())
186            .content_type("application/yaml")
187    }
188
189    /// Sends a payload as a Yaml request, with the contents coming from a file.
190    #[cfg(feature = "yaml")]
191    pub fn yaml_from_file<P>(self, path: P) -> Self
192    where
193        P: AsRef<Path>,
194    {
195        let path_ref = path.as_ref();
196        let file = File::open(path_ref)
197            .with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
198            .unwrap();
199
200        let reader = BufReader::new(file);
201        let payload = serde_yaml::from_reader::<_, serde_yaml::Value>(reader)
202            .with_context(|| {
203                format!(
204                    "Failed to deserialize file '{}' as Yaml",
205                    path_ref.display()
206                )
207            })
208            .unwrap();
209
210        self.yaml(&payload)
211    }
212
213    /// Set the body of the request to send up data as MsgPack,
214    /// and changes the content type to `application/msgpack`.
215    #[cfg(feature = "msgpack")]
216    pub fn msgpack<M>(self, body: &M) -> Self
217    where
218        M: ?Sized + Serialize,
219    {
220        let body_bytes =
221            ::rmp_serde::to_vec(body).expect("It should serialize the content into MsgPack");
222
223        self.bytes(body_bytes.into())
224            .content_type("application/msgpack")
225    }
226
227    /// Sets the body of the request, with the content type
228    /// of 'application/x-www-form-urlencoded'.
229    pub fn form<F>(self, body: &F) -> Self
230    where
231        F: ?Sized + Serialize,
232    {
233        let body_text =
234            serde_urlencoded::to_string(body).expect("It should serialize the content into a Form");
235
236        self.bytes(body_text.into())
237            .content_type(mime::APPLICATION_WWW_FORM_URLENCODED.essence_str())
238    }
239
240    /// For sending multipart forms.
241    /// The payload is built using [`MultipartForm`](crate::multipart::MultipartForm) and [`Part`](crate::multipart::Part).
242    ///
243    /// This will be sent with the content type of 'multipart/form-data'.
244    ///
245    /// # Simple example
246    ///
247    /// ```rust
248    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
249    /// #
250    /// use axum::Router;
251    /// use axum_test::TestServer;
252    /// use axum_test::multipart::MultipartForm;
253    ///
254    /// let app = Router::new();
255    /// let server = TestServer::new(app)?;
256    ///
257    /// let multipart_form = MultipartForm::new()
258    ///     .add_text("name", "Joe")
259    ///     .add_text("animals", "foxes");
260    ///
261    /// let response = server.post(&"/my-form")
262    ///     .multipart(multipart_form)
263    ///     .await;
264    /// #
265    /// # Ok(()) }
266    /// ```
267    ///
268    /// # Sending byte parts
269    ///
270    /// ```rust
271    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
272    /// #
273    /// use axum::Router;
274    /// use axum_test::TestServer;
275    /// use axum_test::multipart::MultipartForm;
276    /// use axum_test::multipart::Part;
277    ///
278    /// let app = Router::new();
279    /// let server = TestServer::new(app)?;
280    ///
281    /// let readme_bytes = include_bytes!("../README.md");
282    /// let readme_part = Part::bytes(readme_bytes.as_slice())
283    ///     .file_name(&"README.md")
284    ///     .mime_type(&"text/markdown");
285    ///
286    /// let multipart_form = MultipartForm::new()
287    ///     .add_part("file", readme_part);
288    ///
289    /// let response = server.post(&"/my-form")
290    ///     .multipart(multipart_form)
291    ///     .await;
292    /// #
293    /// # Ok(()) }
294    /// ```
295    ///
296    pub fn multipart(mut self, multipart: MultipartForm) -> Self {
297        self.config.content_type = Some(multipart.content_type());
298        self.body = Some(multipart.into());
299
300        self
301    }
302
303    /// Set raw text as the body of the request,
304    /// and sets the content type to `text/plain`.
305    pub fn text<T>(self, raw_text: T) -> Self
306    where
307        T: Display,
308    {
309        let body_text = raw_text.to_string();
310
311        self.bytes(body_text.into())
312            .content_type(mime::TEXT_PLAIN.essence_str())
313    }
314
315    /// Sends a payload as plain text, with the contents coming from a file.
316    pub fn text_from_file<P>(self, path: P) -> Self
317    where
318        P: AsRef<Path>,
319    {
320        let path_ref = path.as_ref();
321        let payload = read_to_string(path_ref)
322            .with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
323            .unwrap();
324
325        self.text(payload)
326    }
327
328    /// Set raw bytes as the body of the request.
329    ///
330    /// The content type is left unchanged.
331    pub fn bytes(mut self, body_bytes: Bytes) -> Self {
332        let body: Body = body_bytes.into();
333
334        self.body = Some(body);
335        self
336    }
337
338    /// Reads the contents of the file as raw bytes, and sends it within the request.
339    ///
340    /// The content type is left unchanged, and no parsing of the file is done.
341    pub fn bytes_from_file<P>(self, path: P) -> Self
342    where
343        P: AsRef<Path>,
344    {
345        let path_ref = path.as_ref();
346        let payload = read(path_ref)
347            .with_context(|| format!("Failed to read from file '{}'", path_ref.display()))
348            .unwrap();
349
350        self.bytes(payload.into())
351    }
352
353    /// Set the content type to use for this request in the header.
354    pub fn content_type(mut self, content_type: &str) -> Self {
355        self.config.content_type = Some(content_type.to_string());
356        self
357    }
358
359    /// Adds a Cookie to be sent with this request.
360    pub fn add_cookie(mut self, cookie: Cookie<'_>) -> Self {
361        self.config.cookies.add(cookie.into_owned());
362        self
363    }
364
365    /// Adds many cookies to be used with this request.
366    pub fn add_cookies(mut self, cookies: CookieJar) -> Self {
367        for cookie in cookies.iter() {
368            self.config.cookies.add(cookie.clone());
369        }
370
371        self
372    }
373
374    /// Clears all cookies used internally within this Request,
375    /// including any that came from the `TestServer`.
376    pub fn clear_cookies(mut self) -> Self {
377        self.config.cookies = CookieJar::new();
378        self
379    }
380
381    /// Any cookies returned will be saved to the [`TestServer`](crate::TestServer) that created this,
382    /// which will continue to use those cookies on future requests.
383    pub fn save_cookies(mut self) -> Self {
384        self.config.is_saving_cookies = true;
385        self
386    }
387
388    /// Cookies returned by this will _not_ be saved to the `TestServer`.
389    /// For use by future requests.
390    ///
391    /// This is the default behaviour.
392    /// You can change that default in [`TestServerConfig`](crate::TestServerConfig).
393    pub fn do_not_save_cookies(mut self) -> Self {
394        self.config.is_saving_cookies = false;
395        self
396    }
397
398    /// Adds query parameters to be sent with this request.
399    pub fn add_query_param<V>(self, key: &str, value: V) -> Self
400    where
401        V: Serialize,
402    {
403        self.add_query_params(&[(key, value)])
404    }
405
406    /// Adds the structure given as query parameters for this request.
407    ///
408    /// This is designed to take a list of parameters, or a body of parameters,
409    /// and then serializes them into the parameters of the request.
410    ///
411    /// # Sending a body of parameters using `json!`
412    ///
413    /// ```rust
414    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
415    /// #
416    /// use axum::Router;
417    /// use axum_test::TestServer;
418    /// use serde_json::json;
419    ///
420    /// let app = Router::new();
421    /// let server = TestServer::new(app)?;
422    ///
423    /// let response = server.get(&"/my-end-point")
424    ///     .add_query_params(json!({
425    ///         "username": "Brian",
426    ///         "age": 20
427    ///     }))
428    ///     .await;
429    /// #
430    /// # Ok(()) }
431    /// ```
432    ///
433    /// # Sending a body of parameters with Serde
434    ///
435    /// ```rust
436    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
437    /// #
438    /// use axum::Router;
439    /// use axum_test::TestServer;
440    /// use serde::Deserialize;
441    /// use serde::Serialize;
442    ///
443    /// #[derive(Serialize, Deserialize)]
444    /// struct UserQueryParams {
445    ///     username: String,
446    ///     age: u32,
447    /// }
448    ///
449    /// let app = Router::new();
450    /// let server = TestServer::new(app)?;
451    ///
452    /// let response = server.get(&"/my-end-point")
453    ///     .add_query_params(UserQueryParams {
454    ///         username: "Brian".to_string(),
455    ///         age: 20
456    ///     })
457    ///     .await;
458    /// #
459    /// # Ok(()) }
460    /// ```
461    ///
462    /// # Sending a list of parameters
463    ///
464    /// ```rust
465    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
466    /// #
467    /// use axum::Router;
468    /// use axum_test::TestServer;
469    ///
470    /// let app = Router::new();
471    /// let server = TestServer::new(app)?;
472    ///
473    /// let response = server.get(&"/my-end-point")
474    ///     .add_query_params(&[
475    ///         ("username", "Brian"),
476    ///         ("age", "20"),
477    ///     ])
478    ///     .await;
479    /// #
480    /// # Ok(()) }
481    /// ```
482    ///
483    pub fn add_query_params<V>(mut self, query_params: V) -> Self
484    where
485        V: Serialize,
486    {
487        self.config
488            .query_params
489            .add(query_params)
490            .with_context(|| {
491                format!(
492                    "It should serialize query parameters, for request {}",
493                    self.debug_request_format()
494                )
495            })
496            .unwrap();
497
498        self
499    }
500
501    /// Adds a query param onto the end of the request,
502    /// with no urlencoding of any kind.
503    ///
504    /// This exists to allow custom query parameters,
505    /// such as for the many versions of query param arrays.
506    ///
507    /// ```rust
508    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
509    /// #
510    /// use axum::Router;
511    /// use axum_test::TestServer;
512    ///
513    /// let app = Router::new();
514    /// let server = TestServer::new(app)?;
515    ///
516    /// let response = server.get(&"/my-end-point")
517    ///     .add_raw_query_param(&"my-flag")
518    ///     .add_raw_query_param(&"array[]=123")
519    ///     .add_raw_query_param(&"filter[value]=some-value")
520    ///     .await;
521    /// #
522    /// # Ok(()) }
523    /// ```
524    ///
525    pub fn add_raw_query_param(mut self, query_param: &str) -> Self {
526        self.config.query_params.add_raw(query_param.to_string());
527
528        self
529    }
530
531    /// Clears all query params set,
532    /// including any that came from the [`TestServer`](crate::TestServer).
533    pub fn clear_query_params(mut self) -> Self {
534        self.config.query_params.clear();
535        self
536    }
537
538    /// Adds a header to be sent with this request.
539    ///
540    /// ```rust
541    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
542    /// #
543    /// use axum::Router;
544    /// use axum_test::TestServer;
545    ///
546    /// let app = Router::new();
547    /// let server = TestServer::new(app)?;
548    ///
549    /// let response = server.get(&"/my-end-point")
550    ///     .add_header("x-custom-header", "custom-value")
551    ///     .add_header(http::header::CONTENT_LENGTH, 12345)
552    ///     .add_header(http::header::HOST, "example.com")
553    ///     .await;
554    /// #
555    /// # Ok(()) }
556    /// ```
557    pub fn add_header<N, V>(mut self, name: N, value: V) -> Self
558    where
559        N: TryInto<HeaderName>,
560        N::Error: Debug,
561        V: TryInto<HeaderValue>,
562        V::Error: Debug,
563    {
564        let header_name: HeaderName = name
565            .try_into()
566            .expect("Failed to convert header name to HeaderName");
567        let header_value: HeaderValue = value
568            .try_into()
569            .expect("Failed to convert header vlue to HeaderValue");
570
571        self.config.headers.push((header_name, header_value));
572        self
573    }
574
575    /// Adds an 'AUTHORIZATION' HTTP header to the request,
576    /// with no internal formatting of what is given.
577    pub fn authorization<T>(self, authorization_header: T) -> Self
578    where
579        T: AsRef<str>,
580    {
581        let authorization_header_value = HeaderValue::from_str(authorization_header.as_ref())
582            .expect("Cannot build Authorization HeaderValue from token");
583
584        self.add_header(header::AUTHORIZATION, authorization_header_value)
585    }
586
587    /// Adds an 'AUTHORIZATION' HTTP header to the request,
588    /// in the 'Bearer {token}' format.
589    pub fn authorization_bearer<T>(self, authorization_bearer_token: T) -> Self
590    where
591        T: Display,
592    {
593        let authorization_bearer_header_str = format!("Bearer {authorization_bearer_token}");
594        self.authorization(authorization_bearer_header_str)
595    }
596
597    /// Clears all headers set.
598    pub fn clear_headers(mut self) -> Self {
599        self.config.headers = vec![];
600        self
601    }
602
603    /// Sets the scheme to use when making the request. i.e. http or https.
604    /// The default scheme is 'http'.
605    ///
606    /// ```rust
607    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
608    /// #
609    /// use axum::Router;
610    /// use axum_test::TestServer;
611    ///
612    /// let app = Router::new();
613    /// let server = TestServer::new(app)?;
614    ///
615    /// let response = server
616    ///     .get(&"/my-end-point")
617    ///     .scheme(&"https")
618    ///     .await;
619    /// #
620    /// # Ok(()) }
621    /// ```
622    ///
623    pub fn scheme(mut self, scheme: &str) -> Self {
624        self.config
625            .full_request_url
626            .set_scheme(scheme)
627            .map_err(|_| anyhow!("Scheme '{scheme}' cannot be set to request"))
628            .unwrap();
629        self
630    }
631
632    /// Marks that this request is expected to always return a HTTP
633    /// status code within the 2xx range (200 to 299).
634    ///
635    /// If a code _outside_ of that range is returned,
636    /// then this will panic.
637    ///
638    /// ```rust
639    /// # async fn test() -> Result<(), Box<dyn ::std::error::Error>> {
640    /// #
641    /// use axum::Json;
642    /// use axum::Router;
643    /// use axum::http::StatusCode;
644    /// use axum::routing::put;
645    /// use serde_json::json;
646    ///
647    /// use axum_test::TestServer;
648    ///
649    /// let app = Router::new()
650    ///     .route(&"/todo", put(|| async { StatusCode::NOT_FOUND }));
651    ///
652    /// let server = TestServer::new(app)?;
653    ///
654    /// // If this doesn't return a value in the 2xx range,
655    /// // then it will panic.
656    /// server.put(&"/todo")
657    ///     .expect_success()
658    ///     .json(&json!({
659    ///         "task": "buy milk",
660    ///     }))
661    ///     .await;
662    /// #
663    /// # Ok(())
664    /// # }
665    /// ```
666    ///
667    pub fn expect_success(self) -> Self {
668        self.expect_state(ExpectedState::Success)
669    }
670
671    /// Marks that this request is expected to return a HTTP status code
672    /// outside of the 2xx range.
673    ///
674    /// If a code _within_ the 2xx range is returned,
675    /// then this will panic.
676    pub fn expect_failure(self) -> Self {
677        self.expect_state(ExpectedState::Failure)
678    }
679
680    fn expect_state(mut self, expected_state: ExpectedState) -> Self {
681        self.expected_state = expected_state;
682        self
683    }
684
685    async fn send(self) -> Result<TestResponse> {
686        let debug_request_format = self.debug_request_format().to_string();
687
688        let method = self.config.method;
689        let expected_state = self.expected_state;
690        let save_cookies = self.config.is_saving_cookies;
691        let body = self.body.unwrap_or(Body::empty());
692        let url =
693            Self::build_url_query_params(self.config.full_request_url, &self.config.query_params);
694
695        let request = Self::build_request(
696            method.clone(),
697            &url,
698            body,
699            self.config.content_type,
700            self.config.cookies,
701            self.config.headers,
702            &debug_request_format,
703        )?;
704
705        #[allow(unused_mut)] // Allowed for the `ws` use immediately after.
706        let mut http_response = self.transport.send(request).await?;
707
708        #[cfg(feature = "ws")]
709        let websockets = {
710            let maybe_on_upgrade = http_response
711                .extensions_mut()
712                .remove::<hyper::upgrade::OnUpgrade>();
713            let transport_type = self.transport.transport_layer_type();
714
715            crate::internals::TestResponseWebSocket {
716                maybe_on_upgrade,
717                transport_type,
718            }
719        };
720
721        let version = http_response.version();
722        let (parts, response_body) = http_response.into_parts();
723        let response_bytes = response_body.collect().await?.to_bytes();
724
725        if save_cookies {
726            let cookie_headers = parts.headers.get_all(SET_COOKIE).into_iter();
727            ServerSharedState::add_cookies_by_header(&self.server_state, cookie_headers)?;
728        }
729
730        let test_response = TestResponse::new(
731            version,
732            method,
733            url,
734            parts,
735            response_bytes,
736            #[cfg(feature = "ws")]
737            websockets,
738        );
739
740        // Assert if ok or not.
741        match expected_state {
742            ExpectedState::Success => test_response.assert_status_success(),
743            ExpectedState::Failure => test_response.assert_status_failure(),
744            ExpectedState::None => {}
745        }
746
747        Ok(test_response)
748    }
749
750    fn build_url_query_params(mut url: Url, query_params: &QueryParamsStore) -> Url {
751        // Add all the query params we have
752        if query_params.has_content() {
753            url.set_query(Some(&query_params.to_string()));
754        }
755
756        url
757    }
758
759    fn build_request(
760        method: Method,
761        url: &Url,
762        body: Body,
763        content_type: Option<String>,
764        cookies: CookieJar,
765        headers: Vec<(HeaderName, HeaderValue)>,
766        debug_request_format: &str,
767    ) -> Result<Request<Body>> {
768        let mut request_builder = Request::builder().uri(url.as_str()).method(method);
769
770        // Add all the headers we have.
771        if let Some(content_type) = content_type {
772            let (header_key, header_value) =
773                build_content_type_header(&content_type, debug_request_format)?;
774            request_builder = request_builder.header(header_key, header_value);
775        }
776
777        // Add all the non-expired cookies as headers
778        // Also strip cookies from their attributes, only their names and values should be preserved to conform the HTTP standard
779        let now = OffsetDateTime::now_utc();
780        for cookie in cookies.iter() {
781            let expired = cookie
782                .expires_datetime()
783                .map(|expires| expires <= now)
784                .unwrap_or(false);
785
786            if !expired {
787                let cookie_raw = cookie.stripped().to_string();
788                let header_value = HeaderValue::from_str(&cookie_raw)?;
789                request_builder = request_builder.header(header::COOKIE, header_value);
790            }
791        }
792
793        // Put headers into the request
794        for (header_name, header_value) in headers {
795            request_builder = request_builder.header(header_name, header_value);
796        }
797
798        let request = request_builder.body(body).with_context(|| {
799            format!("Expect valid hyper Request to be built, for request {debug_request_format}")
800        })?;
801
802        Ok(request)
803    }
804
805    fn debug_request_format(&self) -> RequestPathFormatter<'_> {
806        RequestPathFormatter::new(
807            &self.config.method,
808            self.config.full_request_url.as_str(),
809            Some(&self.config.query_params),
810        )
811    }
812}
813
814impl TryFrom<TestRequest> for Request<Body> {
815    type Error = AnyhowError;
816
817    fn try_from(test_request: TestRequest) -> Result<Request<Body>> {
818        let debug_request_format = test_request.debug_request_format().to_string();
819        let url = TestRequest::build_url_query_params(
820            test_request.config.full_request_url,
821            &test_request.config.query_params,
822        );
823        let body = test_request.body.unwrap_or(Body::empty());
824
825        TestRequest::build_request(
826            test_request.config.method,
827            &url,
828            body,
829            test_request.config.content_type,
830            test_request.config.cookies,
831            test_request.config.headers,
832            &debug_request_format,
833        )
834    }
835}
836
837impl IntoFuture for TestRequest {
838    type Output = TestResponse;
839    type IntoFuture = Pin<Box<dyn Future<Output = TestResponse> + Send>>;
840
841    fn into_future(self) -> Self::IntoFuture {
842        Box::pin(async { self.send().await.context("Sending request failed").unwrap() })
843    }
844}
845
846fn build_content_type_header(
847    content_type: &str,
848    debug_request_format: &str,
849) -> Result<(HeaderName, HeaderValue)> {
850    let header_value = HeaderValue::from_str(content_type).with_context(|| {
851        format!(
852            "Failed to store header content type '{content_type}', for request {debug_request_format}"
853        )
854    })?;
855
856    Ok((header::CONTENT_TYPE, header_value))
857}
858
859#[cfg(test)]
860mod test_content_type {
861    use crate::TestServer;
862    use axum::Router;
863    use axum::routing::get;
864    use http::HeaderMap;
865    use http::header::CONTENT_TYPE;
866
867    async fn get_content_type(headers: HeaderMap) -> String {
868        headers
869            .get(CONTENT_TYPE)
870            .map(|h| h.to_str().unwrap().to_string())
871            .unwrap_or_else(|| "".to_string())
872    }
873
874    #[tokio::test]
875    async fn it_should_not_set_a_content_type_by_default() {
876        // Build an application with a route.
877        let app = Router::new().route("/content_type", get(get_content_type));
878
879        // Run the server.
880        let server = TestServer::new(app).expect("Should create test server");
881
882        // Get the request.
883        let text = server.get(&"/content_type").await.text();
884
885        assert_eq!(text, "");
886    }
887
888    #[tokio::test]
889    async fn it_should_override_server_content_type_when_present() {
890        // Build an application with a route.
891        let app = Router::new().route("/content_type", get(get_content_type));
892
893        // Run the server.
894        let server = TestServer::builder()
895            .default_content_type("text/plain")
896            .build(app)
897            .expect("Should create test server");
898
899        // Get the request.
900        let text = server
901            .get(&"/content_type")
902            .content_type(&"application/json")
903            .await
904            .text();
905
906        assert_eq!(text, "application/json");
907    }
908
909    #[tokio::test]
910    async fn it_should_set_content_type_when_present() {
911        // Build an application with a route.
912        let app = Router::new().route("/content_type", get(get_content_type));
913
914        // Run the server.
915        let server = TestServer::new(app).expect("Should create test server");
916
917        // Get the request.
918        let text = server
919            .get(&"/content_type")
920            .content_type(&"application/custom")
921            .await
922            .text();
923
924        assert_eq!(text, "application/custom");
925    }
926}
927
928#[cfg(test)]
929mod test_json {
930    use crate::TestServer;
931    use axum::Json;
932    use axum::Router;
933    use axum::extract::DefaultBodyLimit;
934    use axum::routing::post;
935    use http::HeaderMap;
936    use http::header::CONTENT_TYPE;
937    use rand::random;
938    use serde::Deserialize;
939    use serde::Serialize;
940    use serde_json::json;
941
942    #[tokio::test]
943    async fn it_should_pass_json_up_to_be_read() {
944        #[derive(Deserialize, Serialize)]
945        struct TestJson {
946            name: String,
947            age: u32,
948            pets: Option<String>,
949        }
950
951        // Build an application with a route.
952        let app = Router::new().route(
953            "/json",
954            post(|Json(json): Json<TestJson>| async move {
955                format!(
956                    "json: {}, {}, {}",
957                    json.name,
958                    json.age,
959                    json.pets.unwrap_or_else(|| "pandas".to_string())
960                )
961            }),
962        );
963
964        // Run the server.
965        let server = TestServer::new(app).expect("Should create test server");
966
967        // Get the request.
968        let text = server
969            .post(&"/json")
970            .json(&TestJson {
971                name: "Joe".to_string(),
972                age: 20,
973                pets: Some("foxes".to_string()),
974            })
975            .await
976            .text();
977
978        assert_eq!(text, "json: Joe, 20, foxes");
979    }
980
981    #[tokio::test]
982    async fn it_should_pass_json_content_type_for_json() {
983        // Build an application with a route.
984        let app = Router::new().route(
985            "/content_type",
986            post(|headers: HeaderMap| async move {
987                headers
988                    .get(CONTENT_TYPE)
989                    .map(|h| h.to_str().unwrap().to_string())
990                    .unwrap_or_else(|| "".to_string())
991            }),
992        );
993
994        // Run the server.
995        let server = TestServer::new(app).expect("Should create test server");
996
997        // Get the request.
998        let text = server.post(&"/content_type").json(&json!({})).await.text();
999
1000        assert_eq!(text, "application/json");
1001    }
1002
1003    #[tokio::test]
1004    async fn it_should_pass_large_json_blobs_over_http() {
1005        const LARGE_BLOB_SIZE: usize = 16777216; // 16mb
1006
1007        #[derive(Deserialize, Serialize, PartialEq, Debug)]
1008        struct TestLargeJson {
1009            items: Vec<String>,
1010        }
1011
1012        let mut size = 0;
1013        let mut items = vec![];
1014        while size < LARGE_BLOB_SIZE {
1015            let item = random::<u64>().to_string();
1016            size += item.len();
1017            items.push(item);
1018        }
1019        let large_json_blob = TestLargeJson { items };
1020
1021        // Build an application with a route.
1022        let app = Router::new()
1023            .route(
1024                "/json",
1025                post(|Json(json): Json<TestLargeJson>| async { Json(json) }),
1026            )
1027            .layer(DefaultBodyLimit::max(LARGE_BLOB_SIZE * 2));
1028
1029        // Run the server.
1030        let server = TestServer::builder()
1031            .http_transport()
1032            .expect_success_by_default()
1033            .build(app)
1034            .expect("Should create test server");
1035
1036        // Get the request.
1037        server
1038            .post(&"/json")
1039            .json(&large_json_blob)
1040            .await
1041            .assert_json(&large_json_blob);
1042    }
1043}
1044
1045#[cfg(test)]
1046mod test_json_from_file {
1047    use crate::TestServer;
1048    use axum::Json;
1049    use axum::Router;
1050    use axum::routing::post;
1051    use http::HeaderMap;
1052    use http::header::CONTENT_TYPE;
1053    use serde::Deserialize;
1054    use serde::Serialize;
1055
1056    #[tokio::test]
1057    async fn it_should_pass_json_up_to_be_read() {
1058        #[derive(Deserialize, Serialize)]
1059        struct TestJson {
1060            name: String,
1061            age: u32,
1062        }
1063
1064        // Build an application with a route.
1065        let app = Router::new().route(
1066            "/json",
1067            post(|Json(json): Json<TestJson>| async move {
1068                format!("json: {}, {}", json.name, json.age,)
1069            }),
1070        );
1071
1072        // Run the server.
1073        let server = TestServer::new(app).expect("Should create test server");
1074
1075        // Get the request.
1076        let text = server
1077            .post(&"/json")
1078            .json_from_file(&"files/example.json")
1079            .await
1080            .text();
1081
1082        assert_eq!(text, "json: Joe, 20");
1083    }
1084
1085    #[tokio::test]
1086    async fn it_should_pass_json_content_type_for_json() {
1087        // Build an application with a route.
1088        let app = Router::new().route(
1089            "/content_type",
1090            post(|headers: HeaderMap| async move {
1091                headers
1092                    .get(CONTENT_TYPE)
1093                    .map(|h| h.to_str().unwrap().to_string())
1094                    .unwrap_or_else(|| "".to_string())
1095            }),
1096        );
1097
1098        // Run the server.
1099        let server = TestServer::new(app).expect("Should create test server");
1100
1101        // Get the request.
1102        let text = server
1103            .post(&"/content_type")
1104            .json_from_file(&"files/example.json")
1105            .await
1106            .text();
1107
1108        assert_eq!(text, "application/json");
1109    }
1110}
1111
1112#[cfg(feature = "yaml")]
1113#[cfg(test)]
1114mod test_yaml {
1115    use crate::TestServer;
1116    use axum::Router;
1117    use axum::routing::post;
1118    use axum_yaml::Yaml;
1119    use http::HeaderMap;
1120    use http::header::CONTENT_TYPE;
1121    use serde::Deserialize;
1122    use serde::Serialize;
1123    use serde_json::json;
1124
1125    #[tokio::test]
1126    async fn it_should_pass_yaml_up_to_be_read() {
1127        #[derive(Deserialize, Serialize)]
1128        struct TestYaml {
1129            name: String,
1130            age: u32,
1131            pets: Option<String>,
1132        }
1133
1134        // Build an application with a route.
1135        let app = Router::new().route(
1136            "/yaml",
1137            post(|Yaml(yaml): Yaml<TestYaml>| async move {
1138                format!(
1139                    "yaml: {}, {}, {}",
1140                    yaml.name,
1141                    yaml.age,
1142                    yaml.pets.unwrap_or_else(|| "pandas".to_string())
1143                )
1144            }),
1145        );
1146
1147        // Run the server.
1148        let server = TestServer::new(app).expect("Should create test server");
1149
1150        // Get the request.
1151        let text = server
1152            .post(&"/yaml")
1153            .yaml(&TestYaml {
1154                name: "Joe".to_string(),
1155                age: 20,
1156                pets: Some("foxes".to_string()),
1157            })
1158            .await
1159            .text();
1160
1161        assert_eq!(text, "yaml: Joe, 20, foxes");
1162    }
1163
1164    #[tokio::test]
1165    async fn it_should_pass_yaml_content_type_for_yaml() {
1166        // Build an application with a route.
1167        let app = Router::new().route(
1168            "/content_type",
1169            post(|headers: HeaderMap| async move {
1170                headers
1171                    .get(CONTENT_TYPE)
1172                    .map(|h| h.to_str().unwrap().to_string())
1173                    .unwrap_or_else(|| "".to_string())
1174            }),
1175        );
1176
1177        // Run the server.
1178        let server = TestServer::new(app).expect("Should create test server");
1179
1180        // Get the request.
1181        let text = server.post(&"/content_type").yaml(&json!({})).await.text();
1182
1183        assert_eq!(text, "application/yaml");
1184    }
1185}
1186
1187#[cfg(feature = "yaml")]
1188#[cfg(test)]
1189mod test_yaml_from_file {
1190    use crate::TestServer;
1191    use axum::Router;
1192    use axum::routing::post;
1193    use axum_yaml::Yaml;
1194    use http::HeaderMap;
1195    use http::header::CONTENT_TYPE;
1196    use serde::Deserialize;
1197    use serde::Serialize;
1198
1199    #[tokio::test]
1200    async fn it_should_pass_yaml_up_to_be_read() {
1201        #[derive(Deserialize, Serialize)]
1202        struct TestYaml {
1203            name: String,
1204            age: u32,
1205        }
1206
1207        // Build an application with a route.
1208        let app = Router::new().route(
1209            "/yaml",
1210            post(|Yaml(yaml): Yaml<TestYaml>| async move {
1211                format!("yaml: {}, {}", yaml.name, yaml.age,)
1212            }),
1213        );
1214
1215        // Run the server.
1216        let server = TestServer::new(app).expect("Should create test server");
1217
1218        // Get the request.
1219        let text = server
1220            .post(&"/yaml")
1221            .yaml_from_file(&"files/example.yaml")
1222            .await
1223            .text();
1224
1225        assert_eq!(text, "yaml: Joe, 20");
1226    }
1227
1228    #[tokio::test]
1229    async fn it_should_pass_yaml_content_type_for_yaml() {
1230        // Build an application with a route.
1231        let app = Router::new().route(
1232            "/content_type",
1233            post(|headers: HeaderMap| async move {
1234                headers
1235                    .get(CONTENT_TYPE)
1236                    .map(|h| h.to_str().unwrap().to_string())
1237                    .unwrap_or_else(|| "".to_string())
1238            }),
1239        );
1240
1241        // Run the server.
1242        let server = TestServer::new(app).expect("Should create test server");
1243
1244        // Get the request.
1245        let text = server
1246            .post(&"/content_type")
1247            .yaml_from_file(&"files/example.yaml")
1248            .await
1249            .text();
1250
1251        assert_eq!(text, "application/yaml");
1252    }
1253}
1254
1255#[cfg(feature = "msgpack")]
1256#[cfg(test)]
1257mod test_msgpack {
1258    use crate::TestServer;
1259    use axum::Router;
1260    use axum::routing::post;
1261    use axum_msgpack::MsgPack;
1262    use http::HeaderMap;
1263    use http::header::CONTENT_TYPE;
1264    use serde::Deserialize;
1265    use serde::Serialize;
1266    use serde_json::json;
1267
1268    #[tokio::test]
1269    async fn it_should_pass_msgpack_up_to_be_read() {
1270        #[derive(Deserialize, Serialize)]
1271        struct TestMsgPack {
1272            name: String,
1273            age: u32,
1274            pets: Option<String>,
1275        }
1276
1277        async fn get_msgpack(MsgPack(msgpack): MsgPack<TestMsgPack>) -> String {
1278            format!(
1279                "yaml: {}, {}, {}",
1280                msgpack.name,
1281                msgpack.age,
1282                msgpack.pets.unwrap_or_else(|| "pandas".to_string())
1283            )
1284        }
1285
1286        // Build an application with a route.
1287        let app = Router::new().route("/msgpack", post(get_msgpack));
1288
1289        // Run the server.
1290        let server = TestServer::new(app).expect("Should create test server");
1291
1292        // Get the request.
1293        let text = server
1294            .post(&"/msgpack")
1295            .msgpack(&TestMsgPack {
1296                name: "Joe".to_string(),
1297                age: 20,
1298                pets: Some("foxes".to_string()),
1299            })
1300            .await
1301            .text();
1302
1303        assert_eq!(text, "yaml: Joe, 20, foxes");
1304    }
1305
1306    #[tokio::test]
1307    async fn it_should_pass_msgpck_content_type_for_msgpack() {
1308        async fn get_content_type(headers: HeaderMap) -> String {
1309            headers
1310                .get(CONTENT_TYPE)
1311                .map(|h| h.to_str().unwrap().to_string())
1312                .unwrap_or_else(|| "".to_string())
1313        }
1314
1315        // Build an application with a route.
1316        let app = Router::new().route("/content_type", post(get_content_type));
1317
1318        // Run the server.
1319        let server = TestServer::new(app).expect("Should create test server");
1320
1321        // Get the request.
1322        let text = server
1323            .post(&"/content_type")
1324            .msgpack(&json!({}))
1325            .await
1326            .text();
1327
1328        assert_eq!(text, "application/msgpack");
1329    }
1330}
1331
1332#[cfg(test)]
1333mod test_form {
1334    use crate::TestServer;
1335    use axum::Form;
1336    use axum::Router;
1337    use axum::routing::post;
1338    use http::HeaderMap;
1339    use http::header::CONTENT_TYPE;
1340    use serde::Deserialize;
1341    use serde::Serialize;
1342
1343    #[tokio::test]
1344    async fn it_should_pass_form_up_to_be_read() {
1345        #[derive(Deserialize, Serialize)]
1346        struct TestForm {
1347            name: String,
1348            age: u32,
1349            pets: Option<String>,
1350        }
1351
1352        async fn get_form(Form(form): Form<TestForm>) -> String {
1353            format!(
1354                "form: {}, {}, {}",
1355                form.name,
1356                form.age,
1357                form.pets.unwrap_or_else(|| "pandas".to_string())
1358            )
1359        }
1360
1361        // Build an application with a route.
1362        let app = Router::new().route("/form", post(get_form));
1363
1364        // Run the server.
1365        let server = TestServer::new(app).expect("Should create test server");
1366
1367        // Get the request.
1368        server
1369            .post(&"/form")
1370            .form(&TestForm {
1371                name: "Joe".to_string(),
1372                age: 20,
1373                pets: Some("foxes".to_string()),
1374            })
1375            .await
1376            .assert_text("form: Joe, 20, foxes");
1377    }
1378
1379    #[tokio::test]
1380    async fn it_should_pass_form_content_type_for_form() {
1381        async fn get_content_type(headers: HeaderMap) -> String {
1382            headers
1383                .get(CONTENT_TYPE)
1384                .map(|h| h.to_str().unwrap().to_string())
1385                .unwrap_or_else(|| "".to_string())
1386        }
1387
1388        // Build an application with a route.
1389        let app = Router::new().route("/content_type", post(get_content_type));
1390
1391        // Run the server.
1392        let server = TestServer::new(app).expect("Should create test server");
1393
1394        #[derive(Serialize)]
1395        struct MyForm {
1396            message: String,
1397        }
1398
1399        // Get the request.
1400        server
1401            .post(&"/content_type")
1402            .form(&MyForm {
1403                message: "hello".to_string(),
1404            })
1405            .await
1406            .assert_text("application/x-www-form-urlencoded");
1407    }
1408}
1409
1410#[cfg(test)]
1411mod test_bytes {
1412    use crate::TestServer;
1413    use axum::Router;
1414    use axum::extract::Request;
1415    use axum::routing::post;
1416    use http::HeaderMap;
1417    use http::header::CONTENT_TYPE;
1418    use http_body_util::BodyExt;
1419
1420    #[tokio::test]
1421    async fn it_should_pass_bytes_up_to_be_read() {
1422        // Build an application with a route.
1423        let app = Router::new().route(
1424            "/bytes",
1425            post(|request: Request| async move {
1426                let body_bytes = request
1427                    .into_body()
1428                    .collect()
1429                    .await
1430                    .expect("Should read body to bytes")
1431                    .to_bytes();
1432
1433                String::from_utf8_lossy(&body_bytes).to_string()
1434            }),
1435        );
1436
1437        // Run the server.
1438        let server = TestServer::new(app).expect("Should create test server");
1439
1440        // Get the request.
1441        let text = server
1442            .post(&"/bytes")
1443            .bytes("hello!".as_bytes().into())
1444            .await
1445            .text();
1446
1447        assert_eq!(text, "hello!");
1448    }
1449
1450    #[tokio::test]
1451    async fn it_should_not_change_content_type() {
1452        let app = Router::new().route(
1453            "/content_type",
1454            post(|headers: HeaderMap| async move {
1455                headers
1456                    .get(CONTENT_TYPE)
1457                    .map(|h| h.to_str().unwrap().to_string())
1458                    .unwrap_or_else(|| "".to_string())
1459            }),
1460        );
1461
1462        // Run the server.
1463        let server = TestServer::new(app).expect("Should create test server");
1464
1465        // Get the request.
1466        let text = server
1467            .post(&"/content_type")
1468            .content_type(&"application/testing")
1469            .bytes("hello!".as_bytes().into())
1470            .await
1471            .text();
1472
1473        assert_eq!(text, "application/testing");
1474    }
1475}
1476
1477#[cfg(test)]
1478mod test_bytes_from_file {
1479    use crate::TestServer;
1480    use axum::Router;
1481    use axum::extract::Request;
1482    use axum::routing::post;
1483    use http::HeaderMap;
1484    use http::header::CONTENT_TYPE;
1485    use http_body_util::BodyExt;
1486
1487    #[tokio::test]
1488    async fn it_should_pass_bytes_up_to_be_read() {
1489        // Build an application with a route.
1490        let app = Router::new().route(
1491            "/bytes",
1492            post(|request: Request| async move {
1493                let body_bytes = request
1494                    .into_body()
1495                    .collect()
1496                    .await
1497                    .expect("Should read body to bytes")
1498                    .to_bytes();
1499
1500                String::from_utf8_lossy(&body_bytes).to_string()
1501            }),
1502        );
1503
1504        // Run the server.
1505        let server = TestServer::new(app).expect("Should create test server");
1506
1507        // Get the request.
1508        let text = server
1509            .post(&"/bytes")
1510            .bytes_from_file(&"files/example.txt")
1511            .await
1512            .text();
1513
1514        assert_eq!(text, "hello!");
1515    }
1516
1517    #[tokio::test]
1518    async fn it_should_not_change_content_type() {
1519        let app = Router::new().route(
1520            "/content_type",
1521            post(|headers: HeaderMap| async move {
1522                headers
1523                    .get(CONTENT_TYPE)
1524                    .map(|h| h.to_str().unwrap().to_string())
1525                    .unwrap_or_else(|| "".to_string())
1526            }),
1527        );
1528
1529        // Run the server.
1530        let server = TestServer::new(app).expect("Should create test server");
1531
1532        // Get the request.
1533        let text = server
1534            .post(&"/content_type")
1535            .content_type(&"application/testing")
1536            .bytes_from_file(&"files/example.txt")
1537            .await
1538            .text();
1539
1540        assert_eq!(text, "application/testing");
1541    }
1542}
1543
1544#[cfg(test)]
1545mod test_text {
1546    use crate::TestServer;
1547    use axum::Router;
1548    use axum::extract::Request;
1549    use axum::routing::post;
1550    use http::HeaderMap;
1551    use http::header::CONTENT_TYPE;
1552    use http_body_util::BodyExt;
1553
1554    #[tokio::test]
1555    async fn it_should_pass_text_up_to_be_read() {
1556        // Build an application with a route.
1557        let app = Router::new().route(
1558            "/text",
1559            post(|request: Request| async move {
1560                let body_bytes = request
1561                    .into_body()
1562                    .collect()
1563                    .await
1564                    .expect("Should read body to bytes")
1565                    .to_bytes();
1566
1567                String::from_utf8_lossy(&body_bytes).to_string()
1568            }),
1569        );
1570
1571        // Run the server.
1572        let server = TestServer::new(app).expect("Should create test server");
1573
1574        // Get the request.
1575        let text = server.post(&"/text").text(&"hello!").await.text();
1576
1577        assert_eq!(text, "hello!");
1578    }
1579
1580    #[tokio::test]
1581    async fn it_should_pass_text_content_type_for_text() {
1582        let app = Router::new().route(
1583            "/content_type",
1584            post(|headers: HeaderMap| async move {
1585                headers
1586                    .get(CONTENT_TYPE)
1587                    .map(|h| h.to_str().unwrap().to_string())
1588                    .unwrap_or_else(|| "".to_string())
1589            }),
1590        );
1591
1592        // Run the server.
1593        let server = TestServer::new(app).expect("Should create test server");
1594
1595        // Get the request.
1596        let text = server.post(&"/content_type").text(&"hello!").await.text();
1597
1598        assert_eq!(text, "text/plain");
1599    }
1600
1601    #[tokio::test]
1602    async fn it_should_pass_large_text_blobs_over_mock_http() {
1603        const LARGE_BLOB_SIZE: usize = 16777216; // 16mb
1604        let large_blob = (0..LARGE_BLOB_SIZE).map(|_| "X").collect::<String>();
1605
1606        // Build an application with a route.
1607        let app = Router::new().route(
1608            "/text",
1609            post(|request: Request| async move {
1610                let body_bytes = request
1611                    .into_body()
1612                    .collect()
1613                    .await
1614                    .expect("Should read body to bytes")
1615                    .to_bytes();
1616
1617                String::from_utf8_lossy(&body_bytes).to_string()
1618            }),
1619        );
1620
1621        // Run the server.
1622        let server = TestServer::builder()
1623            .mock_transport()
1624            .build(app)
1625            .expect("Should create test server");
1626
1627        // Get the request.
1628        let text = server.post(&"/text").text(&large_blob).await.text();
1629
1630        assert_eq!(text.len(), LARGE_BLOB_SIZE);
1631        assert_eq!(text, large_blob);
1632    }
1633
1634    #[tokio::test]
1635    async fn it_should_pass_large_text_blobs_over_http() {
1636        const LARGE_BLOB_SIZE: usize = 16777216; // 16mb
1637        let large_blob = (0..LARGE_BLOB_SIZE).map(|_| "X").collect::<String>();
1638
1639        // Build an application with a route.
1640        let app = Router::new().route(
1641            "/text",
1642            post(|request: Request| async move {
1643                let body_bytes = request
1644                    .into_body()
1645                    .collect()
1646                    .await
1647                    .expect("Should read body to bytes")
1648                    .to_bytes();
1649
1650                String::from_utf8_lossy(&body_bytes).to_string()
1651            }),
1652        );
1653
1654        // Run the server.
1655        let server = TestServer::builder()
1656            .http_transport()
1657            .build(app)
1658            .expect("Should create test server");
1659
1660        // Get the request.
1661        let text = server.post(&"/text").text(&large_blob).await.text();
1662
1663        assert_eq!(text.len(), LARGE_BLOB_SIZE);
1664        assert_eq!(text, large_blob);
1665    }
1666}
1667
1668#[cfg(test)]
1669mod test_text_from_file {
1670    use crate::TestServer;
1671    use axum::Router;
1672    use axum::extract::Request;
1673    use axum::routing::post;
1674    use http::HeaderMap;
1675    use http::header::CONTENT_TYPE;
1676    use http_body_util::BodyExt;
1677
1678    #[tokio::test]
1679    async fn it_should_pass_text_up_to_be_read() {
1680        // Build an application with a route.
1681        let app = Router::new().route(
1682            "/text",
1683            post(|request: Request| async move {
1684                let body_bytes = request
1685                    .into_body()
1686                    .collect()
1687                    .await
1688                    .expect("Should read body to bytes")
1689                    .to_bytes();
1690
1691                String::from_utf8_lossy(&body_bytes).to_string()
1692            }),
1693        );
1694
1695        // Run the server.
1696        let server = TestServer::new(app).expect("Should create test server");
1697
1698        // Get the request.
1699        let text = server
1700            .post(&"/text")
1701            .text_from_file(&"files/example.txt")
1702            .await
1703            .text();
1704
1705        assert_eq!(text, "hello!");
1706    }
1707
1708    #[tokio::test]
1709    async fn it_should_pass_text_content_type_for_text() {
1710        // Build an application with a route.
1711        let app = Router::new().route(
1712            "/content_type",
1713            post(|headers: HeaderMap| async move {
1714                headers
1715                    .get(CONTENT_TYPE)
1716                    .map(|h| h.to_str().unwrap().to_string())
1717                    .unwrap_or_else(|| "".to_string())
1718            }),
1719        );
1720
1721        // Run the server.
1722        let server = TestServer::new(app).expect("Should create test server");
1723
1724        // Get the request.
1725        let text = server
1726            .post(&"/content_type")
1727            .text_from_file(&"files/example.txt")
1728            .await
1729            .text();
1730
1731        assert_eq!(text, "text/plain");
1732    }
1733}
1734
1735#[cfg(test)]
1736mod test_expect_success {
1737    use crate::TestServer;
1738    use axum::Router;
1739    use axum::routing::get;
1740    use http::StatusCode;
1741
1742    #[tokio::test]
1743    async fn it_should_not_panic_if_success_is_returned() {
1744        async fn get_ping() -> &'static str {
1745            "pong!"
1746        }
1747
1748        // Build an application with a route.
1749        let app = Router::new().route("/ping", get(get_ping));
1750
1751        // Run the server.
1752        let server = TestServer::new(app).expect("Should create test server");
1753
1754        // Get the request.
1755        server.get(&"/ping").expect_success().await;
1756    }
1757
1758    #[tokio::test]
1759    async fn it_should_not_panic_on_other_2xx_status_code() {
1760        async fn get_accepted() -> StatusCode {
1761            StatusCode::ACCEPTED
1762        }
1763
1764        // Build an application with a route.
1765        let app = Router::new().route("/accepted", get(get_accepted));
1766
1767        // Run the server.
1768        let server = TestServer::new(app).expect("Should create test server");
1769
1770        // Get the request.
1771        server.get(&"/accepted").expect_success().await;
1772    }
1773
1774    #[tokio::test]
1775    #[should_panic]
1776    async fn it_should_panic_on_404() {
1777        // Build an application with a route.
1778        let app = Router::new();
1779
1780        // Run the server.
1781        let server = TestServer::new(app).expect("Should create test server");
1782
1783        // Get the request.
1784        server.get(&"/some_unknown_route").expect_success().await;
1785    }
1786
1787    #[tokio::test]
1788    async fn it_should_override_what_test_server_has_set() {
1789        async fn get_ping() -> &'static str {
1790            "pong!"
1791        }
1792
1793        // Build an application with a route.
1794        let app = Router::new().route("/ping", get(get_ping));
1795
1796        // Run the server.
1797        let mut server = TestServer::new(app).expect("Should create test server");
1798        server.expect_failure();
1799
1800        // Get the request.
1801        server.get(&"/ping").expect_success().await;
1802    }
1803}
1804
1805#[cfg(test)]
1806mod test_expect_failure {
1807    use crate::TestServer;
1808    use axum::Router;
1809    use axum::routing::get;
1810    use http::StatusCode;
1811
1812    #[tokio::test]
1813    async fn it_should_not_panic_if_expect_failure_on_404() {
1814        // Build an application with a route.
1815        let app = Router::new();
1816
1817        // Run the server.
1818        let server = TestServer::new(app).expect("Should create test server");
1819
1820        // Get the request.
1821        server.get(&"/some_unknown_route").expect_failure().await;
1822    }
1823
1824    #[tokio::test]
1825    #[should_panic]
1826    async fn it_should_panic_if_success_is_returned() {
1827        async fn get_ping() -> &'static str {
1828            "pong!"
1829        }
1830
1831        // Build an application with a route.
1832        let app = Router::new().route("/ping", get(get_ping));
1833
1834        // Run the server.
1835        let server = TestServer::new(app).expect("Should create test server");
1836
1837        // Get the request.
1838        server.get(&"/ping").expect_failure().await;
1839    }
1840
1841    #[tokio::test]
1842    #[should_panic]
1843    async fn it_should_panic_on_other_2xx_status_code() {
1844        async fn get_accepted() -> StatusCode {
1845            StatusCode::ACCEPTED
1846        }
1847
1848        // Build an application with a route.
1849        let app = Router::new().route("/accepted", get(get_accepted));
1850
1851        // Run the server.
1852        let server = TestServer::new(app).expect("Should create test server");
1853
1854        // Get the request.
1855        server.get(&"/accepted").expect_failure().await;
1856    }
1857
1858    #[tokio::test]
1859    async fn it_should_should_override_what_test_server_has_set() {
1860        // Build an application with a route.
1861        let app = Router::new();
1862
1863        // Run the server.
1864        let mut server = TestServer::new(app).expect("Should create test server");
1865        server.expect_success();
1866
1867        // Get the request.
1868        server.get(&"/some_unknown_route").expect_failure().await;
1869    }
1870}
1871
1872#[cfg(test)]
1873mod test_add_cookie {
1874    use crate::TestServer;
1875    use axum::Router;
1876    use axum::routing::get;
1877    use axum_extra::extract::cookie::CookieJar;
1878    use cookie::Cookie;
1879    use cookie::time::Duration;
1880    use cookie::time::OffsetDateTime;
1881
1882    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
1883
1884    async fn get_cookie(cookies: CookieJar) -> (CookieJar, String) {
1885        let cookie = cookies.get(&TEST_COOKIE_NAME);
1886        let cookie_value = cookie
1887            .map(|c| c.value().to_string())
1888            .unwrap_or_else(|| "cookie-not-found".to_string());
1889
1890        (cookies, cookie_value)
1891    }
1892
1893    #[tokio::test]
1894    async fn it_should_send_cookies_added_to_request() {
1895        let app = Router::new().route("/cookie", get(get_cookie));
1896        let server = TestServer::new(app).expect("Should create test server");
1897
1898        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1899        let response_text = server.get(&"/cookie").add_cookie(cookie).await.text();
1900        assert_eq!(response_text, "my-custom-cookie");
1901    }
1902
1903    #[tokio::test]
1904    async fn it_should_send_non_expired_cookies_added_to_request() {
1905        let app = Router::new().route("/cookie", get(get_cookie));
1906        let server = TestServer::new(app).expect("Should create test server");
1907
1908        let mut cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1909        cookie.set_expires(
1910            OffsetDateTime::now_utc()
1911                .checked_add(Duration::minutes(10))
1912                .unwrap(),
1913        );
1914        let response_text = server.get(&"/cookie").add_cookie(cookie).await.text();
1915        assert_eq!(response_text, "my-custom-cookie");
1916    }
1917
1918    #[tokio::test]
1919    async fn it_should_not_send_expired_cookies_added_to_request() {
1920        let app = Router::new().route("/cookie", get(get_cookie));
1921        let server = TestServer::new(app).expect("Should create test server");
1922
1923        let mut cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
1924        cookie.set_expires(OffsetDateTime::now_utc());
1925        let response_text = server.get(&"/cookie").add_cookie(cookie).await.text();
1926        assert_eq!(response_text, "cookie-not-found");
1927    }
1928}
1929
1930#[cfg(test)]
1931mod test_add_cookies {
1932    use crate::TestServer;
1933    use axum::Router;
1934    use axum::http::header::HeaderMap;
1935    use axum::routing::get;
1936    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
1937    use cookie::Cookie;
1938    use cookie::CookieJar;
1939    use cookie::SameSite;
1940
1941    async fn route_get_cookies(cookies: AxumCookieJar) -> String {
1942        let mut all_cookies = cookies
1943            .iter()
1944            .map(|cookie| format!("{}={}", cookie.name(), cookie.value()))
1945            .collect::<Vec<String>>();
1946        all_cookies.sort();
1947
1948        all_cookies.join(&", ")
1949    }
1950
1951    async fn get_cookie_headers_joined(headers: HeaderMap) -> String {
1952        let cookies: String = headers
1953            .get_all("cookie")
1954            .into_iter()
1955            .map(|c| c.to_str().unwrap_or("").to_string())
1956            .reduce(|a, b| a + "; " + &b)
1957            .unwrap_or_else(|| String::new());
1958
1959        cookies
1960    }
1961
1962    #[tokio::test]
1963    async fn it_should_send_all_cookies_added_by_jar() {
1964        let app = Router::new().route("/cookies", get(route_get_cookies));
1965        let server = TestServer::new(app).expect("Should create test server");
1966
1967        // Build cookies to send up
1968        let cookie_1 = Cookie::new("first-cookie", "my-custom-cookie");
1969        let cookie_2 = Cookie::new("second-cookie", "other-cookie");
1970        let mut cookie_jar = CookieJar::new();
1971        cookie_jar.add(cookie_1);
1972        cookie_jar.add(cookie_2);
1973
1974        server
1975            .get(&"/cookies")
1976            .add_cookies(cookie_jar)
1977            .await
1978            .assert_text("first-cookie=my-custom-cookie, second-cookie=other-cookie");
1979    }
1980
1981    #[tokio::test]
1982    async fn it_should_send_all_cookies_stripped_by_their_attributes() {
1983        let app = Router::new().route("/cookies", get(get_cookie_headers_joined));
1984        let server = TestServer::new(app).expect("Should create test server");
1985
1986        const TEST_COOKIE_NAME: &'static str = &"test-cookie";
1987        const TEST_COOKIE_VALUE: &'static str = &"my-custom-cookie";
1988
1989        // Build cookie to send up
1990        let cookie = Cookie::build((TEST_COOKIE_NAME, TEST_COOKIE_VALUE))
1991            .http_only(true)
1992            .secure(true)
1993            .same_site(SameSite::Strict)
1994            .path("/cookie")
1995            .build();
1996        let mut cookie_jar = CookieJar::new();
1997        cookie_jar.add(cookie);
1998
1999        server
2000            .get(&"/cookies")
2001            .add_cookies(cookie_jar)
2002            .await
2003            .assert_text(format!("{}={}", TEST_COOKIE_NAME, TEST_COOKIE_VALUE));
2004    }
2005}
2006
2007#[cfg(test)]
2008mod test_save_cookies {
2009    use crate::TestServer;
2010    use axum::Router;
2011    use axum::extract::Request;
2012    use axum::http::header::HeaderMap;
2013    use axum::routing::get;
2014    use axum::routing::put;
2015    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
2016    use cookie::Cookie;
2017    use cookie::SameSite;
2018    use http_body_util::BodyExt;
2019
2020    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
2021
2022    async fn put_cookie_with_attributes(
2023        mut cookies: AxumCookieJar,
2024        request: Request,
2025    ) -> (AxumCookieJar, &'static str) {
2026        let body_bytes = request
2027            .into_body()
2028            .collect()
2029            .await
2030            .expect("Should turn the body into bytes")
2031            .to_bytes();
2032
2033        let body_text: String = String::from_utf8_lossy(&body_bytes).to_string();
2034        let cookie = Cookie::build((TEST_COOKIE_NAME, body_text))
2035            .http_only(true)
2036            .secure(true)
2037            .same_site(SameSite::Strict)
2038            .path("/cookie")
2039            .build();
2040        cookies = cookies.add(cookie);
2041
2042        (cookies, &"done")
2043    }
2044
2045    async fn get_cookie_headers_joined(headers: HeaderMap) -> String {
2046        let cookies: String = headers
2047            .get_all("cookie")
2048            .into_iter()
2049            .map(|c| c.to_str().unwrap_or("").to_string())
2050            .reduce(|a, b| a + "; " + &b)
2051            .unwrap_or_else(|| String::new());
2052
2053        cookies
2054    }
2055
2056    #[tokio::test]
2057    async fn it_should_strip_cookies_from_their_attributes() {
2058        let app = Router::new()
2059            .route("/cookie", put(put_cookie_with_attributes))
2060            .route("/cookie", get(get_cookie_headers_joined));
2061        let server = TestServer::new(app).expect("Should create test server");
2062
2063        // Create a cookie.
2064        server
2065            .put(&"/cookie")
2066            .text(&"cookie-found!")
2067            .save_cookies()
2068            .await;
2069
2070        // Check, only the cookie names and their values should come back.
2071        let response_text = server.get(&"/cookie").await.text();
2072
2073        assert_eq!(response_text, format!("{}=cookie-found!", TEST_COOKIE_NAME));
2074    }
2075}
2076
2077#[cfg(test)]
2078mod test_do_not_save_cookies {
2079    use crate::TestServer;
2080    use axum::Router;
2081    use axum::extract::Request;
2082    use axum::http::header::HeaderMap;
2083    use axum::routing::get;
2084    use axum::routing::put;
2085    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
2086    use cookie::Cookie;
2087    use cookie::SameSite;
2088    use http_body_util::BodyExt;
2089
2090    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
2091
2092    async fn put_cookie_with_attributes(
2093        mut cookies: AxumCookieJar,
2094        request: Request,
2095    ) -> (AxumCookieJar, &'static str) {
2096        let body_bytes = request
2097            .into_body()
2098            .collect()
2099            .await
2100            .expect("Should turn the body into bytes")
2101            .to_bytes();
2102
2103        let body_text: String = String::from_utf8_lossy(&body_bytes).to_string();
2104        let cookie = Cookie::build((TEST_COOKIE_NAME, body_text))
2105            .http_only(true)
2106            .secure(true)
2107            .same_site(SameSite::Strict)
2108            .path("/cookie")
2109            .build();
2110        cookies = cookies.add(cookie);
2111
2112        (cookies, &"done")
2113    }
2114
2115    async fn get_cookie_headers_joined(headers: HeaderMap) -> String {
2116        let cookies: String = headers
2117            .get_all("cookie")
2118            .into_iter()
2119            .map(|c| c.to_str().unwrap_or("").to_string())
2120            .reduce(|a, b| a + "; " + &b)
2121            .unwrap_or_else(|| String::new());
2122
2123        cookies
2124    }
2125
2126    #[tokio::test]
2127    async fn it_should_not_save_cookies_when_set() {
2128        let app = Router::new()
2129            .route("/cookie", put(put_cookie_with_attributes))
2130            .route("/cookie", get(get_cookie_headers_joined));
2131        let server = TestServer::new(app).expect("Should create test server");
2132
2133        // Create a cookie.
2134        server
2135            .put(&"/cookie")
2136            .text(&"cookie-found!")
2137            .do_not_save_cookies()
2138            .await;
2139
2140        // Check, only the cookie names and their values should come back.
2141        let response_text = server.get(&"/cookie").await.text();
2142
2143        assert_eq!(response_text, "");
2144    }
2145
2146    #[tokio::test]
2147    async fn it_should_override_test_server_and_not_save_cookies_when_set() {
2148        let app = Router::new()
2149            .route("/cookie", put(put_cookie_with_attributes))
2150            .route("/cookie", get(get_cookie_headers_joined));
2151        let server = TestServer::builder()
2152            .save_cookies()
2153            .build(app)
2154            .expect("Should create test server");
2155
2156        // Create a cookie.
2157        server
2158            .put(&"/cookie")
2159            .text(&"cookie-found!")
2160            .do_not_save_cookies()
2161            .await;
2162
2163        // Check, only the cookie names and their values should come back.
2164        let response_text = server.get(&"/cookie").await.text();
2165
2166        assert_eq!(response_text, "");
2167    }
2168}
2169
2170#[cfg(test)]
2171mod test_clear_cookies {
2172    use crate::TestServer;
2173    use axum::Router;
2174    use axum::extract::Request;
2175    use axum::routing::get;
2176    use axum::routing::put;
2177    use axum_extra::extract::cookie::Cookie as AxumCookie;
2178    use axum_extra::extract::cookie::CookieJar as AxumCookieJar;
2179    use cookie::Cookie;
2180    use cookie::CookieJar;
2181    use http_body_util::BodyExt;
2182
2183    const TEST_COOKIE_NAME: &'static str = &"test-cookie";
2184
2185    async fn get_cookie(cookies: AxumCookieJar) -> (AxumCookieJar, String) {
2186        let cookie = cookies.get(&TEST_COOKIE_NAME);
2187        let cookie_value = cookie
2188            .map(|c| c.value().to_string())
2189            .unwrap_or_else(|| "cookie-not-found".to_string());
2190
2191        (cookies, cookie_value)
2192    }
2193
2194    async fn put_cookie(
2195        mut cookies: AxumCookieJar,
2196        request: Request,
2197    ) -> (AxumCookieJar, &'static str) {
2198        let body_bytes = request
2199            .into_body()
2200            .collect()
2201            .await
2202            .expect("Should turn the body into bytes")
2203            .to_bytes();
2204
2205        let body_text: String = String::from_utf8_lossy(&body_bytes).to_string();
2206        let cookie = AxumCookie::new(TEST_COOKIE_NAME, body_text);
2207        cookies = cookies.add(cookie);
2208
2209        (cookies, &"done")
2210    }
2211
2212    #[tokio::test]
2213    async fn it_should_clear_cookie_added_to_request() {
2214        let app = Router::new().route("/cookie", get(get_cookie));
2215        let server = TestServer::new(app).expect("Should create test server");
2216
2217        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
2218        let response_text = server
2219            .get(&"/cookie")
2220            .add_cookie(cookie)
2221            .clear_cookies()
2222            .await
2223            .text();
2224
2225        assert_eq!(response_text, "cookie-not-found");
2226    }
2227
2228    #[tokio::test]
2229    async fn it_should_clear_cookie_jar_added_to_request() {
2230        let app = Router::new().route("/cookie", get(get_cookie));
2231        let server = TestServer::new(app).expect("Should create test server");
2232
2233        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
2234        let mut cookie_jar = CookieJar::new();
2235        cookie_jar.add(cookie);
2236
2237        let response_text = server
2238            .get(&"/cookie")
2239            .add_cookies(cookie_jar)
2240            .clear_cookies()
2241            .await
2242            .text();
2243
2244        assert_eq!(response_text, "cookie-not-found");
2245    }
2246
2247    #[tokio::test]
2248    async fn it_should_clear_cookies_saved_by_past_request() {
2249        let app = Router::new()
2250            .route("/cookie", put(put_cookie))
2251            .route("/cookie", get(get_cookie));
2252        let server = TestServer::new(app).expect("Should create test server");
2253
2254        // Create a cookie.
2255        server
2256            .put(&"/cookie")
2257            .text(&"cookie-found!")
2258            .save_cookies()
2259            .await;
2260
2261        // Check it comes back.
2262        let response_text = server.get(&"/cookie").clear_cookies().await.text();
2263
2264        assert_eq!(response_text, "cookie-not-found");
2265    }
2266
2267    #[tokio::test]
2268    async fn it_should_clear_cookies_added_to_test_server() {
2269        let app = Router::new()
2270            .route("/cookie", put(put_cookie))
2271            .route("/cookie", get(get_cookie));
2272        let mut server = TestServer::new(app).expect("Should create test server");
2273
2274        let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie");
2275        server.add_cookie(cookie);
2276
2277        // Check it comes back.
2278        let response_text = server.get(&"/cookie").clear_cookies().await.text();
2279
2280        assert_eq!(response_text, "cookie-not-found");
2281    }
2282}
2283
2284#[cfg(test)]
2285mod test_add_header {
2286    use super::*;
2287    use crate::TestServer;
2288    use axum::Router;
2289    use axum::extract::FromRequestParts;
2290    use axum::routing::get;
2291    use http::HeaderName;
2292    use http::HeaderValue;
2293    use http::request::Parts;
2294    use hyper::StatusCode;
2295    use std::marker::Sync;
2296
2297    const TEST_HEADER_NAME: &'static str = &"test-header";
2298    const TEST_HEADER_CONTENT: &'static str = &"Test header content";
2299
2300    struct TestHeader(Vec<u8>);
2301
2302    impl<S: Sync> FromRequestParts<S> for TestHeader {
2303        type Rejection = (StatusCode, &'static str);
2304
2305        async fn from_request_parts(
2306            parts: &mut Parts,
2307            _state: &S,
2308        ) -> Result<TestHeader, Self::Rejection> {
2309            parts
2310                .headers
2311                .get(HeaderName::from_static(TEST_HEADER_NAME))
2312                .map(|v| TestHeader(v.as_bytes().to_vec()))
2313                .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
2314        }
2315    }
2316
2317    async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
2318        header
2319    }
2320
2321    #[tokio::test]
2322    async fn it_should_send_header_added_to_request() {
2323        // Build an application with a route.
2324        let app = Router::new().route("/header", get(ping_header));
2325
2326        // Run the server.
2327        let server = TestServer::new(app).expect("Should create test server");
2328
2329        // Send a request with the header
2330        let response = server
2331            .get(&"/header")
2332            .add_header(
2333                HeaderName::from_static(TEST_HEADER_NAME),
2334                HeaderValue::from_static(TEST_HEADER_CONTENT),
2335            )
2336            .await;
2337
2338        // Check it sent back the right text
2339        response.assert_text(TEST_HEADER_CONTENT)
2340    }
2341}
2342
2343#[cfg(test)]
2344mod test_authorization {
2345    use super::*;
2346    use crate::TestServer;
2347    use axum::Router;
2348    use axum::extract::FromRequestParts;
2349    use axum::routing::get;
2350    use http::request::Parts;
2351    use hyper::StatusCode;
2352    use std::marker::Sync;
2353
2354    fn new_test_server() -> TestServer {
2355        struct TestHeader(String);
2356
2357        impl<S: Sync> FromRequestParts<S> for TestHeader {
2358            type Rejection = (StatusCode, &'static str);
2359
2360            async fn from_request_parts(
2361                parts: &mut Parts,
2362                _state: &S,
2363            ) -> Result<TestHeader, Self::Rejection> {
2364                parts
2365                    .headers
2366                    .get(header::AUTHORIZATION)
2367                    .map(|v| TestHeader(v.to_str().unwrap().to_string()))
2368                    .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
2369            }
2370        }
2371
2372        async fn ping_auth_header(TestHeader(header): TestHeader) -> String {
2373            header
2374        }
2375
2376        // Build an application with a route.
2377        let app = Router::new().route("/auth-header", get(ping_auth_header));
2378
2379        // Run the server.
2380        let mut server = TestServer::new(app).expect("Should create test server");
2381        server.expect_success();
2382
2383        server
2384    }
2385
2386    #[tokio::test]
2387    async fn it_should_send_header_added_to_request() {
2388        let server = new_test_server();
2389
2390        // Send a request with the header
2391        let response = server
2392            .get(&"/auth-header")
2393            .authorization("Bearer abc123")
2394            .await;
2395
2396        // Check it sent back the right text
2397        response.assert_text("Bearer abc123")
2398    }
2399}
2400
2401#[cfg(test)]
2402mod test_authorization_bearer {
2403    use super::*;
2404    use crate::TestServer;
2405    use axum::Router;
2406    use axum::extract::FromRequestParts;
2407    use axum::routing::get;
2408    use http::request::Parts;
2409    use hyper::StatusCode;
2410    use std::marker::Sync;
2411
2412    fn new_test_server() -> TestServer {
2413        struct TestHeader(String);
2414
2415        impl<S: Sync> FromRequestParts<S> for TestHeader {
2416            type Rejection = (StatusCode, &'static str);
2417
2418            async fn from_request_parts(
2419                parts: &mut Parts,
2420                _state: &S,
2421            ) -> Result<TestHeader, Self::Rejection> {
2422                parts
2423                    .headers
2424                    .get(header::AUTHORIZATION)
2425                    .map(|v| TestHeader(v.to_str().unwrap().to_string().replace("Bearer ", "")))
2426                    .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
2427            }
2428        }
2429
2430        async fn ping_auth_header(TestHeader(header): TestHeader) -> String {
2431            header
2432        }
2433
2434        // Build an application with a route.
2435        let app = Router::new().route("/auth-header", get(ping_auth_header));
2436
2437        // Run the server.
2438        let mut server = TestServer::new(app).expect("Should create test server");
2439        server.expect_success();
2440
2441        server
2442    }
2443
2444    #[tokio::test]
2445    async fn it_should_send_header_added_to_request() {
2446        let server = new_test_server();
2447
2448        // Send a request with the header
2449        let response = server
2450            .get(&"/auth-header")
2451            .authorization_bearer("abc123")
2452            .await;
2453
2454        // Check it sent back the right text
2455        response.assert_text("abc123")
2456    }
2457}
2458
2459#[cfg(test)]
2460mod test_clear_headers {
2461    use super::*;
2462    use crate::TestServer;
2463    use axum::Router;
2464    use axum::extract::FromRequestParts;
2465    use axum::routing::get;
2466    use http::HeaderName;
2467    use http::HeaderValue;
2468    use http::request::Parts;
2469    use hyper::StatusCode;
2470    use std::marker::Sync;
2471
2472    const TEST_HEADER_NAME: &'static str = &"test-header";
2473    const TEST_HEADER_CONTENT: &'static str = &"Test header content";
2474
2475    struct TestHeader(Vec<u8>);
2476
2477    impl<S: Sync> FromRequestParts<S> for TestHeader {
2478        type Rejection = (StatusCode, &'static str);
2479
2480        async fn from_request_parts(
2481            parts: &mut Parts,
2482            _state: &S,
2483        ) -> Result<TestHeader, Self::Rejection> {
2484            parts
2485                .headers
2486                .get(HeaderName::from_static(TEST_HEADER_NAME))
2487                .map(|v| TestHeader(v.as_bytes().to_vec()))
2488                .ok_or((StatusCode::BAD_REQUEST, "Missing test header"))
2489        }
2490    }
2491
2492    async fn ping_header(TestHeader(header): TestHeader) -> Vec<u8> {
2493        header
2494    }
2495
2496    #[tokio::test]
2497    async fn it_should_clear_headers_added_to_request() {
2498        // Build an application with a route.
2499        let app = Router::new().route("/header", get(ping_header));
2500
2501        // Run the server.
2502        let server = TestServer::new(app).expect("Should create test server");
2503
2504        // Send a request with the header
2505        let response = server
2506            .get(&"/header")
2507            .add_header(
2508                HeaderName::from_static(TEST_HEADER_NAME),
2509                HeaderValue::from_static(TEST_HEADER_CONTENT),
2510            )
2511            .clear_headers()
2512            .await;
2513
2514        // Check it sent back the right text
2515        response.assert_status_bad_request();
2516        response.assert_text("Missing test header");
2517    }
2518
2519    #[tokio::test]
2520    async fn it_should_clear_headers_added_to_server() {
2521        // Build an application with a route.
2522        let app = Router::new().route("/header", get(ping_header));
2523
2524        // Run the server.
2525        let mut server = TestServer::new(app).expect("Should create test server");
2526        server.add_header(
2527            HeaderName::from_static(TEST_HEADER_NAME),
2528            HeaderValue::from_static(TEST_HEADER_CONTENT),
2529        );
2530
2531        // Send a request with the header
2532        let response = server.get(&"/header").clear_headers().await;
2533
2534        // Check it sent back the right text
2535        response.assert_status_bad_request();
2536        response.assert_text("Missing test header");
2537    }
2538}
2539
2540#[cfg(test)]
2541mod test_add_query_params {
2542    use crate::TestServer;
2543    use axum::Router;
2544    use axum::extract::Query as AxumStdQuery;
2545    use axum::routing::get;
2546    use serde::Deserialize;
2547    use serde::Serialize;
2548    use serde_json::json;
2549
2550    #[derive(Debug, Deserialize, Serialize)]
2551    struct QueryParam {
2552        message: String,
2553    }
2554
2555    async fn get_query_param(AxumStdQuery(params): AxumStdQuery<QueryParam>) -> String {
2556        params.message
2557    }
2558
2559    #[derive(Debug, Deserialize, Serialize)]
2560    struct QueryParam2 {
2561        message: String,
2562        other: String,
2563    }
2564
2565    async fn get_query_param_2(AxumStdQuery(params): AxumStdQuery<QueryParam2>) -> String {
2566        format!("{}-{}", params.message, params.other)
2567    }
2568
2569    fn build_app() -> Router {
2570        Router::new()
2571            .route("/query", get(get_query_param))
2572            .route("/query-2", get(get_query_param_2))
2573    }
2574
2575    #[tokio::test]
2576    async fn it_should_pass_up_query_params_from_serialization() {
2577        // Run the server.
2578        let server = TestServer::new(build_app()).expect("Should create test server");
2579
2580        // Get the request.
2581        server
2582            .get(&"/query")
2583            .add_query_params(QueryParam {
2584                message: "it works".to_string(),
2585            })
2586            .await
2587            .assert_text(&"it works");
2588    }
2589
2590    #[tokio::test]
2591    async fn it_should_pass_up_query_params_from_pairs() {
2592        // Run the server.
2593        let server = TestServer::new(build_app()).expect("Should create test server");
2594
2595        // Get the request.
2596        server
2597            .get(&"/query")
2598            .add_query_params(&[("message", "it works")])
2599            .await
2600            .assert_text(&"it works");
2601    }
2602
2603    #[tokio::test]
2604    async fn it_should_pass_up_multiple_query_params_from_multiple_params() {
2605        // Run the server.
2606        let server = TestServer::new(build_app()).expect("Should create test server");
2607
2608        // Get the request.
2609        server
2610            .get(&"/query-2")
2611            .add_query_params(&[("message", "it works"), ("other", "yup")])
2612            .await
2613            .assert_text(&"it works-yup");
2614    }
2615
2616    #[tokio::test]
2617    async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
2618        // Run the server.
2619        let server = TestServer::new(build_app()).expect("Should create test server");
2620
2621        // Get the request.
2622        server
2623            .get(&"/query-2")
2624            .add_query_params(&[("message", "it works")])
2625            .add_query_params(&[("other", "yup")])
2626            .await
2627            .assert_text(&"it works-yup");
2628    }
2629
2630    #[tokio::test]
2631    async fn it_should_pass_up_multiple_query_params_from_json() {
2632        // Run the server.
2633        let server = TestServer::new(build_app()).expect("Should create test server");
2634
2635        // Get the request.
2636        server
2637            .get(&"/query-2")
2638            .add_query_params(json!({
2639                "message": "it works",
2640                "other": "yup"
2641            }))
2642            .await
2643            .assert_text(&"it works-yup");
2644    }
2645}
2646
2647#[cfg(test)]
2648mod test_add_raw_query_param {
2649    use crate::TestServer;
2650    use axum::Router;
2651    use axum::extract::Query as AxumStdQuery;
2652    use axum::routing::get;
2653    use axum_extra::extract::Query as AxumExtraQuery;
2654    use serde::Deserialize;
2655    use serde::Serialize;
2656    use std::fmt::Write;
2657
2658    #[derive(Debug, Deserialize, Serialize)]
2659    struct QueryParam {
2660        message: String,
2661    }
2662
2663    async fn get_query_param(AxumStdQuery(params): AxumStdQuery<QueryParam>) -> String {
2664        params.message
2665    }
2666
2667    #[derive(Debug, Deserialize, Serialize)]
2668    struct QueryParamExtra {
2669        #[serde(default)]
2670        items: Vec<String>,
2671
2672        #[serde(default, rename = "arrs[]")]
2673        arrs: Vec<String>,
2674    }
2675
2676    async fn get_query_param_extra(
2677        AxumExtraQuery(params): AxumExtraQuery<QueryParamExtra>,
2678    ) -> String {
2679        let mut output = String::new();
2680
2681        if params.items.len() > 0 {
2682            write!(output, "{}", params.items.join(", ")).unwrap();
2683        }
2684
2685        if params.arrs.len() > 0 {
2686            write!(output, "{}", params.arrs.join(", ")).unwrap();
2687        }
2688
2689        output
2690    }
2691
2692    fn build_app() -> Router {
2693        Router::new()
2694            .route("/query", get(get_query_param))
2695            .route("/query-extra", get(get_query_param_extra))
2696    }
2697
2698    #[tokio::test]
2699    async fn it_should_pass_up_query_param_as_is() {
2700        // Run the server.
2701        let server = TestServer::new(build_app()).expect("Should create test server");
2702
2703        // Get the request.
2704        server
2705            .get(&"/query")
2706            .add_raw_query_param(&"message=it-works")
2707            .await
2708            .assert_text(&"it-works");
2709    }
2710
2711    #[tokio::test]
2712    async fn it_should_pass_up_array_query_params_as_one_string() {
2713        // Run the server.
2714        let server = TestServer::new(build_app()).expect("Should create test server");
2715
2716        // Get the request.
2717        server
2718            .get(&"/query-extra")
2719            .add_raw_query_param(&"items=one&items=two&items=three")
2720            .await
2721            .assert_text(&"one, two, three");
2722    }
2723
2724    #[tokio::test]
2725    async fn it_should_pass_up_array_query_params_as_multiple_params() {
2726        // Run the server.
2727        let server = TestServer::new(build_app()).expect("Should create test server");
2728
2729        // Get the request.
2730        server
2731            .get(&"/query-extra")
2732            .add_raw_query_param(&"arrs[]=one")
2733            .add_raw_query_param(&"arrs[]=two")
2734            .add_raw_query_param(&"arrs[]=three")
2735            .await
2736            .assert_text(&"one, two, three");
2737    }
2738}
2739
2740#[cfg(test)]
2741mod test_add_query_param {
2742    use crate::TestServer;
2743    use axum::Router;
2744    use axum::extract::Query;
2745    use axum::routing::get;
2746    use serde::Deserialize;
2747    use serde::Serialize;
2748
2749    #[derive(Debug, Deserialize, Serialize)]
2750    struct QueryParam {
2751        message: String,
2752    }
2753
2754    async fn get_query_param(Query(params): Query<QueryParam>) -> String {
2755        params.message
2756    }
2757
2758    #[derive(Debug, Deserialize, Serialize)]
2759    struct QueryParam2 {
2760        message: String,
2761        other: String,
2762    }
2763
2764    async fn get_query_param_2(Query(params): Query<QueryParam2>) -> String {
2765        format!("{}-{}", params.message, params.other)
2766    }
2767
2768    #[tokio::test]
2769    async fn it_should_pass_up_query_params_from_pairs() {
2770        // Build an application with a route.
2771        let app = Router::new().route("/query", get(get_query_param));
2772
2773        // Run the server.
2774        let server = TestServer::new(app).expect("Should create test server");
2775
2776        // Get the request.
2777        server
2778            .get(&"/query")
2779            .add_query_param("message", "it works")
2780            .await
2781            .assert_text(&"it works");
2782    }
2783
2784    #[tokio::test]
2785    async fn it_should_pass_up_multiple_query_params_from_multiple_calls() {
2786        // Build an application with a route.
2787        let app = Router::new().route("/query-2", get(get_query_param_2));
2788
2789        // Run the server.
2790        let server = TestServer::new(app).expect("Should create test server");
2791
2792        // Get the request.
2793        server
2794            .get(&"/query-2")
2795            .add_query_param("message", "it works")
2796            .add_query_param("other", "yup")
2797            .await
2798            .assert_text(&"it works-yup");
2799    }
2800}
2801
2802#[cfg(test)]
2803mod test_clear_query_params {
2804    use crate::TestServer;
2805    use axum::Router;
2806    use axum::extract::Query;
2807    use axum::routing::get;
2808    use serde::Deserialize;
2809    use serde::Serialize;
2810
2811    #[derive(Debug, Deserialize, Serialize)]
2812    struct QueryParams {
2813        first: Option<String>,
2814        second: Option<String>,
2815    }
2816
2817    async fn get_query_params(Query(params): Query<QueryParams>) -> String {
2818        format!(
2819            "has first? {}, has second? {}",
2820            params.first.is_some(),
2821            params.second.is_some()
2822        )
2823    }
2824
2825    #[tokio::test]
2826    async fn it_should_clear_all_params_set() {
2827        // Build an application with a route.
2828        let app = Router::new().route("/query", get(get_query_params));
2829
2830        // Run the server.
2831        let server = TestServer::new(app).expect("Should create test server");
2832
2833        // Get the request.
2834        server
2835            .get(&"/query")
2836            .add_query_params(QueryParams {
2837                first: Some("first".to_string()),
2838                second: Some("second".to_string()),
2839            })
2840            .clear_query_params()
2841            .await
2842            .assert_text(&"has first? false, has second? false");
2843    }
2844
2845    #[tokio::test]
2846    async fn it_should_clear_all_params_set_and_allow_replacement() {
2847        // Build an application with a route.
2848        let app = Router::new().route("/query", get(get_query_params));
2849
2850        // Run the server.
2851        let server = TestServer::new(app).expect("Should create test server");
2852
2853        // Get the request.
2854        server
2855            .get(&"/query")
2856            .add_query_params(QueryParams {
2857                first: Some("first".to_string()),
2858                second: Some("second".to_string()),
2859            })
2860            .clear_query_params()
2861            .add_query_params(QueryParams {
2862                first: Some("first".to_string()),
2863                second: Some("second".to_string()),
2864            })
2865            .await
2866            .assert_text(&"has first? true, has second? true");
2867    }
2868}
2869
2870#[cfg(test)]
2871mod test_scheme {
2872    use crate::TestServer;
2873    use axum::Router;
2874    use axum::extract::Request;
2875    use axum::routing::get;
2876
2877    async fn route_get_scheme(request: Request) -> String {
2878        request.uri().scheme_str().unwrap().to_string()
2879    }
2880
2881    #[tokio::test]
2882    async fn it_should_return_http_by_default() {
2883        let router = Router::new().route("/scheme", get(route_get_scheme));
2884        let server = TestServer::builder().build(router).unwrap();
2885
2886        server.get("/scheme").await.assert_text("http");
2887    }
2888
2889    #[tokio::test]
2890    async fn it_should_return_http_when_set() {
2891        let router = Router::new().route("/scheme", get(route_get_scheme));
2892        let server = TestServer::builder().build(router).unwrap();
2893
2894        server
2895            .get("/scheme")
2896            .scheme(&"http")
2897            .await
2898            .assert_text("http");
2899    }
2900
2901    #[tokio::test]
2902    async fn it_should_return_https_when_set() {
2903        let router = Router::new().route("/scheme", get(route_get_scheme));
2904        let server = TestServer::builder().build(router).unwrap();
2905
2906        server
2907            .get("/scheme")
2908            .scheme(&"https")
2909            .await
2910            .assert_text("https");
2911    }
2912
2913    #[tokio::test]
2914    async fn it_should_override_test_server_when_set() {
2915        let router = Router::new().route("/scheme", get(route_get_scheme));
2916
2917        let mut server = TestServer::builder().build(router).unwrap();
2918        server.scheme(&"https");
2919
2920        server
2921            .get("/scheme")
2922            .scheme(&"http") // set it back to http
2923            .await
2924            .assert_text("http");
2925    }
2926}
2927
2928#[cfg(test)]
2929mod test_multipart {
2930    use crate::TestServer;
2931    use crate::multipart::MultipartForm;
2932    use crate::multipart::Part;
2933    use axum::Json;
2934    use axum::Router;
2935    use axum::extract::Multipart;
2936    use axum::routing::post;
2937    use serde_json::Value;
2938    use serde_json::json;
2939
2940    async fn route_post_multipart(mut multipart: Multipart) -> Json<Vec<String>> {
2941        let mut fields = vec![];
2942
2943        while let Some(field) = multipart.next_field().await.unwrap() {
2944            let name = field.name().unwrap().to_string();
2945            let content_type = field.content_type().unwrap().to_owned();
2946            let data = field.bytes().await.unwrap();
2947
2948            let field_stats = format!("{name} is {} bytes, {content_type}", data.len());
2949            fields.push(field_stats);
2950        }
2951
2952        Json(fields)
2953    }
2954
2955    async fn route_post_multipart_headers(mut multipart: Multipart) -> Json<Vec<Value>> {
2956        let mut sent_part_headers = vec![];
2957
2958        while let Some(field) = multipart.next_field().await.unwrap() {
2959            let part_name = field.name().unwrap().to_string();
2960            let part_header_value = field
2961                .headers()
2962                .get("x-part-header-test")
2963                .unwrap()
2964                .to_str()
2965                .unwrap()
2966                .to_string();
2967            let part_text = String::from_utf8(field.bytes().await.unwrap().into()).unwrap();
2968
2969            sent_part_headers.push(json!({
2970                "name": part_name,
2971                "text": part_text,
2972                "header": part_header_value,
2973            }))
2974        }
2975
2976        Json(sent_part_headers)
2977    }
2978
2979    fn test_router() -> Router {
2980        Router::new()
2981            .route("/multipart", post(route_post_multipart))
2982            .route("/multipart_headers", post(route_post_multipart_headers))
2983    }
2984
2985    #[tokio::test]
2986    async fn it_should_get_multipart_stats_on_mock_transport() {
2987        // Run the server.
2988        let server = TestServer::builder()
2989            .mock_transport()
2990            .build(test_router())
2991            .expect("Should create test server");
2992
2993        let form = MultipartForm::new()
2994            .add_text("penguins?", "lots")
2995            .add_text("animals", "🦊🦊🦊")
2996            .add_text("carrots", 123 as u32);
2997
2998        // Get the request.
2999        server
3000            .post(&"/multipart")
3001            .multipart(form)
3002            .await
3003            .assert_json(&vec![
3004                "penguins? is 4 bytes, text/plain".to_string(),
3005                "animals is 12 bytes, text/plain".to_string(),
3006                "carrots is 3 bytes, text/plain".to_string(),
3007            ]);
3008    }
3009
3010    #[tokio::test]
3011    async fn it_should_get_multipart_stats_on_http_transport() {
3012        // Run the server.
3013        let server = TestServer::builder()
3014            .http_transport()
3015            .build(test_router())
3016            .expect("Should create test server");
3017
3018        let form = MultipartForm::new()
3019            .add_text("penguins?", "lots")
3020            .add_text("animals", "🦊🦊🦊")
3021            .add_text("carrots", 123 as u32);
3022
3023        // Get the request.
3024        server
3025            .post(&"/multipart")
3026            .multipart(form)
3027            .await
3028            .assert_json(&vec![
3029                "penguins? is 4 bytes, text/plain".to_string(),
3030                "animals is 12 bytes, text/plain".to_string(),
3031                "carrots is 3 bytes, text/plain".to_string(),
3032            ]);
3033    }
3034
3035    #[tokio::test]
3036    async fn it_should_send_text_parts_as_text() {
3037        // Run the server.
3038        let server = TestServer::builder()
3039            .mock_transport()
3040            .build(test_router())
3041            .expect("Should create test server");
3042
3043        let form = MultipartForm::new().add_part("animals", Part::text("🦊🦊🦊"));
3044
3045        // Get the request.
3046        server
3047            .post(&"/multipart")
3048            .multipart(form)
3049            .await
3050            .assert_json(&vec!["animals is 12 bytes, text/plain".to_string()]);
3051    }
3052
3053    #[tokio::test]
3054    async fn it_should_send_custom_mime_type() {
3055        // Run the server.
3056        let server = TestServer::builder()
3057            .mock_transport()
3058            .build(test_router())
3059            .expect("Should create test server");
3060
3061        let form = MultipartForm::new().add_part(
3062            "animals",
3063            Part::bytes("🦊,🦊,🦊".as_bytes()).mime_type(mime::TEXT_CSV),
3064        );
3065
3066        // Get the request.
3067        server
3068            .post(&"/multipart")
3069            .multipart(form)
3070            .await
3071            .assert_json(&vec!["animals is 14 bytes, text/csv".to_string()]);
3072    }
3073
3074    #[tokio::test]
3075    async fn it_should_send_using_include_bytes() {
3076        // Run the server.
3077        let server = TestServer::builder()
3078            .mock_transport()
3079            .build(test_router())
3080            .expect("Should create test server");
3081
3082        let form = MultipartForm::new().add_part(
3083            "file",
3084            Part::bytes(include_bytes!("../files/example.txt").as_slice())
3085                .mime_type(mime::TEXT_PLAIN),
3086        );
3087
3088        // Get the request.
3089        server
3090            .post(&"/multipart")
3091            .multipart(form)
3092            .await
3093            .assert_json(&vec!["file is 6 bytes, text/plain".to_string()]);
3094    }
3095
3096    #[tokio::test]
3097    async fn it_should_send_form_headers_in_parts() {
3098        // Run the server.
3099        let server = TestServer::builder()
3100            .mock_transport()
3101            .build(test_router())
3102            .expect("Should create test server");
3103
3104        let form = MultipartForm::new()
3105            .add_part(
3106                "part_1",
3107                Part::text("part_1_text").add_header("x-part-header-test", "part_1_header"),
3108            )
3109            .add_part(
3110                "part_2",
3111                Part::text("part_2_text").add_header("x-part-header-test", "part_2_header"),
3112            );
3113
3114        // Get the request.
3115        server
3116            .post(&"/multipart_headers")
3117            .multipart(form)
3118            .await
3119            .assert_json(&json!([
3120                {
3121                    "name": "part_1",
3122                    "text": "part_1_text",
3123                    "header": "part_1_header",
3124                },
3125                {
3126                    "name": "part_2",
3127                    "text": "part_2_text",
3128                    "header": "part_2_header",
3129                },
3130            ]));
3131    }
3132}