axum_test_helpers/
test_client.rs

1use super::{serve, Request, Response};
2use bytes::Bytes;
3use futures_util::future::BoxFuture;
4use http::{
5    header::{HeaderName, HeaderValue},
6    StatusCode,
7};
8use std::{convert::Infallible, future::IntoFuture, net::SocketAddr, str::FromStr};
9use tokio::net::TcpListener;
10use tower::make::Shared;
11use tower_service::Service;
12
13pub fn spawn_service<S>(svc: S) -> SocketAddr
14where
15    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
16    S::Future: Send,
17{
18    let std_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
19    std_listener.set_nonblocking(true).unwrap();
20    let listener = TcpListener::from_std(std_listener).unwrap();
21
22    let addr = listener.local_addr().unwrap();
23    println!("Listening on {addr}");
24
25    tokio::spawn(async move {
26        serve(listener, Shared::new(svc))
27            .await
28            .expect("server error")
29    });
30
31    addr
32}
33
34pub struct TestClient {
35    client: reqwest::Client,
36    addr: SocketAddr,
37}
38
39impl TestClient {
40    pub fn new<S>(svc: S) -> Self
41    where
42        S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
43        S::Future: Send,
44    {
45        let addr = spawn_service(svc);
46
47        let client = reqwest::Client::builder()
48            .redirect(reqwest::redirect::Policy::none())
49            .build()
50            .unwrap();
51
52        TestClient { client, addr }
53    }
54
55    pub fn get(&self, url: &str) -> RequestBuilder {
56        RequestBuilder {
57            builder: self.client.get(format!("http://{}{}", self.addr, url)),
58        }
59    }
60
61    pub fn head(&self, url: &str) -> RequestBuilder {
62        RequestBuilder {
63            builder: self.client.head(format!("http://{}{}", self.addr, url)),
64        }
65    }
66
67    pub fn post(&self, url: &str) -> RequestBuilder {
68        RequestBuilder {
69            builder: self.client.post(format!("http://{}{}", self.addr, url)),
70        }
71    }
72
73    #[allow(dead_code)]
74    pub fn put(&self, url: &str) -> RequestBuilder {
75        RequestBuilder {
76            builder: self.client.put(format!("http://{}{}", self.addr, url)),
77        }
78    }
79
80    #[allow(dead_code)]
81    pub fn patch(&self, url: &str) -> RequestBuilder {
82        RequestBuilder {
83            builder: self.client.patch(format!("http://{}{}", self.addr, url)),
84        }
85    }
86
87    pub fn delete(&self, url: &str) -> RequestBuilder {
88        RequestBuilder {
89            builder: self.client.delete(format!("http://{}{}", self.addr, url)),
90        }
91    }
92}
93
94pub struct RequestBuilder {
95    builder: reqwest::RequestBuilder,
96}
97
98impl RequestBuilder {
99    pub fn body(mut self, body: impl Into<reqwest::Body>) -> Self {
100        self.builder = self.builder.body(body);
101        self
102    }
103
104    pub fn json<T>(mut self, json: &T) -> Self
105    where
106        T: serde::Serialize,
107    {
108        self.builder = self.builder.json(json);
109        self
110    }
111
112    pub fn header<K, V>(mut self, key: K, value: V) -> Self
113    where
114        HeaderName: TryFrom<K>,
115        <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
116        HeaderValue: TryFrom<V>,
117        <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
118    {
119        // reqwest still uses http 0.2
120        let key: HeaderName = key.try_into().map_err(Into::into).unwrap();
121        let key = reqwest::header::HeaderName::from_bytes(key.as_ref()).unwrap();
122
123        let value: HeaderValue = value.try_into().map_err(Into::into).unwrap();
124        let value = reqwest::header::HeaderValue::from_bytes(value.as_bytes()).unwrap();
125
126        self.builder = self.builder.header(key, value);
127
128        self
129    }
130
131    #[allow(dead_code)]
132    pub fn multipart(mut self, form: reqwest::multipart::Form) -> Self {
133        self.builder = self.builder.multipart(form);
134        self
135    }
136}
137
138impl IntoFuture for RequestBuilder {
139    type Output = TestResponse;
140    type IntoFuture = BoxFuture<'static, Self::Output>;
141
142    fn into_future(self) -> Self::IntoFuture {
143        Box::pin(async {
144            TestResponse {
145                response: self.builder.send().await.unwrap(),
146            }
147        })
148    }
149}
150
151#[derive(Debug)]
152pub struct TestResponse {
153    response: reqwest::Response,
154}
155
156impl TestResponse {
157    #[allow(dead_code)]
158    pub async fn bytes(self) -> Bytes {
159        self.response.bytes().await.unwrap()
160    }
161
162    pub async fn text(self) -> String {
163        self.response.text().await.unwrap()
164    }
165
166    #[allow(dead_code)]
167    pub async fn json<T>(self) -> T
168    where
169        T: serde::de::DeserializeOwned,
170    {
171        self.response.json().await.unwrap()
172    }
173
174    pub fn status(&self) -> StatusCode {
175        StatusCode::from_u16(self.response.status().as_u16()).unwrap()
176    }
177
178    pub fn headers(&self) -> http::HeaderMap {
179        // reqwest still uses http 0.2 so have to convert into http 1.0
180        let mut headers = http::HeaderMap::new();
181        for (key, value) in self.response.headers() {
182            let key = http::HeaderName::from_str(key.as_str()).unwrap();
183            let value = http::HeaderValue::from_bytes(value.as_bytes()).unwrap();
184            headers.insert(key, value);
185        }
186        headers
187    }
188
189    pub async fn chunk(&mut self) -> Option<Bytes> {
190        self.response.chunk().await.unwrap()
191    }
192
193    pub async fn chunk_text(&mut self) -> Option<String> {
194        let chunk = self.chunk().await?;
195        Some(String::from_utf8(chunk.to_vec()).unwrap())
196    }
197}