http_client/
hyper.rs

1//! http-client implementation for reqwest
2
3use std::convert::{Infallible, TryFrom};
4use std::fmt::Debug;
5use std::io;
6use std::str::FromStr;
7
8use futures_util::stream::TryStreamExt;
9use http_types::headers::{HeaderName, HeaderValue};
10use http_types::StatusCode;
11use hyper::body::HttpBody;
12use hyper::client::connect::Connect;
13use hyper_tls::HttpsConnector;
14
15use crate::Config;
16
17use super::{async_trait, Error, HttpClient, Request, Response};
18
19type HyperRequest = hyper::Request<hyper::Body>;
20
21// Avoid leaking Hyper generics into HttpClient by hiding it behind a dynamic trait object pointer.
22trait HyperClientObject: Debug + Send + Sync + 'static {
23    fn dyn_request(&self, req: hyper::Request<hyper::Body>) -> hyper::client::ResponseFuture;
24}
25
26impl<C: Clone + Connect + Debug + Send + Sync + 'static> HyperClientObject for hyper::Client<C> {
27    fn dyn_request(&self, req: HyperRequest) -> hyper::client::ResponseFuture {
28        self.request(req)
29    }
30}
31
32/// Hyper-based HTTP Client.
33#[derive(Debug)]
34pub struct HyperClient {
35    client: Box<dyn HyperClientObject>,
36    config: Config,
37}
38
39impl HyperClient {
40    /// Create a new client instance.
41    pub fn new() -> Self {
42        let https = HttpsConnector::new();
43        let client = hyper::Client::builder().build(https);
44
45        Self {
46            client: Box::new(client),
47            config: Config::default(),
48        }
49    }
50
51    /// Create from externally initialized and configured client.
52    pub fn from_client<C>(client: hyper::Client<C>) -> Self
53    where
54        C: Clone + Connect + Debug + Send + Sync + 'static,
55    {
56        Self {
57            client: Box::new(client),
58            config: Config::default(),
59        }
60    }
61}
62
63impl Default for HyperClient {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69#[async_trait]
70impl HttpClient for HyperClient {
71    async fn send(&self, req: Request) -> Result<Response, Error> {
72        let req = HyperHttpRequest::try_from(req).await?.into_inner();
73
74        let conn_fut = self.client.dyn_request(req);
75        let response = if let Some(timeout) = self.config.timeout {
76            match tokio::time::timeout(timeout, conn_fut).await {
77                Err(_elapsed) => Err(Error::from_str(400, "Client timed out")),
78                Ok(Ok(try_res)) => Ok(try_res),
79                Ok(Err(e)) => Err(e.into()),
80            }?
81        } else {
82            conn_fut.await?
83        };
84
85        let res = HttpTypesResponse::try_from(response).await?.into_inner();
86        Ok(res)
87    }
88
89    /// Override the existing configuration with new configuration.
90    ///
91    /// Config options may not impact existing connections.
92    fn set_config(&mut self, config: Config) -> http_types::Result<()> {
93        let connector = HttpsConnector::new();
94        let mut builder = hyper::Client::builder();
95
96        if !config.http_keep_alive {
97            builder.pool_max_idle_per_host(1);
98        }
99
100        self.client = Box::new(builder.build(connector));
101        self.config = config;
102
103        Ok(())
104    }
105
106    /// Get the current configuration.
107    fn config(&self) -> &Config {
108        &self.config
109    }
110}
111
112impl TryFrom<Config> for HyperClient {
113    type Error = Infallible;
114
115    fn try_from(config: Config) -> Result<Self, Self::Error> {
116        let connector = HttpsConnector::new();
117        let mut builder = hyper::Client::builder();
118
119        if !config.http_keep_alive {
120            builder.pool_max_idle_per_host(1);
121        }
122
123        Ok(Self {
124            client: Box::new(builder.build(connector)),
125            config,
126        })
127    }
128}
129
130struct HyperHttpRequest(HyperRequest);
131
132impl HyperHttpRequest {
133    async fn try_from(mut value: Request) -> Result<Self, Error> {
134        // UNWRAP: This unwrap is unjustified in `http-types`, need to check if it's actually safe.
135        let uri = hyper::Uri::try_from(&format!("{}", value.url())).unwrap();
136
137        // `HyperClient` depends on the scheme being either "http" or "https"
138        match uri.scheme_str() {
139            Some("http") | Some("https") => (),
140            _ => return Err(Error::from_str(StatusCode::BadRequest, "invalid scheme")),
141        };
142
143        let mut request = hyper::Request::builder();
144
145        // UNWRAP: Default builder is safe
146        let req_headers = request.headers_mut().unwrap();
147        for (name, values) in &value {
148            // UNWRAP: http-types and http have equivalent validation rules
149            let name = hyper::header::HeaderName::from_str(name.as_str()).unwrap();
150
151            for value in values.iter() {
152                // UNWRAP: http-types and http have equivalent validation rules
153                let value =
154                    hyper::header::HeaderValue::from_bytes(value.as_str().as_bytes()).unwrap();
155                req_headers.append(&name, value);
156            }
157        }
158
159        let body = value.body_bytes().await?;
160        let body = hyper::Body::from(body);
161
162        let request = request
163            .method(value.method())
164            .version(value.version().map(|v| v.into()).unwrap_or_default())
165            .uri(uri)
166            .body(body)?;
167
168        Ok(HyperHttpRequest(request))
169    }
170
171    fn into_inner(self) -> hyper::Request<hyper::Body> {
172        self.0
173    }
174}
175
176struct HttpTypesResponse(Response);
177
178impl HttpTypesResponse {
179    async fn try_from(value: hyper::Response<hyper::Body>) -> Result<Self, Error> {
180        let (parts, body) = value.into_parts();
181
182        let size_hint = body.size_hint().upper().map(|s| s as usize);
183        let body = body.map_err(|err| io::Error::new(io::ErrorKind::Other, err.to_string()));
184        let body = http_types::Body::from_reader(body.into_async_read(), size_hint);
185
186        let mut res = Response::new(parts.status);
187        res.set_version(Some(parts.version.into()));
188
189        for (name, value) in parts.headers {
190            let value = value.as_bytes().to_owned();
191            let value = HeaderValue::from_bytes(value)?;
192
193            if let Some(name) = name {
194                let name = name.as_str();
195                let name = HeaderName::from_str(name)?;
196                res.append_header(name, value);
197            }
198        }
199
200        res.set_body(body);
201        Ok(HttpTypesResponse(res))
202    }
203
204    fn into_inner(self) -> Response {
205        self.0
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use crate::{Error, HttpClient};
212    use http_types::{Method, Request, Url};
213    use hyper::service::{make_service_fn, service_fn};
214    use std::time::Duration;
215    use tokio::sync::oneshot::channel;
216
217    use super::HyperClient;
218
219    async fn echo(
220        req: hyper::Request<hyper::Body>,
221    ) -> Result<hyper::Response<hyper::Body>, hyper::Error> {
222        Ok(hyper::Response::new(req.into_body()))
223    }
224
225    #[tokio::test]
226    async fn basic_functionality() {
227        let (send, recv) = channel::<()>();
228
229        let recv = async move { recv.await.unwrap_or(()) };
230
231        let addr = ([127, 0, 0, 1], portpicker::pick_unused_port().unwrap()).into();
232        let service = make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(echo)) });
233        let server = hyper::Server::bind(&addr)
234            .serve(service)
235            .with_graceful_shutdown(recv);
236
237        let client = HyperClient::new();
238        let url = Url::parse(&format!("http://localhost:{}", addr.port())).unwrap();
239        let mut req = Request::new(Method::Get, url);
240        req.set_body("hello");
241
242        let client = async move {
243            tokio::time::delay_for(Duration::from_millis(100)).await;
244            let mut resp = client.send(req).await?;
245            send.send(()).unwrap();
246            assert_eq!(resp.body_string().await?, "hello");
247
248            Result::<(), Error>::Ok(())
249        };
250
251        let (client_res, server_res) = tokio::join!(client, server);
252        assert!(client_res.is_ok());
253        assert!(server_res.is_ok());
254    }
255}