axum_test_helper/
lib.rs

1//! # Axum Test Helper
2//! This is a hard copy from TestClient at axum
3//!
4//! ## Features
5//! - `cookies` - Enables support for cookies in the test client.
6//! - `withouttrace` - Disables tracing for the test client.
7//!
8//! ## Example
9//! ```rust
10//! use axum::Router;
11//! use axum::http::StatusCode;
12//! use axum::routing::get;
13//! use axum_test_helper::TestClient;
14//!
15//! fn main() {
16//!     let async_block = async {
17//!         // you can replace this Router with your own app
18//!         let app = Router::new().route("/", get(|| async {}));
19//!
20//!         // initiate the TestClient with the previous declared Router
21//!         let client = TestClient::new(app);
22//!
23//!         let res = client.get("/").send().await;
24//!         assert_eq!(res.status(), StatusCode::OK);
25//!     };
26//!
27//!     // Create a runtime for executing the async block. This runtime is local
28//!     // to the main function and does not require any global setup.
29//!     let runtime = tokio::runtime::Builder::new_current_thread()
30//!         .enable_all()
31//!         .build()
32//!         .unwrap();
33//!
34//!     // Use the local runtime to block on the async block.
35//!     runtime.block_on(async_block);
36//! }
37
38use bytes::Bytes;
39use http::StatusCode;
40use std::net::SocketAddr;
41use tokio::net::TcpListener;
42
43pub struct TestClient {
44    client: reqwest::Client,
45    addr: SocketAddr,
46}
47
48impl TestClient {
49    pub async fn new(svc: axum::Router) -> Self {
50        let listener = TcpListener::bind("127.0.0.1:0")
51            .await
52            .expect("Could not bind ephemeral socket");
53        let addr = listener.local_addr().unwrap();
54        #[cfg(feature = "withtrace")]
55        println!("Listening on {}", addr);
56
57        tokio::spawn(async move {
58            let server = axum::serve(listener, svc);
59            server.await.expect("server error");
60        });
61
62        #[cfg(feature = "cookies")]
63        let client = reqwest::Client::builder()
64            .redirect(reqwest::redirect::Policy::none())
65            .cookie_store(true)
66            .build()
67            .unwrap();
68
69        #[cfg(not(feature = "cookies"))]
70        let client = reqwest::Client::builder()
71            .redirect(reqwest::redirect::Policy::none())
72            .build()
73            .unwrap();
74
75        TestClient { client, addr }
76    }
77
78    /// returns the base URL (http://ip:port) for this TestClient
79    ///
80    /// this is useful when trying to check if Location headers in responses
81    /// are generated correctly as Location contains an absolute URL
82    pub fn base_url(&self) -> String {
83        format!("http://{}", self.addr)
84    }
85
86    pub fn get(&self, url: &str) -> RequestBuilder {
87        RequestBuilder {
88            builder: self.client.get(format!("http://{}{}", self.addr, url)),
89        }
90    }
91
92    pub fn head(&self, url: &str) -> RequestBuilder {
93        RequestBuilder {
94            builder: self.client.head(format!("http://{}{}", self.addr, url)),
95        }
96    }
97
98    pub fn post(&self, url: &str) -> RequestBuilder {
99        RequestBuilder {
100            builder: self.client.post(format!("http://{}{}", self.addr, url)),
101        }
102    }
103
104    pub fn put(&self, url: &str) -> RequestBuilder {
105        RequestBuilder {
106            builder: self.client.put(format!("http://{}{}", self.addr, url)),
107        }
108    }
109
110    pub fn patch(&self, url: &str) -> RequestBuilder {
111        RequestBuilder {
112            builder: self.client.patch(format!("http://{}{}", self.addr, url)),
113        }
114    }
115
116    pub fn delete(&self, url: &str) -> RequestBuilder {
117        RequestBuilder {
118            builder: self.client.delete(format!("http://{}{}", self.addr, url)),
119        }
120    }
121}
122
123pub struct RequestBuilder {
124    builder: reqwest::RequestBuilder,
125}
126
127impl RequestBuilder {
128    pub async fn send(self) -> TestResponse {
129        TestResponse {
130            response: self.builder.send().await.unwrap(),
131        }
132    }
133
134    pub fn body(mut self, body: impl Into<reqwest::Body>) -> Self {
135        self.builder = self.builder.body(body);
136        self
137    }
138
139    pub fn form<T: serde::Serialize + ?Sized>(mut self, form: &T) -> Self {
140        self.builder = self.builder.form(&form);
141        self
142    }
143
144    pub fn json<T>(mut self, json: &T) -> Self
145    where
146        T: serde::Serialize,
147    {
148        self.builder = self.builder.json(json);
149        self
150    }
151
152    pub fn header(mut self, key: &str, value: &str) -> Self {
153        self.builder = self.builder.header(key, value);
154        self
155    }
156
157    pub fn multipart(mut self, form: reqwest::multipart::Form) -> Self {
158        self.builder = self.builder.multipart(form);
159        self
160    }
161}
162
163/// A wrapper around [`reqwest::Response`] that provides common methods with internal `unwrap()`s.
164///
165/// This is conventient for tests where panics are what you want. For access to
166/// non-panicking versions or the complete `Response` API use `into_inner()` or
167/// `as_ref()`.
168pub struct TestResponse {
169    response: reqwest::Response,
170}
171
172impl TestResponse {
173    pub async fn text(self) -> String {
174        self.response.text().await.unwrap()
175    }
176
177    #[allow(dead_code)]
178    pub async fn bytes(self) -> Bytes {
179        self.response.bytes().await.unwrap()
180    }
181
182    pub async fn json<T>(self) -> T
183    where
184        T: serde::de::DeserializeOwned,
185    {
186        self.response.json().await.unwrap()
187    }
188
189    pub fn status(&self) -> StatusCode {
190        StatusCode::from_u16(self.response.status().as_u16()).unwrap()
191    }
192
193    pub fn headers(&self) -> &reqwest::header::HeaderMap {
194        self.response.headers()
195    }
196
197    pub async fn chunk(&mut self) -> Option<Bytes> {
198        self.response.chunk().await.unwrap()
199    }
200
201    pub async fn chunk_text(&mut self) -> Option<String> {
202        let chunk = self.chunk().await?;
203        Some(String::from_utf8(chunk.to_vec()).unwrap())
204    }
205
206    /// Get the inner [`reqwest::Response`] for less convenient but more complete access.
207    pub fn into_inner(self) -> reqwest::Response {
208        self.response
209    }
210}
211
212impl AsRef<reqwest::Response> for TestResponse {
213    fn as_ref(&self) -> &reqwest::Response {
214        &self.response
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use axum::response::Html;
221    use axum::{routing::get, routing::post, Json, Router};
222    use http::StatusCode;
223    use serde::{Deserialize, Serialize};
224
225    #[derive(Deserialize)]
226    struct FooForm {
227        val: String,
228    }
229
230    async fn handle_form(axum::Form(form): axum::Form<FooForm>) -> (StatusCode, Html<String>) {
231        (StatusCode::OK, Html(form.val))
232    }
233
234    #[tokio::test]
235    async fn test_get_request() {
236        let app = Router::new().route("/", get(|| async {}));
237        let client = super::TestClient::new(app).await;
238        let res = client.get("/").send().await;
239        assert_eq!(res.status(), StatusCode::OK);
240    }
241
242    #[tokio::test]
243    async fn test_post_form_request() {
244        let app = Router::new().route("/", post(handle_form));
245        let client = super::TestClient::new(app).await;
246        let form = [("val", "bar"), ("baz", "quux")];
247        let res = client.post("/").form(&form).send().await;
248        assert_eq!(res.status(), StatusCode::OK);
249        assert_eq!(res.text().await, "bar");
250    }
251
252    #[derive(Debug, Serialize, Deserialize, PartialEq)]
253    struct TestPayload {
254        name: String,
255        age: i32,
256    }
257
258    #[tokio::test]
259    async fn test_post_request_with_json() {
260        let app = Router::new().route(
261            "/",
262            post(|json_value: Json<serde_json::Value>| async { json_value }),
263        );
264        let client = super::TestClient::new(app).await;
265        let payload = TestPayload {
266            name: "Alice".to_owned(),
267            age: 30,
268        };
269        let res = client
270            .post("/")
271            .header("Content-Type", "application/json")
272            .json(&payload)
273            .send()
274            .await;
275        assert_eq!(res.status(), StatusCode::OK);
276        let response_body: TestPayload = serde_json::from_str(&res.text().await).unwrap();
277        assert_eq!(response_body, payload);
278    }
279}