generic_async_http_client/hyper/
mod.rs

1use std::{
2    convert::{Infallible, TryFrom},
3    str::FromStr,
4};
5
6use serde::Serialize;
7
8pub use hyper::{
9    body::Incoming,
10    header::{HeaderName, HeaderValue},
11};
12use hyper::{
13    body::{Body as BodyTrait, Bytes, Frame, SizeHint},
14    header::{InvalidHeaderName, InvalidHeaderValue, CONTENT_TYPE},
15    http::{
16        method::{InvalidMethod, Method},
17        request::Builder,
18        uri::{Builder as UriBuilder, InvalidUri, PathAndQuery, Uri},
19        Error as HTTPError,
20    },
21    Error as HyperError, Request, Response,
22};
23use std::mem::take;
24
25mod connector;
26pub(crate) use connector::HyperClient;
27
28pub(crate) fn get_client() -> HyperClient {
29    HyperClient::default()
30}
31
32#[derive(Debug)]
33pub struct Req {
34    req: Builder,
35    body: Body,
36    pub(crate) client: Option<HyperClient>,
37}
38
39impl<M, U> TryFrom<(M, U)> for crate::Request
40where
41    Method: TryFrom<M>,
42    <Method as TryFrom<M>>::Error: Into<HTTPError>,
43    Uri: TryFrom<U>,
44    <Uri as TryFrom<U>>::Error: Into<HTTPError>,
45{
46    type Error = Infallible;
47
48    fn try_from(value: (M, U)) -> Result<Self, Self::Error> {
49        let req = Builder::new().method(value.0).uri(value.1);
50
51        Ok(crate::Request(Req {
52            req,
53            body: Body::empty(),
54            client: None,
55        }))
56    }
57}
58impl Req {
59    fn init(method: Method, uri: &str) -> Req {
60        let req = Builder::new().method(method).uri(uri);
61
62        Req {
63            req,
64            body: Body::empty(),
65            client: None,
66        }
67    }
68    fn _query(&mut self, query: String) -> Result<(), Error> {
69        let old = self.req.uri_ref().expect("no uri");
70
71        let mut p_and_p = String::with_capacity(old.path().len() + query.len() + 1);
72        p_and_p.push_str(old.path());
73        p_and_p.push('?');
74        p_and_p.push_str(&query);
75
76        let path_and_query = PathAndQuery::from_str(&p_and_p)?;
77
78        let new = UriBuilder::new()
79            .scheme(old.scheme_str().unwrap())
80            .authority(old.authority().unwrap().as_str())
81            .path_and_query(path_and_query)
82            .build()?;
83
84        self.req = take(&mut self.req).uri(new);
85        Ok(())
86    }
87}
88
89impl crate::request::Requests for Req {
90    fn get(uri: &str) -> Req {
91        Self::init(Method::GET, uri)
92    }
93    fn post(uri: &str) -> Req {
94        Self::init(Method::POST, uri)
95    }
96    fn put(uri: &str) -> Req {
97        Self::init(Method::PUT, uri)
98    }
99    fn delete(uri: &str) -> Req {
100        Self::init(Method::DELETE, uri)
101    }
102    fn head(uri: &str) -> Req {
103        Self::init(Method::HEAD, uri)
104    }
105    fn options(uri: &str) -> Req {
106        Self::init(Method::OPTIONS, uri)
107    }
108    fn new(meth: &str, uri: &str) -> Result<Req, Error> {
109        Ok(Self::init(Method::from_str(meth)?, uri))
110    }
111    async fn send_request(mut self) -> Result<crate::Response, Error> {
112        let req = self.req.body(self.body)?;
113
114        let resp = if let Some(mut client) = self.client.take() {
115            client.request(req).await?
116        } else {
117            get_client().request(req).await?
118        };
119
120        #[cfg(not(all(feature = "mock_tests", test)))]
121        return Ok(crate::Response(Resp { resp }));
122        #[cfg(all(feature = "mock_tests", test))]
123        return Ok(crate::Response(Resp::Real(not_mocked::Resp { resp })));
124    }
125    fn json<T: Serialize + ?Sized>(&mut self, json: &T) -> Result<(), Error> {
126        let bytes = serde_json::to_string(&json)?;
127        self.set_header(CONTENT_TYPE, HeaderValue::from_static("application/json"))?;
128        self.body = bytes.into();
129        Ok(())
130    }
131    fn form<T: Serialize + ?Sized>(&mut self, data: &T) -> Result<(), Error> {
132        let query = serde_urlencoded::to_string(data)?;
133        self.set_header(
134            CONTENT_TYPE,
135            HeaderValue::from_static("application/x-www-form-urlencoded"),
136        )?;
137        self.body = query.into();
138        Ok(())
139    }
140    #[inline]
141    fn query<T: Serialize + ?Sized>(&mut self, query: &T) -> Result<(), Error> {
142        // codegen trampoline: https://github.com/rust-lang/rust/issues/77960
143        self._query(serde_qs::to_string(&query)?)
144    }
145    fn body<B: Into<Body>>(&mut self, body: B) -> Result<(), Error> {
146        self.body = body.into();
147        Ok(())
148    }
149    fn set_header(&mut self, name: HeaderName, value: HeaderValue) -> Result<(), Error> {
150        self.req.headers_mut().map(|hm| hm.insert(name, value));
151        Ok(())
152    }
153    fn add_header(&mut self, name: HeaderName, value: HeaderValue) -> Result<(), Error> {
154        self.req = take(&mut self.req).header(name, value);
155        Ok(())
156    }
157}
158use hyper::body::Buf;
159use serde::de::DeserializeOwned;
160
161mod not_mocked {
162    use super::*;
163    pub struct Resp {
164        pub(super) resp: Response<Incoming>,
165    }
166    impl crate::response::Responses for Resp {
167        fn status(&self) -> u16 {
168            self.resp.status().as_u16()
169        }
170        fn status_str(&self) -> &'static str {
171            self.resp.status().canonical_reason().unwrap_or("")
172        }
173        async fn json<D: DeserializeOwned>(&mut self) -> Result<D, Error> {
174            let reader = aggregate(self.resp.body_mut()).await?.reader();
175            Ok(serde_json::from_reader(reader)?)
176        }
177        async fn bytes(&mut self) -> Result<Vec<u8>, Error> {
178            let mut b = aggregate(self.resp.body_mut()).await?;
179            let capacity = b.remaining();
180            //Ok(b.copy_to_bytes(capacity).into())
181            let mut v = Vec::with_capacity(capacity);
182            let ptr = v.spare_capacity_mut().as_mut_ptr();
183            let dst = unsafe { std::slice::from_raw_parts_mut(ptr.cast::<u8>(), capacity) };
184            b.copy_to_slice(dst);
185            unsafe {
186                v.set_len(capacity);
187            }
188            Ok(v)
189        }
190        async fn string(&mut self) -> Result<String, Error> {
191            let b = self.bytes().await?;
192            Ok(String::from_utf8_lossy(&b).to_string())
193        }
194        fn get_header(&self, name: HeaderName) -> Option<&HeaderValue> {
195            self.resp.headers().get(name)
196        }
197        fn get_headers(&self, name: HeaderName) -> impl Iterator<Item = &HeaderValue> {
198            self.resp.headers().get_all(name).iter()
199        }
200        fn header_iter(&self) -> impl Iterator<Item = (&HeaderName, &HeaderValue)> {
201            self.resp.headers().into_iter()
202        }
203    }
204}
205
206#[cfg(not(all(feature = "mock_tests", test)))]
207pub use not_mocked::Resp;
208#[cfg(all(feature = "mock_tests", test))]
209pub type Resp = crate::mock::Resp<not_mocked::Resp>;
210
211#[cfg(all(feature = "mock_tests", test))]
212impl crate::mock::MockedRequest for Req {
213    /// on error, return full body
214    fn assert_body_bytes(&mut self, should_be: &[u8]) -> Result<(), Vec<u8>> {
215        let is = &self.body.0;
216        if is != should_be {
217            Err(is.clone())
218        } else {
219            Ok(())
220        }
221    }
222    fn get_headers(&self, name: &str) -> Option<Vec<crate::mock::MockHeaderValue>> {
223        let name = HeaderName::from_str(name).unwrap();
224        let hm = self
225            .req
226            .headers_ref()
227            .expect("builder should not have errors");
228        if !hm.contains_key(&name) {
229            return None;
230        }
231        Some(hm.get_all(name).iter().cloned().map(|v| v.into()).collect())
232    }
233    fn endpoint(&self) -> crate::mock::Endpoint {
234        let uri = self
235            .req
236            .uri_ref()
237            .map(|u| u.to_string())
238            .unwrap_or_default();
239        let meth = self
240            .req
241            .method_ref()
242            .map(|u| u.to_string())
243            .unwrap_or_default();
244        (meth, uri)
245    }
246}
247
248//(fragmented) memory returned by aggregate
249struct FracturedBuf(std::collections::VecDeque<Bytes>);
250impl Buf for FracturedBuf {
251    fn remaining(&self) -> usize {
252        self.0.iter().map(|buf| buf.remaining()).sum()
253    }
254    fn chunk(&self) -> &[u8] {
255        self.0.front().map(Buf::chunk).unwrap_or_default()
256    }
257    fn advance(&mut self, mut cnt: usize) {
258        let bufs = &mut self.0;
259        while cnt > 0 {
260            if let Some(front) = bufs.front_mut() {
261                let rem = front.remaining();
262                if rem > cnt {
263                    front.advance(cnt);
264                    return;
265                } else {
266                    front.advance(rem);
267                    cnt -= rem;
268                }
269            } else {
270                //no data -> panic?
271                return;
272            }
273            bufs.pop_front();
274        }
275    }
276}
277/// Helper for aggregate function. Polls a single frame from an incoming body
278struct Framed<'a>(&'a mut Incoming);
279
280impl futures::Future for Framed<'_> {
281    type Output = Option<Result<hyper::body::Frame<Bytes>, hyper::Error>>;
282
283    fn poll(
284        mut self: std::pin::Pin<&mut Self>,
285        ctx: &mut std::task::Context<'_>,
286    ) -> std::task::Poll<Self::Output> {
287        std::pin::Pin::new(&mut self.0).poll_frame(ctx)
288    }
289}
290/// read an incoming body to (fragmented) memory
291async fn aggregate(body: &mut Incoming) -> Result<FracturedBuf, Error> {
292    let mut v = std::collections::VecDeque::new();
293    while let Some(f) = Framed(body).await {
294        if let Ok(d) = f?.into_data() {
295            v.push_back(d);
296        }
297    }
298    Ok(FracturedBuf(v))
299}
300
301#[derive(Debug)]
302pub enum Error {
303    Scheme,
304    Http(HTTPError),
305    InvalidQueryString(serde_qs::Error),
306    InvalidMethod(InvalidMethod),
307    Hyper(HyperError),
308    Json(serde_json::Error),
309    InvalidHeaderValue(InvalidHeaderValue),
310    InvalidHeaderName(InvalidHeaderName),
311    InvalidUri(InvalidUri),
312    Urlencoded(serde_urlencoded::ser::Error),
313    Io(std::io::Error),
314}
315impl std::error::Error for Error {}
316use std::fmt;
317impl fmt::Display for Error {
318    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
319        match self {
320            Error::Scheme => write!(f, "Scheme"),
321            Error::Http(i) => write!(f, "{}", i),
322            Error::InvalidQueryString(i) => write!(f, "{}", i),
323            Error::InvalidMethod(i) => write!(f, "{}", i),
324            Error::Hyper(i) => write!(f, "{}", i),
325            Error::Json(i) => write!(f, "{}", i),
326            Error::InvalidHeaderValue(i) => write!(f, "{}", i),
327            Error::InvalidHeaderName(i) => write!(f, "{}", i),
328            Error::InvalidUri(i) => write!(f, "{}", i),
329            Error::Urlencoded(i) => write!(f, "{}", i),
330            Error::Io(i) => write!(f, "{}", i),
331        }
332    }
333}
334impl From<Error> for crate::Error {
335    fn from(e: Error) -> Self {
336        match e {
337            Error::Io(error) => Self::Io(error),
338            Error::Hyper(h) => {
339                //It might be an IO error. If so, return it as such
340                if let Some(io) = std::error::Error::source(&h)
341                    .and_then(|err| err.downcast_ref::<std::io::Error>())
342                {
343                    let io_e = if let Some(code) = io.raw_os_error() {
344                        std::io::Error::from_raw_os_error(code)
345                    //}else if let Some(error) = io.into_inner() {
346                    //    std::io::Error::new(io.kind(), error)
347                    } else {
348                        io.kind().into()
349                    };
350                    Self::Io(io_e)
351                } else {
352                    Self::Other(Error::Hyper(h))
353                }
354            }
355            e => Self::Other(e),
356        }
357    }
358}
359//connect_to_uri
360impl From<std::io::Error> for Error {
361    fn from(e: std::io::Error) -> Self {
362        Self::Io(e)
363    }
364}
365//Req::form
366impl From<serde_urlencoded::ser::Error> for Error {
367    fn from(e: serde_urlencoded::ser::Error) -> Self {
368        Self::Urlencoded(e)
369    }
370}
371impl From<InvalidUri> for Error {
372    fn from(e: InvalidUri) -> Self {
373        Self::InvalidUri(e)
374    }
375}
376//TryFrom<> for HeaderName
377impl From<InvalidHeaderName> for Error {
378    fn from(e: InvalidHeaderName) -> Self {
379        Self::InvalidHeaderName(e)
380    }
381}
382//TryFrom<> for HeaderValue
383impl From<InvalidHeaderValue> for Error {
384    fn from(e: InvalidHeaderValue) -> Self {
385        Self::InvalidHeaderValue(e)
386    }
387}
388//Resp::json
389impl From<serde_json::Error> for Error {
390    fn from(e: serde_json::Error) -> Self {
391        Self::Json(e)
392    }
393}
394impl From<HyperError> for Error {
395    fn from(e: HyperError) -> Self {
396        Self::Hyper(e)
397    }
398}
399impl From<InvalidMethod> for Error {
400    fn from(e: InvalidMethod) -> Self {
401        Self::InvalidMethod(e)
402    }
403}
404impl From<HTTPError> for Error {
405    fn from(e: HTTPError) -> Self {
406        Self::Http(e)
407    }
408}
409//Req::query
410impl From<serde_qs::Error> for Error {
411    fn from(e: serde_qs::Error) -> Self {
412        Self::InvalidQueryString(e)
413    }
414}
415
416#[derive(Debug)]
417pub struct Body(Vec<u8>);
418impl Body {
419    fn empty() -> Self {
420        Self(vec![])
421    }
422}
423impl From<String> for Body {
424    #[inline]
425    fn from(t: String) -> Self {
426        Body(t.into_bytes())
427    }
428}
429impl From<Vec<u8>> for Body {
430    #[inline]
431    fn from(t: Vec<u8>) -> Self {
432        Body(t)
433    }
434}
435impl From<&'static [u8]> for Body {
436    #[inline]
437    fn from(t: &'static [u8]) -> Self {
438        Body(t.to_vec())
439    }
440}
441impl From<&'static str> for Body {
442    #[inline]
443    fn from(t: &'static str) -> Self {
444        Body(t.as_bytes().to_vec())
445    }
446}
447impl hyper::body::Body for Body {
448    type Data = Bytes;
449    type Error = Infallible;
450
451    fn poll_frame(
452        mut self: std::pin::Pin<&mut Self>,
453        _cx: &mut std::task::Context<'_>,
454    ) -> std::task::Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
455        if self.0.is_empty() {
456            std::task::Poll::Ready(None)
457        } else {
458            let v: Vec<u8> = std::mem::take(self.0.as_mut());
459            std::task::Poll::Ready(Some(Ok(Frame::data(v.into()))))
460        }
461    }
462    fn size_hint(&self) -> SizeHint {
463        SizeHint::with_exact(self.0.len() as u64)
464    }
465}