1use axum::{
2 Router,
3 body::{Body, Bytes, to_bytes},
4};
5use http::{HeaderMap, HeaderName, HeaderValue, Method, Request, header::CONTENT_TYPE};
6use serde::Serialize;
7use std::{error::Error, fmt};
8use tower::ServiceExt;
9
10use crate::response::TestResponse;
11
12pub struct TestRequest {
38 router: Router,
39 method: Method,
40 path: String,
41 body: Body,
42 headers: HeaderMap,
43 content_type: Option<&'static str>,
44}
45
46impl TestRequest {
47 pub(crate) fn new(router: Router, method: Method, path: String) -> Self {
48 Self {
49 router,
50 method,
51 path,
52 body: Body::empty(),
53 headers: HeaderMap::new(),
54 content_type: None,
55 }
56 }
57
58 pub fn header<N, V>(mut self, name: N, value: V) -> Self
63 where
64 N: TryInto<HeaderName>,
65 N::Error: Into<http::Error>,
66 V: TryInto<HeaderValue>,
67 V::Error: Into<http::Error>,
68 {
69 self = self
70 .try_header(name, value)
71 .expect("test request header was invalid");
72 self
73 }
74
75 pub fn try_header<N, V>(mut self, name: N, value: V) -> std::result::Result<Self, http::Error>
77 where
78 N: TryInto<HeaderName>,
79 N::Error: Into<http::Error>,
80 V: TryInto<HeaderValue>,
81 V::Error: Into<http::Error>,
82 {
83 let name = name.try_into().map_err(Into::into)?;
84 let value = value.try_into().map_err(Into::into)?;
85 self.headers.insert(name, value);
86 Ok(self)
87 }
88
89 pub fn text(mut self, body: impl Into<String>) -> Self {
93 self.body = Body::from(body.into());
94 self.content_type = Some("text/plain; charset=utf-8");
95 self
96 }
97
98 pub fn body(mut self, body: impl Into<Bytes>) -> Self {
102 self.body = Body::from(body.into());
103 self
104 }
105
106 pub fn json<T: Serialize>(self, body: &T) -> Self {
110 self.try_json(body).expect("test JSON serialization failed")
111 }
112
113 pub fn try_json<T: Serialize>(
115 mut self,
116 body: &T,
117 ) -> std::result::Result<Self, serde_json::Error> {
118 self.body = Body::from(serde_json::to_vec(body)?);
119 self.content_type = Some("application/json");
120 Ok(self)
121 }
122
123 pub fn query<T: Serialize>(mut self, query: &T) -> Self {
128 self = self
129 .try_query(query)
130 .expect("test query serialization failed");
131 self
132 }
133
134 pub fn try_query<T: Serialize>(
136 mut self,
137 query: &T,
138 ) -> std::result::Result<Self, serde_urlencoded::ser::Error> {
139 let query = serde_urlencoded::to_string(query)?;
140 if !query.is_empty() {
141 self.path = append_query(&self.path, &query);
142 }
143 Ok(self)
144 }
145
146 pub async fn send(self) -> TestResponse {
148 self.try_send().await.expect("test request send failed")
149 }
150
151 pub async fn try_send(self) -> Result<TestResponse, TestRequestError> {
153 let mut builder = Request::builder().method(self.method).uri(self.path);
154 if let Some(content_type) = self.content_type {
155 builder = builder.header(CONTENT_TYPE, content_type);
156 }
157 for (name, value) in self.headers {
158 if let Some(name) = name {
159 builder = builder.header(name, value);
160 }
161 }
162 let request = builder.body(self.body).map_err(TestRequestError::Request)?;
163 let response = match self.router.oneshot(request).await {
164 Ok(response) => response,
165 Err(error) => match error {},
166 };
167 let status = response.status();
168 let headers = response.headers().clone();
169 let body = to_bytes(response.into_body(), usize::MAX)
170 .await
171 .map_err(TestRequestError::Body)?;
172
173 Ok(TestResponse::new(status, headers, body))
174 }
175}
176
177#[derive(Debug)]
179pub enum TestRequestError {
180 Request(http::Error),
182 Body(axum::Error),
184}
185
186impl fmt::Display for TestRequestError {
187 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
188 match self {
189 Self::Request(error) => write!(formatter, "test request build failed: {error}"),
190 Self::Body(error) => write!(formatter, "test response body read failed: {error}"),
191 }
192 }
193}
194
195impl Error for TestRequestError {
196 fn source(&self) -> Option<&(dyn Error + 'static)> {
197 match self {
198 Self::Request(error) => Some(error),
199 Self::Body(error) => Some(error),
200 }
201 }
202}
203
204fn append_query(path: &str, query: &str) -> String {
205 let separator = if path.contains('?') && !path.ends_with('?') && !path.ends_with('&') {
206 "&"
207 } else if path.contains('?') {
208 ""
209 } else {
210 "?"
211 };
212 format!("{path}{separator}{query}")
213}