1use bytes::Bytes;
39use http::StatusCode;
40use std::net::SocketAddr;
41use tokio::net::TcpListener;
42
43pub struct TestClient {
44 client: reqwest::Client,
45 addr: SocketAddr,
46}
47
48impl TestClient {
49 pub async fn new(svc: axum::Router) -> Self {
50 let listener = TcpListener::bind("127.0.0.1:0")
51 .await
52 .expect("Could not bind ephemeral socket");
53 let addr = listener.local_addr().unwrap();
54 #[cfg(feature = "withtrace")]
55 println!("Listening on {}", addr);
56
57 tokio::spawn(async move {
58 let server = axum::serve(listener, svc);
59 server.await.expect("server error");
60 });
61
62 #[cfg(feature = "cookies")]
63 let client = reqwest::Client::builder()
64 .redirect(reqwest::redirect::Policy::none())
65 .cookie_store(true)
66 .build()
67 .unwrap();
68
69 #[cfg(not(feature = "cookies"))]
70 let client = reqwest::Client::builder()
71 .redirect(reqwest::redirect::Policy::none())
72 .build()
73 .unwrap();
74
75 TestClient { client, addr }
76 }
77
78 pub fn base_url(&self) -> String {
83 format!("http://{}", self.addr)
84 }
85
86 pub fn get(&self, url: &str) -> RequestBuilder {
87 RequestBuilder {
88 builder: self.client.get(format!("http://{}{}", self.addr, url)),
89 }
90 }
91
92 pub fn head(&self, url: &str) -> RequestBuilder {
93 RequestBuilder {
94 builder: self.client.head(format!("http://{}{}", self.addr, url)),
95 }
96 }
97
98 pub fn post(&self, url: &str) -> RequestBuilder {
99 RequestBuilder {
100 builder: self.client.post(format!("http://{}{}", self.addr, url)),
101 }
102 }
103
104 pub fn put(&self, url: &str) -> RequestBuilder {
105 RequestBuilder {
106 builder: self.client.put(format!("http://{}{}", self.addr, url)),
107 }
108 }
109
110 pub fn patch(&self, url: &str) -> RequestBuilder {
111 RequestBuilder {
112 builder: self.client.patch(format!("http://{}{}", self.addr, url)),
113 }
114 }
115
116 pub fn delete(&self, url: &str) -> RequestBuilder {
117 RequestBuilder {
118 builder: self.client.delete(format!("http://{}{}", self.addr, url)),
119 }
120 }
121}
122
123pub struct RequestBuilder {
124 builder: reqwest::RequestBuilder,
125}
126
127impl RequestBuilder {
128 pub async fn send(self) -> TestResponse {
129 TestResponse {
130 response: self.builder.send().await.unwrap(),
131 }
132 }
133
134 pub fn body(mut self, body: impl Into<reqwest::Body>) -> Self {
135 self.builder = self.builder.body(body);
136 self
137 }
138
139 pub fn form<T: serde::Serialize + ?Sized>(mut self, form: &T) -> Self {
140 self.builder = self.builder.form(&form);
141 self
142 }
143
144 pub fn json<T>(mut self, json: &T) -> Self
145 where
146 T: serde::Serialize,
147 {
148 self.builder = self.builder.json(json);
149 self
150 }
151
152 pub fn header(mut self, key: &str, value: &str) -> Self {
153 self.builder = self.builder.header(key, value);
154 self
155 }
156
157 pub fn multipart(mut self, form: reqwest::multipart::Form) -> Self {
158 self.builder = self.builder.multipart(form);
159 self
160 }
161}
162
163pub struct TestResponse {
169 response: reqwest::Response,
170}
171
172impl TestResponse {
173 pub async fn text(self) -> String {
174 self.response.text().await.unwrap()
175 }
176
177 #[allow(dead_code)]
178 pub async fn bytes(self) -> Bytes {
179 self.response.bytes().await.unwrap()
180 }
181
182 pub async fn json<T>(self) -> T
183 where
184 T: serde::de::DeserializeOwned,
185 {
186 self.response.json().await.unwrap()
187 }
188
189 pub fn status(&self) -> StatusCode {
190 StatusCode::from_u16(self.response.status().as_u16()).unwrap()
191 }
192
193 pub fn headers(&self) -> &reqwest::header::HeaderMap {
194 self.response.headers()
195 }
196
197 pub async fn chunk(&mut self) -> Option<Bytes> {
198 self.response.chunk().await.unwrap()
199 }
200
201 pub async fn chunk_text(&mut self) -> Option<String> {
202 let chunk = self.chunk().await?;
203 Some(String::from_utf8(chunk.to_vec()).unwrap())
204 }
205
206 pub fn into_inner(self) -> reqwest::Response {
208 self.response
209 }
210}
211
212impl AsRef<reqwest::Response> for TestResponse {
213 fn as_ref(&self) -> &reqwest::Response {
214 &self.response
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use axum::response::Html;
221 use axum::{routing::get, routing::post, Json, Router};
222 use http::StatusCode;
223 use serde::{Deserialize, Serialize};
224
225 #[derive(Deserialize)]
226 struct FooForm {
227 val: String,
228 }
229
230 async fn handle_form(axum::Form(form): axum::Form<FooForm>) -> (StatusCode, Html<String>) {
231 (StatusCode::OK, Html(form.val))
232 }
233
234 #[tokio::test]
235 async fn test_get_request() {
236 let app = Router::new().route("/", get(|| async {}));
237 let client = super::TestClient::new(app).await;
238 let res = client.get("/").send().await;
239 assert_eq!(res.status(), StatusCode::OK);
240 }
241
242 #[tokio::test]
243 async fn test_post_form_request() {
244 let app = Router::new().route("/", post(handle_form));
245 let client = super::TestClient::new(app).await;
246 let form = [("val", "bar"), ("baz", "quux")];
247 let res = client.post("/").form(&form).send().await;
248 assert_eq!(res.status(), StatusCode::OK);
249 assert_eq!(res.text().await, "bar");
250 }
251
252 #[derive(Debug, Serialize, Deserialize, PartialEq)]
253 struct TestPayload {
254 name: String,
255 age: i32,
256 }
257
258 #[tokio::test]
259 async fn test_post_request_with_json() {
260 let app = Router::new().route(
261 "/",
262 post(|json_value: Json<serde_json::Value>| async { json_value }),
263 );
264 let client = super::TestClient::new(app).await;
265 let payload = TestPayload {
266 name: "Alice".to_owned(),
267 age: 30,
268 };
269 let res = client
270 .post("/")
271 .header("Content-Type", "application/json")
272 .json(&payload)
273 .send()
274 .await;
275 assert_eq!(res.status(), StatusCode::OK);
276 let response_body: TestPayload = serde_json::from_str(&res.text().await).unwrap();
277 assert_eq!(response_body, payload);
278 }
279}