axum/test_helpers/
test_client.rs1use super::{serve, Request, Response};
2use bytes::Bytes;
3use futures_util::future::BoxFuture;
4use http::header::{HeaderName, HeaderValue};
5use std::ops::Deref;
6use std::{convert::Infallible, future::IntoFuture, net::SocketAddr};
7use tokio::net::TcpListener;
8use tower::make::Shared;
9use tower_service::Service;
10
11pub(crate) fn spawn_service<S>(svc: S) -> SocketAddr
12where
13 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
14 S::Future: Send,
15{
16 let std_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
17 std_listener.set_nonblocking(true).unwrap();
18 let listener = TcpListener::from_std(std_listener).unwrap();
19
20 let addr = listener.local_addr().unwrap();
21 println!("Listening on {addr}");
22
23 tokio::spawn(async move {
24 serve(listener, Shared::new(svc))
25 .await
26 .expect("server error")
27 });
28
29 addr
30}
31
32pub struct TestClient {
33 client: reqwest::Client,
34 addr: SocketAddr,
35}
36
37impl TestClient {
38 pub fn new<S>(svc: S) -> Self
39 where
40 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
41 S::Future: Send,
42 {
43 let addr = spawn_service(svc);
44
45 let client = reqwest::Client::builder()
46 .redirect(reqwest::redirect::Policy::none())
47 .build()
48 .unwrap();
49
50 TestClient { client, addr }
51 }
52
53 pub fn get(&self, url: &str) -> RequestBuilder {
54 RequestBuilder {
55 builder: self.client.get(format!("http://{}{url}", self.addr)),
56 }
57 }
58
59 pub fn head(&self, url: &str) -> RequestBuilder {
60 RequestBuilder {
61 builder: self.client.head(format!("http://{}{url}", self.addr)),
62 }
63 }
64
65 pub fn post(&self, url: &str) -> RequestBuilder {
66 RequestBuilder {
67 builder: self.client.post(format!("http://{}{url}", self.addr)),
68 }
69 }
70
71 #[allow(dead_code)]
72 pub fn put(&self, url: &str) -> RequestBuilder {
73 RequestBuilder {
74 builder: self.client.put(format!("http://{}{url}", self.addr)),
75 }
76 }
77
78 #[allow(dead_code)]
79 pub fn patch(&self, url: &str) -> RequestBuilder {
80 RequestBuilder {
81 builder: self.client.patch(format!("http://{}{url}", self.addr)),
82 }
83 }
84
85 #[allow(dead_code)]
86 #[must_use]
87 pub fn server_port(&self) -> u16 {
88 self.addr.port()
89 }
90}
91
92#[must_use]
93pub struct RequestBuilder {
94 builder: reqwest::RequestBuilder,
95}
96
97impl RequestBuilder {
98 pub fn body(mut self, body: impl Into<reqwest::Body>) -> Self {
99 self.builder = self.builder.body(body);
100 self
101 }
102
103 pub fn json<T>(mut self, json: &T) -> Self
104 where
105 T: serde_core::Serialize,
106 {
107 self.builder = self.builder.json(json);
108 self
109 }
110
111 pub fn header<K, V>(mut self, key: K, value: V) -> Self
112 where
113 HeaderName: TryFrom<K>,
114 <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
115 HeaderValue: TryFrom<V>,
116 <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
117 {
118 self.builder = self.builder.header(key, value);
119 self
120 }
121
122 #[allow(dead_code)]
123 pub fn multipart(mut self, form: reqwest::multipart::Form) -> Self {
124 self.builder = self.builder.multipart(form);
125 self
126 }
127}
128
129impl IntoFuture for RequestBuilder {
130 type Output = TestResponse;
131 type IntoFuture = BoxFuture<'static, Self::Output>;
132
133 fn into_future(self) -> Self::IntoFuture {
134 Box::pin(async {
135 TestResponse {
136 response: self.builder.send().await.unwrap(),
137 }
138 })
139 }
140}
141
142#[derive(Debug)]
143pub struct TestResponse {
144 response: reqwest::Response,
145}
146
147impl Deref for TestResponse {
148 type Target = reqwest::Response;
149
150 fn deref(&self) -> &Self::Target {
151 &self.response
152 }
153}
154
155impl TestResponse {
156 #[allow(dead_code)]
157 pub async fn bytes(self) -> Bytes {
158 self.response.bytes().await.unwrap()
159 }
160
161 pub async fn text(self) -> String {
162 self.response.text().await.unwrap()
163 }
164
165 #[allow(dead_code)]
166 pub async fn json<T>(self) -> T
167 where
168 T: serde_core::de::DeserializeOwned,
169 {
170 self.response.json().await.unwrap()
171 }
172
173 pub async fn chunk(&mut self) -> Option<Bytes> {
174 self.response.chunk().await.unwrap()
175 }
176
177 pub async fn chunk_text(&mut self) -> Option<String> {
178 let chunk = self.chunk().await?;
179 Some(String::from_utf8(chunk.to_vec()).unwrap())
180 }
181}