1use std::convert::{Infallible, TryFrom};
4use std::fmt::Debug;
5use std::io;
6use std::str::FromStr;
7
8use futures_util::stream::TryStreamExt;
9use http_types::headers::{HeaderName, HeaderValue};
10use http_types::StatusCode;
11use hyper::body::HttpBody;
12use hyper::client::connect::Connect;
13use hyper_tls::HttpsConnector;
14
15use crate::Config;
16
17use super::{async_trait, Error, HttpClient, Request, Response};
18
19type HyperRequest = hyper::Request<hyper::Body>;
20
21trait HyperClientObject: Debug + Send + Sync + 'static {
23 fn dyn_request(&self, req: hyper::Request<hyper::Body>) -> hyper::client::ResponseFuture;
24}
25
26impl<C: Clone + Connect + Debug + Send + Sync + 'static> HyperClientObject for hyper::Client<C> {
27 fn dyn_request(&self, req: HyperRequest) -> hyper::client::ResponseFuture {
28 self.request(req)
29 }
30}
31
32#[derive(Debug)]
34pub struct HyperClient {
35 client: Box<dyn HyperClientObject>,
36 config: Config,
37}
38
39impl HyperClient {
40 pub fn new() -> Self {
42 let https = HttpsConnector::new();
43 let client = hyper::Client::builder().build(https);
44
45 Self {
46 client: Box::new(client),
47 config: Config::default(),
48 }
49 }
50
51 pub fn from_client<C>(client: hyper::Client<C>) -> Self
53 where
54 C: Clone + Connect + Debug + Send + Sync + 'static,
55 {
56 Self {
57 client: Box::new(client),
58 config: Config::default(),
59 }
60 }
61}
62
63impl Default for HyperClient {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69#[async_trait]
70impl HttpClient for HyperClient {
71 async fn send(&self, req: Request) -> Result<Response, Error> {
72 let req = HyperHttpRequest::try_from(req).await?.into_inner();
73
74 let conn_fut = self.client.dyn_request(req);
75 let response = if let Some(timeout) = self.config.timeout {
76 match tokio::time::timeout(timeout, conn_fut).await {
77 Err(_elapsed) => Err(Error::from_str(400, "Client timed out")),
78 Ok(Ok(try_res)) => Ok(try_res),
79 Ok(Err(e)) => Err(e.into()),
80 }?
81 } else {
82 conn_fut.await?
83 };
84
85 let res = HttpTypesResponse::try_from(response).await?.into_inner();
86 Ok(res)
87 }
88
89 fn set_config(&mut self, config: Config) -> http_types::Result<()> {
93 let connector = HttpsConnector::new();
94 let mut builder = hyper::Client::builder();
95
96 if !config.http_keep_alive {
97 builder.pool_max_idle_per_host(1);
98 }
99
100 self.client = Box::new(builder.build(connector));
101 self.config = config;
102
103 Ok(())
104 }
105
106 fn config(&self) -> &Config {
108 &self.config
109 }
110}
111
112impl TryFrom<Config> for HyperClient {
113 type Error = Infallible;
114
115 fn try_from(config: Config) -> Result<Self, Self::Error> {
116 let connector = HttpsConnector::new();
117 let mut builder = hyper::Client::builder();
118
119 if !config.http_keep_alive {
120 builder.pool_max_idle_per_host(1);
121 }
122
123 Ok(Self {
124 client: Box::new(builder.build(connector)),
125 config,
126 })
127 }
128}
129
130struct HyperHttpRequest(HyperRequest);
131
132impl HyperHttpRequest {
133 async fn try_from(mut value: Request) -> Result<Self, Error> {
134 let uri = hyper::Uri::try_from(&format!("{}", value.url())).unwrap();
136
137 match uri.scheme_str() {
139 Some("http") | Some("https") => (),
140 _ => return Err(Error::from_str(StatusCode::BadRequest, "invalid scheme")),
141 };
142
143 let mut request = hyper::Request::builder();
144
145 let req_headers = request.headers_mut().unwrap();
147 for (name, values) in &value {
148 let name = hyper::header::HeaderName::from_str(name.as_str()).unwrap();
150
151 for value in values.iter() {
152 let value =
154 hyper::header::HeaderValue::from_bytes(value.as_str().as_bytes()).unwrap();
155 req_headers.append(&name, value);
156 }
157 }
158
159 let body = value.body_bytes().await?;
160 let body = hyper::Body::from(body);
161
162 let request = request
163 .method(value.method())
164 .version(value.version().map(|v| v.into()).unwrap_or_default())
165 .uri(uri)
166 .body(body)?;
167
168 Ok(HyperHttpRequest(request))
169 }
170
171 fn into_inner(self) -> hyper::Request<hyper::Body> {
172 self.0
173 }
174}
175
176struct HttpTypesResponse(Response);
177
178impl HttpTypesResponse {
179 async fn try_from(value: hyper::Response<hyper::Body>) -> Result<Self, Error> {
180 let (parts, body) = value.into_parts();
181
182 let size_hint = body.size_hint().upper().map(|s| s as usize);
183 let body = body.map_err(|err| io::Error::new(io::ErrorKind::Other, err.to_string()));
184 let body = http_types::Body::from_reader(body.into_async_read(), size_hint);
185
186 let mut res = Response::new(parts.status);
187 res.set_version(Some(parts.version.into()));
188
189 for (name, value) in parts.headers {
190 let value = value.as_bytes().to_owned();
191 let value = HeaderValue::from_bytes(value)?;
192
193 if let Some(name) = name {
194 let name = name.as_str();
195 let name = HeaderName::from_str(name)?;
196 res.append_header(name, value);
197 }
198 }
199
200 res.set_body(body);
201 Ok(HttpTypesResponse(res))
202 }
203
204 fn into_inner(self) -> Response {
205 self.0
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use crate::{Error, HttpClient};
212 use http_types::{Method, Request, Url};
213 use hyper::service::{make_service_fn, service_fn};
214 use std::time::Duration;
215 use tokio::sync::oneshot::channel;
216
217 use super::HyperClient;
218
219 async fn echo(
220 req: hyper::Request<hyper::Body>,
221 ) -> Result<hyper::Response<hyper::Body>, hyper::Error> {
222 Ok(hyper::Response::new(req.into_body()))
223 }
224
225 #[tokio::test]
226 async fn basic_functionality() {
227 let (send, recv) = channel::<()>();
228
229 let recv = async move { recv.await.unwrap_or(()) };
230
231 let addr = ([127, 0, 0, 1], portpicker::pick_unused_port().unwrap()).into();
232 let service = make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(echo)) });
233 let server = hyper::Server::bind(&addr)
234 .serve(service)
235 .with_graceful_shutdown(recv);
236
237 let client = HyperClient::new();
238 let url = Url::parse(&format!("http://localhost:{}", addr.port())).unwrap();
239 let mut req = Request::new(Method::Get, url);
240 req.set_body("hello");
241
242 let client = async move {
243 tokio::time::delay_for(Duration::from_millis(100)).await;
244 let mut resp = client.send(req).await?;
245 send.send(()).unwrap();
246 assert_eq!(resp.body_string().await?, "hello");
247
248 Result::<(), Error>::Ok(())
249 };
250
251 let (client_res, server_res) = tokio::join!(client, server);
252 assert!(client_res.is_ok());
253 assert!(server_res.is_ok());
254 }
255}