axum_test_helpers/
test_client.rs1use super::{serve, Request, Response};
2use bytes::Bytes;
3use futures_util::future::BoxFuture;
4use http::{
5 header::{HeaderName, HeaderValue},
6 StatusCode,
7};
8use std::{convert::Infallible, future::IntoFuture, net::SocketAddr, str::FromStr};
9use tokio::net::TcpListener;
10use tower::make::Shared;
11use tower_service::Service;
12
13pub fn spawn_service<S>(svc: S) -> SocketAddr
14where
15 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
16 S::Future: Send,
17{
18 let std_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
19 std_listener.set_nonblocking(true).unwrap();
20 let listener = TcpListener::from_std(std_listener).unwrap();
21
22 let addr = listener.local_addr().unwrap();
23 println!("Listening on {addr}");
24
25 tokio::spawn(async move {
26 serve(listener, Shared::new(svc))
27 .await
28 .expect("server error")
29 });
30
31 addr
32}
33
34pub struct TestClient {
35 client: reqwest::Client,
36 addr: SocketAddr,
37}
38
39impl TestClient {
40 pub fn new<S>(svc: S) -> Self
41 where
42 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
43 S::Future: Send,
44 {
45 let addr = spawn_service(svc);
46
47 let client = reqwest::Client::builder()
48 .redirect(reqwest::redirect::Policy::none())
49 .build()
50 .unwrap();
51
52 TestClient { client, addr }
53 }
54
55 pub fn get(&self, url: &str) -> RequestBuilder {
56 RequestBuilder {
57 builder: self.client.get(format!("http://{}{}", self.addr, url)),
58 }
59 }
60
61 pub fn head(&self, url: &str) -> RequestBuilder {
62 RequestBuilder {
63 builder: self.client.head(format!("http://{}{}", self.addr, url)),
64 }
65 }
66
67 pub fn post(&self, url: &str) -> RequestBuilder {
68 RequestBuilder {
69 builder: self.client.post(format!("http://{}{}", self.addr, url)),
70 }
71 }
72
73 #[allow(dead_code)]
74 pub fn put(&self, url: &str) -> RequestBuilder {
75 RequestBuilder {
76 builder: self.client.put(format!("http://{}{}", self.addr, url)),
77 }
78 }
79
80 #[allow(dead_code)]
81 pub fn patch(&self, url: &str) -> RequestBuilder {
82 RequestBuilder {
83 builder: self.client.patch(format!("http://{}{}", self.addr, url)),
84 }
85 }
86
87 pub fn delete(&self, url: &str) -> RequestBuilder {
88 RequestBuilder {
89 builder: self.client.delete(format!("http://{}{}", self.addr, url)),
90 }
91 }
92}
93
94pub struct RequestBuilder {
95 builder: reqwest::RequestBuilder,
96}
97
98impl RequestBuilder {
99 pub fn body(mut self, body: impl Into<reqwest::Body>) -> Self {
100 self.builder = self.builder.body(body);
101 self
102 }
103
104 pub fn json<T>(mut self, json: &T) -> Self
105 where
106 T: serde::Serialize,
107 {
108 self.builder = self.builder.json(json);
109 self
110 }
111
112 pub fn header<K, V>(mut self, key: K, value: V) -> Self
113 where
114 HeaderName: TryFrom<K>,
115 <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
116 HeaderValue: TryFrom<V>,
117 <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
118 {
119 let key: HeaderName = key.try_into().map_err(Into::into).unwrap();
121 let key = reqwest::header::HeaderName::from_bytes(key.as_ref()).unwrap();
122
123 let value: HeaderValue = value.try_into().map_err(Into::into).unwrap();
124 let value = reqwest::header::HeaderValue::from_bytes(value.as_bytes()).unwrap();
125
126 self.builder = self.builder.header(key, value);
127
128 self
129 }
130
131 #[allow(dead_code)]
132 pub fn multipart(mut self, form: reqwest::multipart::Form) -> Self {
133 self.builder = self.builder.multipart(form);
134 self
135 }
136}
137
138impl IntoFuture for RequestBuilder {
139 type Output = TestResponse;
140 type IntoFuture = BoxFuture<'static, Self::Output>;
141
142 fn into_future(self) -> Self::IntoFuture {
143 Box::pin(async {
144 TestResponse {
145 response: self.builder.send().await.unwrap(),
146 }
147 })
148 }
149}
150
151#[derive(Debug)]
152pub struct TestResponse {
153 response: reqwest::Response,
154}
155
156impl TestResponse {
157 #[allow(dead_code)]
158 pub async fn bytes(self) -> Bytes {
159 self.response.bytes().await.unwrap()
160 }
161
162 pub async fn text(self) -> String {
163 self.response.text().await.unwrap()
164 }
165
166 #[allow(dead_code)]
167 pub async fn json<T>(self) -> T
168 where
169 T: serde::de::DeserializeOwned,
170 {
171 self.response.json().await.unwrap()
172 }
173
174 pub fn status(&self) -> StatusCode {
175 StatusCode::from_u16(self.response.status().as_u16()).unwrap()
176 }
177
178 pub fn headers(&self) -> http::HeaderMap {
179 let mut headers = http::HeaderMap::new();
181 for (key, value) in self.response.headers() {
182 let key = http::HeaderName::from_str(key.as_str()).unwrap();
183 let value = http::HeaderValue::from_bytes(value.as_bytes()).unwrap();
184 headers.insert(key, value);
185 }
186 headers
187 }
188
189 pub async fn chunk(&mut self) -> Option<Bytes> {
190 self.response.chunk().await.unwrap()
191 }
192
193 pub async fn chunk_text(&mut self) -> Option<String> {
194 let chunk = self.chunk().await?;
195 Some(String::from_utf8(chunk.to_vec()).unwrap())
196 }
197}