client_util/
request.rs

1#[cfg(feature = "multipart")]
2#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
3mod multipart;
4use bytes::Bytes;
5use futures_util::TryFutureExt;
6use http::uri::PathAndQuery;
7use http::HeaderValue;
8use http::Request;
9use http::Response;
10use http::{header::CONTENT_TYPE, Uri};
11use http_body_util::combinators::BoxBody;
12use http_body_util::{Empty, Full};
13#[cfg(feature = "multipart")]
14pub use multipart::*;
15#[cfg(feature = "serde")]
16use serde::Serialize;
17use std::convert::Infallible;
18use std::future::Future;
19use std::ops::Deref;
20use std::ops::DerefMut;
21use std::str::FromStr;
22
23use crate::body::{empty, full};
24use crate::client::ClientBody;
25
26#[derive(Debug, thiserror::Error)]
27pub enum BuildRequestError {
28    #[cfg(feature = "json")]
29    #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
30    #[error("failed to serialize json body: {0}")]
31    BuildJsonBody(#[from] serde_json::Error),
32    #[cfg(feature = "form")]
33    #[cfg_attr(docsrs, doc(cfg(feature = "form")))]
34    #[error("failed to build form body: {0}")]
35    BuildForm(#[from] BuildFormError),
36    #[cfg(feature = "multipart")]
37    #[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
38    #[error("failed to build multipart body: {0}")]
39    BuildMultipart(#[from] BuildMultipartError),
40    #[error("failed to build request path: {0}")]
41    BuildPath(#[from] BuildPathError),
42    #[cfg(feature = "query")]
43    #[cfg_attr(docsrs, doc(cfg(feature = "query")))]
44    #[error("failed to build request query: {0}")]
45    BuildQuery(#[from] BuildQueryError),
46    #[error("invalid uri: {0}")]
47    InvalidUri(#[from] http::uri::InvalidUri),
48    #[error("invalid header value: {0}")]
49    InvalidHeaderValue(#[from] http::header::InvalidHeaderValue),
50    #[error("failed to build request: {0}")]
51    HttpError(#[from] http::Error),
52}
53
54impl From<Infallible> for BuildRequestError {
55    fn from(_: Infallible) -> Self {
56        unreachable!()
57    }
58}
59
60#[derive(Debug, thiserror::Error)]
61pub enum BuildPathError {
62    #[error("invalid uri: {0}")]
63    InvalidUri(#[from] http::uri::InvalidUri),
64    #[error("invalid uri parts: {0}")]
65    InvalidUriParts(#[from] http::uri::InvalidUriParts),
66}
67
68#[cfg(feature = "query")]
69#[cfg_attr(docsrs, doc(cfg(feature = "query")))]
70#[derive(Debug, thiserror::Error)]
71pub enum BuildQueryError {
72    #[error("invalid uri: {0}")]
73    InvalidUri(#[from] http::uri::InvalidUri),
74    #[error("invalid uri parts: {0}")]
75    InvalidUriParts(#[from] http::uri::InvalidUriParts),
76    #[error("failed to serialize query string: {0}")]
77    SerializeQuery(#[from] serde_urlencoded::ser::Error),
78}
79
80#[cfg(feature = "query")]
81#[cfg_attr(docsrs, doc(cfg(feature = "query")))]
82#[derive(Debug, thiserror::Error)]
83pub enum BuildMultipartError {
84    #[error("invalid boundary header: {0}")]
85    InvalidBoundaryHeader(#[from] http::header::InvalidHeaderValue),
86    #[error("invalid mime type: {0}")]
87    InvalidMime(#[from] mime::FromStrError),
88}
89
90#[cfg(feature = "form")]
91#[cfg_attr(docsrs, doc(cfg(feature = "form")))]
92#[derive(Debug, thiserror::Error)]
93pub enum BuildFormError {
94    #[error("failed to serialize form body: {0}")]
95    SerializeForm(#[from] serde_urlencoded::ser::Error),
96    #[error("invalid content type header: {0}")]
97    InvalidContentTypeHeader(#[from] http::header::InvalidHeaderValue),
98}
99
100#[cfg(feature = "json")]
101#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
102#[derive(Debug, thiserror::Error)]
103pub enum BuildJsonBodyError {
104    #[error("failed to serialize json body: {0}")]
105    SerdeJson(#[from] serde_json::Error),
106}
107pub struct RequestBuilder {
108    parts: http::request::Parts,
109}
110
111impl Deref for RequestBuilder {
112    type Target = http::request::Parts;
113    fn deref(&self) -> &Self::Target {
114        &self.parts
115    }
116}
117
118impl DerefMut for RequestBuilder {
119    fn deref_mut(&mut self) -> &mut Self::Target {
120        &mut self.parts
121    }
122}
123
124macro_rules! http_methods {
125    ($fn: ident  $method: expr) => {
126        pub fn $fn<T>(uri: T) -> Result<Self, BuildRequestError>
127        where
128            T: TryInto<Uri>,
129            <T as TryInto<Uri>>::Error: Into<BuildRequestError>,
130        {
131            let mut this = Self::new();
132            this.parts.method = $method;
133            this.parts.uri = uri.try_into().map_err(Into::into)?;
134            Ok(this)
135        }
136    };
137}
138
139impl Default for RequestBuilder {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145impl RequestBuilder {
146    pub fn new() -> Self {
147        let (parts, _) = http::Request::new(()).into_parts();
148        Self { parts }
149    }
150    pub fn uri(mut self, uri: Uri) -> Self {
151        self.parts.uri = uri;
152        self
153    }
154    http_methods!(get http::Method::GET);
155    http_methods!(post http::Method::POST);
156    http_methods!(put http::Method::PUT);
157    http_methods!(delete http::Method::DELETE);
158    http_methods!(head http::Method::HEAD);
159    http_methods!(patch http::Method::PATCH);
160    http_methods!(options http::Method::OPTIONS);
161    http_methods!(trace http::Method::TRACE);
162    http_methods!(connect http::Method::CONNECT);
163
164    pub fn method(mut self, method: http::Method) -> Self {
165        self.parts.method = method;
166        self
167    }
168    pub fn version(mut self, version: http::Version) -> Self {
169        self.parts.version = version;
170        self
171    }
172    pub fn body<B>(self, body: B) -> Result<Request<B>, BuildRequestError>
173    where
174        B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
175        B::Error: Into<crate::error::BoxError>,
176    {
177        let request = Request::from_parts(self.parts, body);
178        Ok(request)
179    }
180    #[cfg(feature = "json")]
181    #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
182    pub fn json<T: Serialize + ?Sized>(
183        self,
184        body: &T,
185    ) -> Result<Request<Full<Bytes>>, BuildRequestError> {
186        let json_body = serde_json::to_vec(&body)?;
187        let mut parts = self.parts;
188        parts.headers.insert(
189            CONTENT_TYPE,
190            HeaderValue::from_static(mime::APPLICATION_JSON.as_ref()),
191        );
192        let request = Request::from_parts(parts, Full::new(Bytes::from(json_body)));
193        Ok(request)
194    }
195    #[cfg(feature = "multipart")]
196    #[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
197    pub fn multipart(
198        self,
199        mut form: multipart::Form,
200    ) -> Result<Request<crate::Body>, BuildRequestError> {
201        let mut parts = self.parts;
202        let boundary = form.boundary();
203        parts.headers.insert(
204            CONTENT_TYPE,
205            HeaderValue::from_str(&format!(
206                "{}; boundary={}",
207                mime::MULTIPART_FORM_DATA,
208                boundary
209            ))
210            .map_err(BuildMultipartError::from)?,
211        );
212        if let Some(length) = form.compute_length() {
213            parts.headers.insert(
214                http::header::CONTENT_LENGTH,
215                HeaderValue::from_str(&length.to_string())
216                    .expect("content length is always valid HeaderValue"),
217            );
218        }
219        let body = form.stream();
220        Ok(Request::from_parts(parts, body))
221    }
222    /// Set the request body as form data.
223    #[cfg(feature = "form")]
224    #[cfg_attr(docsrs, doc(cfg(feature = "form")))]
225    pub fn form<T: Serialize + ?Sized>(
226        mut self,
227        form: &T,
228    ) -> Result<Request<Full<Bytes>>, BuildRequestError> {
229        let body = serde_urlencoded::to_string(form).map_err(BuildFormError::from)?;
230        self.parts.headers.insert(
231            CONTENT_TYPE,
232            HeaderValue::from_static(mime::APPLICATION_WWW_FORM_URLENCODED.as_ref()),
233        );
234        Ok(Request::from_parts(self.parts, full(body)))
235    }
236    pub fn plain_text(self, body: impl Into<Bytes>) -> Request<Full<Bytes>> {
237        Request::from_parts(self.parts, full(body))
238    }
239    pub fn empty(self) -> Request<Empty<Bytes>> {
240        Request::from_parts(self.parts, empty())
241    }
242    #[cfg(feature = "query")]
243    #[cfg_attr(docsrs, doc(cfg(feature = "query")))]
244    pub fn query<Q: Serialize + ?Sized>(mut self, query: &Q) -> Result<Self, BuildRequestError> {
245        self.parts.uri = build_query_uri(self.parts.uri, query)?;
246        Ok(self)
247    }
248    pub fn path(mut self, path: impl AsRef<str>) -> Result<Self, BuildRequestError> {
249        let path = path.as_ref();
250        self.parts.uri = build_path_uri(self.parts.uri, path)?;
251        Ok(self)
252    }
253    pub fn headers(mut self, header_map: http::header::HeaderMap) -> Self {
254        self.parts.headers.extend(header_map);
255        self
256    }
257    pub fn header<V>(
258        mut self,
259        key: impl http::header::IntoHeaderName,
260        value: V,
261    ) -> Result<Self, BuildRequestError>
262    where
263        V: TryInto<HeaderValue>,
264        V::Error: Into<BuildRequestError>,
265    {
266        self.parts
267            .headers
268            .insert(key, value.try_into().map_err(Into::into)?);
269        Ok(self)
270    }
271    #[cfg(feature = "auth")]
272    #[cfg_attr(docsrs, doc(cfg(feature = "auth")))]
273    pub fn basic_auth<U, P>(self, username: U, password: Option<P>) -> Self
274    where
275        U: std::fmt::Display,
276        P: std::fmt::Display,
277        Self: Sized,
278    {
279        let header_value = crate::util::basic_auth(username, password);
280        self.header(http::header::AUTHORIZATION, header_value)
281            .expect("base64 should always be a valid header value")
282    }
283    #[cfg(feature = "auth")]
284    #[cfg_attr(docsrs, doc(cfg(feature = "auth")))]
285    pub fn bearer_auth<T>(self, token: T) -> Self
286    where
287        T: std::fmt::Display,
288    {
289        let header_value = crate::util::bearer_auth(token);
290        self.header(http::header::AUTHORIZATION, header_value)
291            .expect("base64 should always be a valid header value")
292    }
293}
294
295/// Extension trait for [`http::Request`].
296pub trait RequestExt<B>: Sized {
297    fn with_version(self, version: http::Version) -> Request<B>;
298    fn with_method(self, method: http::Method) -> Request<B>;
299    fn with_header<K>(self, key: K, value: http::header::HeaderValue) -> Request<B>
300    where
301        K: http::header::IntoHeaderName;
302    fn with_headers(self, header_map: http::header::HeaderMap) -> Request<B>;
303    #[cfg(feature = "auth")]
304    #[cfg_attr(docsrs, doc(cfg(feature = "auth")))]
305    fn with_basic_auth<U, P>(self, username: U, password: Option<P>) -> Request<B>
306    where
307        U: std::fmt::Display,
308        P: std::fmt::Display,
309        Self: Sized,
310    {
311        let header_value = crate::util::basic_auth(username, password);
312        self.with_header(http::header::AUTHORIZATION, header_value)
313    }
314
315    #[cfg(feature = "auth")]
316    #[cfg_attr(docsrs, doc(cfg(feature = "auth")))]
317    fn with_bearer_auth<T>(self, token: T) -> Request<B>
318    where
319        T: std::fmt::Display,
320    {
321        let header_value = crate::util::bearer_auth(token);
322        self.with_header(http::header::AUTHORIZATION, header_value)
323    }
324
325    fn send<S, R>(self, client: S) -> impl Future<Output = crate::Result<S::Response>> + Send
326    where
327        B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
328        B::Error: Into<crate::error::BoxError>,
329        S: tower_service::Service<Request<ClientBody>, Response = Response<R>> + Send + Sync,
330        R: http_body::Body + Send + Sync + 'static,
331        <S as tower_service::Service<Request<ClientBody>>>::Error: Into<crate::error::BoxError>,
332        <S as tower_service::Service<Request<ClientBody>>>::Future: Send;
333}
334
335impl<B> RequestExt<B> for Request<B>
336where
337    B: Send,
338{
339    /// Set the request HTTP version.
340    #[inline]
341    fn with_version(mut self, version: http::Version) -> Request<B> {
342        *self.version_mut() = version;
343        self
344    }
345
346    /// Set the request method.
347    #[inline]
348    fn with_method(mut self, method: http::Method) -> Request<B> {
349        *self.method_mut() = method;
350        self
351    }
352
353    /*
354    // I think we may don't need to modify schema?
355    fn schema(self, schema: Scheme) -> crate::Result<Request<B>> {
356        let (mut parts, body) = self.into_parts();
357        let mut uri_parts = parts.uri.into_parts();
358        uri_parts.scheme = Some(schema);
359        parts.uri = Uri::from_parts(uri_parts).map_err(crate::Error::with_context(
360            "reconstruct uri with new schema",
361        ))?;
362        Ok(Request::from_parts(parts, body))
363    }
364    */
365    /// Set a request header.
366    #[inline]
367    fn with_header<K>(mut self, key: K, value: http::header::HeaderValue) -> Request<B>
368    where
369        K: http::header::IntoHeaderName,
370    {
371        self.headers_mut().insert(key, value);
372        self
373    }
374
375    /// Extend multiple request headers.
376    #[inline]
377    fn with_headers(mut self, header_map: http::header::HeaderMap) -> Request<B> {
378        self.headers_mut().extend(header_map);
379        self
380    }
381
382    /// Send the request to a service.
383    ///
384    /// If you enabled any decompression feature, the response body will be automatically decompressed.
385    #[allow(unused_mut)]
386    fn send<S, R>(self, mut client: S) -> impl Future<Output = crate::Result<S::Response>> + Send
387    where
388        B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
389        B::Error: Into<crate::error::BoxError>,
390        S: tower_service::Service<Request<ClientBody>, Response = Response<R>> + Send + Sync,
391        R: http_body::Body + Send + Sync + 'static,
392        <S as tower_service::Service<Request<ClientBody>>>::Error: Into<crate::error::BoxError>,
393        <S as tower_service::Service<Request<ClientBody>>>::Future: Send,
394    {
395        use http_body_util::BodyExt;
396        let request = self.map(|b| BoxBody::new(b.map_err(|e| e.into())));
397        client
398            .call(request)
399            .map_err(|e| crate::Error::SendRequest(e.into()))
400    }
401}
402
403#[cfg(feature = "query")]
404#[cfg_attr(docsrs, doc(cfg(feature = "query")))]
405fn build_query_uri<Q: Serialize + ?Sized>(uri: Uri, query: &Q) -> Result<Uri, BuildQueryError> {
406    use std::str::FromStr;
407    let new_query = serde_urlencoded::to_string(query)?;
408    if new_query.is_empty() {
409        return Ok(uri);
410    }
411    let mut uri_parts = uri.into_parts();
412    let new_pq = if let Some(pq) = uri_parts.path_and_query {
413        let mut new_pq_string = String::with_capacity(new_query.len() + pq.as_str().len() + 2);
414        new_pq_string.push_str(pq.path());
415        new_pq_string.push('?');
416        if let Some(old_query) = pq.query() {
417            new_pq_string.push_str(old_query);
418            new_pq_string.push('&');
419        }
420        new_pq_string.push_str(&new_query);
421
422        http::uri::PathAndQuery::from_str(&new_pq_string)?
423    } else {
424        http::uri::PathAndQuery::from_str(&new_query)?
425    };
426    uri_parts.path_and_query = Some(new_pq);
427    let new_uri = Uri::from_parts(uri_parts)?;
428    Ok(new_uri)
429}
430
431fn build_path_uri(uri: Uri, path: &str) -> Result<Uri, BuildPathError> {
432    let mut parts = uri.into_parts();
433    let Some(pq) = parts.path_and_query else {
434        parts.path_and_query = Some(PathAndQuery::from_str(path)?);
435        return Ok(Uri::from_parts(parts)?);
436    };
437    let query = pq.query();
438    let pq = if let Some(query) = query {
439        PathAndQuery::from_maybe_shared(format!("{path}?{query}"))?
440    } else {
441        PathAndQuery::from_str(path)?
442    };
443    parts.path_and_query = Some(pq);
444    let uri = Uri::from_parts(parts)?;
445    Ok(uri)
446}
447/*
448    I copied and modified those tests from reqwest: https://github.com/seanmonstar/reqwest/blob/master/src/async_impl/request.rs
449*/
450#[cfg(test)]
451mod tests {
452
453    use super::*;
454    use std::collections::BTreeMap;
455
456    #[test]
457    fn add_query_append() -> crate::Result<()> {
458        let req = RequestBuilder::get("https://google.com/")?
459            .query(&[("foo", "bar")])?
460            .query(&[("qux", 3)])?
461            .empty();
462
463        assert_eq!(req.uri().query(), Some("foo=bar&qux=3"));
464        Ok(())
465    }
466
467    #[test]
468    fn add_query_append_same() -> crate::Result<()> {
469        let req = RequestBuilder::get("https://google.com/")?
470            .query(&[("foo", "a"), ("foo", "b")])?
471            .empty();
472
473        assert_eq!(req.uri().query(), Some("foo=a&foo=b"));
474        Ok(())
475    }
476
477    #[test]
478    fn add_query_struct() -> crate::Result<()> {
479        #[derive(serde::Serialize)]
480        struct Params {
481            foo: String,
482            qux: i32,
483        }
484
485        let params = Params {
486            foo: "bar".into(),
487            qux: 3,
488        };
489        let req = RequestBuilder::get("https://google.com/")?
490            .query(&params)?
491            .empty();
492
493        assert_eq!(req.uri().query(), Some("foo=bar&qux=3"));
494        Ok(())
495    }
496
497    #[test]
498    fn add_query_map() -> crate::Result<()> {
499        let mut params = BTreeMap::new();
500        params.insert("foo", "bar");
501        params.insert("qux", "three");
502
503        let req = RequestBuilder::get("https://google.com/")?
504            .query(&params)?
505            .empty();
506        assert_eq!(req.uri().query(), Some("foo=bar&qux=three"));
507        Ok(())
508    }
509
510    #[test]
511    fn test_replace_headers() -> crate::Result<()> {
512        use http::HeaderMap;
513
514        let mut headers = HeaderMap::new();
515        headers.insert("foo", "bar".parse().unwrap());
516        headers.append("foo", "baz".parse().unwrap());
517
518        let req = RequestBuilder::get("https://hyper.rs")?
519            .header("im-a", "keeper")?
520            .header("foo", "pop me")?
521            .headers(headers)
522            .empty();
523
524        assert_eq!(req.headers()["im-a"], "keeper");
525
526        let foo = req.headers().get_all("foo").iter().collect::<Vec<_>>();
527        assert_eq!(foo.len(), 2);
528        assert_eq!(foo[0], "bar");
529        assert_eq!(foo[1], "baz");
530        Ok(())
531    }
532
533    #[test]
534    #[cfg(feature = "auth")]
535    #[cfg_attr(docsrs, doc(cfg(feature = "auth")))]
536    fn test_basic_auth_sensitive_header() -> crate::Result<()> {
537        let some_url = "https://localhost/";
538
539        let req = RequestBuilder::get(some_url)?
540            .basic_auth("Aladdin", Some("open sesame"))
541            .empty();
542
543        assert_eq!(req.uri().to_string(), "https://localhost/");
544        assert_eq!(
545            req.headers()["authorization"],
546            "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="
547        );
548        assert!(req.headers()["authorization"].is_sensitive());
549        Ok(())
550    }
551
552    #[test]
553    #[cfg(feature = "auth")]
554    #[cfg_attr(docsrs, doc(cfg(feature = "auth")))]
555    fn test_bearer_auth_sensitive_header() -> crate::Result<()> {
556        let some_url = "https://localhost/";
557
558        let req = RequestBuilder::get(some_url)?
559            .bearer_auth("Hold my bear")
560            .empty();
561
562        assert_eq!(req.uri().to_string(), "https://localhost/");
563        assert_eq!(req.headers()["authorization"], "Bearer Hold my bear");
564        assert!(req.headers()["authorization"].is_sensitive());
565        Ok(())
566    }
567
568    #[test]
569    fn test_explicit_sensitive_header() -> crate::Result<()> {
570        let some_url = "https://localhost/";
571
572        let mut header = http::HeaderValue::from_static("in plain sight");
573        header.set_sensitive(true);
574
575        let req = RequestBuilder::get(some_url)?.header("hiding", header)?;
576
577        assert_eq!(req.uri.to_string(), "https://localhost/");
578        assert_eq!(req.headers["hiding"], "in plain sight");
579        assert!(req.headers["hiding"].is_sensitive());
580        Ok(())
581    }
582
583    #[test]
584    fn convert_from_http_request() -> crate::Result<()> {
585        let req = Request::builder()
586            .method("GET")
587            .uri("http://localhost/")
588            .header("User-Agent", "my-awesome-agent/1.0")
589            .body("test test test")
590            .unwrap();
591        let test_data = b"test test test";
592        assert_eq!(req.body().as_bytes(), &test_data[..]);
593        let headers = req.headers();
594        assert_eq!(headers.get("User-Agent").unwrap(), "my-awesome-agent/1.0");
595        assert_eq!(req.method(), http::Method::GET);
596        assert_eq!(req.uri().to_string(), "http://localhost/");
597        Ok(())
598    }
599
600    #[test]
601    fn set_http_request_version() -> crate::Result<()> {
602        let req = Request::builder()
603            .method("GET")
604            .uri("http://localhost/")
605            .header("User-Agent", "my-awesome-agent/1.0")
606            .version(http::Version::HTTP_11)
607            .body("test test test")
608            .unwrap();
609        let test_data = b"test test test";
610        assert_eq!(req.body().as_bytes(), &test_data[..]);
611        let headers = req.headers();
612        assert_eq!(headers.get("User-Agent").unwrap(), "my-awesome-agent/1.0");
613        assert_eq!(req.method(), http::Method::GET);
614        assert_eq!(req.uri().to_string(), "http://localhost/");
615        assert_eq!(req.version(), http::Version::HTTP_11);
616        Ok(())
617    }
618}