Skip to main content

axum/test_helpers/
test_client.rs

1use super::{serve, Request, Response};
2use bytes::Bytes;
3use futures_util::future::BoxFuture;
4use http::header::{HeaderName, HeaderValue};
5use std::ops::Deref;
6use std::{convert::Infallible, future::IntoFuture, net::SocketAddr};
7use tokio::net::TcpListener;
8use tower::make::Shared;
9use tower_service::Service;
10
11pub(crate) fn spawn_service<S>(svc: S) -> SocketAddr
12where
13    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
14    S::Future: Send,
15{
16    let std_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
17    std_listener.set_nonblocking(true).unwrap();
18    let listener = TcpListener::from_std(std_listener).unwrap();
19
20    let addr = listener.local_addr().unwrap();
21    println!("Listening on {addr}");
22
23    tokio::spawn(async move {
24        serve(listener, Shared::new(svc))
25            .await
26            .expect("server error")
27    });
28
29    addr
30}
31
32pub struct TestClient {
33    client: reqwest::Client,
34    addr: SocketAddr,
35}
36
37impl TestClient {
38    pub fn new<S>(svc: S) -> Self
39    where
40        S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
41        S::Future: Send,
42    {
43        let addr = spawn_service(svc);
44
45        let client = reqwest::Client::builder()
46            .redirect(reqwest::redirect::Policy::none())
47            .build()
48            .unwrap();
49
50        TestClient { client, addr }
51    }
52
53    pub fn get(&self, url: &str) -> RequestBuilder {
54        RequestBuilder {
55            builder: self.client.get(format!("http://{}{url}", self.addr)),
56        }
57    }
58
59    pub fn head(&self, url: &str) -> RequestBuilder {
60        RequestBuilder {
61            builder: self.client.head(format!("http://{}{url}", self.addr)),
62        }
63    }
64
65    pub fn post(&self, url: &str) -> RequestBuilder {
66        RequestBuilder {
67            builder: self.client.post(format!("http://{}{url}", self.addr)),
68        }
69    }
70
71    #[allow(dead_code)]
72    pub fn put(&self, url: &str) -> RequestBuilder {
73        RequestBuilder {
74            builder: self.client.put(format!("http://{}{url}", self.addr)),
75        }
76    }
77
78    #[allow(dead_code)]
79    pub fn patch(&self, url: &str) -> RequestBuilder {
80        RequestBuilder {
81            builder: self.client.patch(format!("http://{}{url}", self.addr)),
82        }
83    }
84
85    #[allow(dead_code)]
86    #[must_use]
87    pub fn server_port(&self) -> u16 {
88        self.addr.port()
89    }
90}
91
92#[must_use]
93pub struct RequestBuilder {
94    builder: reqwest::RequestBuilder,
95}
96
97impl RequestBuilder {
98    pub fn body(mut self, body: impl Into<reqwest::Body>) -> Self {
99        self.builder = self.builder.body(body);
100        self
101    }
102
103    pub fn json<T>(mut self, json: &T) -> Self
104    where
105        T: serde_core::Serialize,
106    {
107        self.builder = self.builder.json(json);
108        self
109    }
110
111    pub fn header<K, V>(mut self, key: K, value: V) -> Self
112    where
113        HeaderName: TryFrom<K>,
114        <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
115        HeaderValue: TryFrom<V>,
116        <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
117    {
118        self.builder = self.builder.header(key, value);
119        self
120    }
121
122    #[allow(dead_code)]
123    pub fn multipart(mut self, form: reqwest::multipart::Form) -> Self {
124        self.builder = self.builder.multipart(form);
125        self
126    }
127}
128
129impl IntoFuture for RequestBuilder {
130    type Output = TestResponse;
131    type IntoFuture = BoxFuture<'static, Self::Output>;
132
133    fn into_future(self) -> Self::IntoFuture {
134        Box::pin(async {
135            TestResponse {
136                response: self.builder.send().await.unwrap(),
137            }
138        })
139    }
140}
141
142#[derive(Debug)]
143pub struct TestResponse {
144    response: reqwest::Response,
145}
146
147impl Deref for TestResponse {
148    type Target = reqwest::Response;
149
150    fn deref(&self) -> &Self::Target {
151        &self.response
152    }
153}
154
155impl TestResponse {
156    #[allow(dead_code)]
157    pub async fn bytes(self) -> Bytes {
158        self.response.bytes().await.unwrap()
159    }
160
161    pub async fn text(self) -> String {
162        self.response.text().await.unwrap()
163    }
164
165    #[allow(dead_code)]
166    pub async fn json<T>(self) -> T
167    where
168        T: serde_core::de::DeserializeOwned,
169    {
170        self.response.json().await.unwrap()
171    }
172
173    pub async fn chunk(&mut self) -> Option<Bytes> {
174        self.response.chunk().await.unwrap()
175    }
176
177    pub async fn chunk_text(&mut self) -> Option<String> {
178        let chunk = self.chunk().await?;
179        Some(String::from_utf8(chunk.to_vec()).unwrap())
180    }
181}