httpclient/request/
builder.rs

1use std::future::IntoFuture;
2use std::str::FromStr;
3use std::sync::Arc;
4
5use futures::future::BoxFuture;
6use http::header::{Entry, HeaderName, ACCEPT, AUTHORIZATION, CONTENT_TYPE, COOKIE};
7use http::uri::PathAndQuery;
8use http::{header, HeaderMap, HeaderValue, Method, Uri, Version};
9use serde::Serialize;
10use serde_json::Value;
11
12use crate::error::ProtocolResult;
13use crate::middleware::Next;
14use crate::multipart::Form;
15use crate::{Client, Error, InMemoryBody, InMemoryResponse, Middleware, Request, Response};
16
17pub static ACCEPT_JSON: HeaderValue = HeaderValue::from_static("application/json");
18pub static CONTENT_JSON: HeaderValue = HeaderValue::from_static("application/json; charset=utf-8");
19pub static CONTENT_URL_ENCODED: HeaderValue = HeaderValue::from_static("application/x-www-form-urlencoded");
20
21/// Provide a custom request builder for several reasons:
22/// - The required reason is have it implement IntoFuture, so that it can be directly awaited.
23/// - The secondary reasons is directly storing client & middlewares on the RequestBuilder. In
24///   theory it could be stored on Request.extensions, but that's less explicit.
25/// - It's also nice to not require implementing an Extension trait to get all the convenience methods
26///   on http::request::RequestBuilder
27///
28/// Middlewares are used in order (first to last).
29#[derive(Debug)]
30pub struct RequestBuilder<'a, C = Client, B = InMemoryBody> {
31    client: &'a C,
32
33    pub version: Version,
34    pub method: Method,
35    pub uri: Uri,
36    pub headers: HeaderMap,
37    pub body: Option<B>,
38    pub middlewares: Vec<Arc<dyn Middleware>>,
39}
40
41impl<'a> RequestBuilder<'a, ()> {
42    pub fn get(url: &str) -> RequestBuilder<'a, ()> {
43        RequestBuilder::new(&(), Method::GET, Uri::from_str(url).expect("Invalid URL"))
44    }
45    pub fn post(url: &str) -> RequestBuilder<'a, ()> {
46        RequestBuilder::new(&(), Method::POST, Uri::from_str(url).expect("Invalid URL"))
47    }
48    pub fn put(url: &str) -> RequestBuilder<'a, ()> {
49        RequestBuilder::new(&(), Method::PUT, Uri::from_str(url).expect("Invalid URL"))
50    }
51    pub fn delete(url: &str) -> RequestBuilder<'a, ()> {
52        RequestBuilder::new(&(), Method::DELETE, Uri::from_str(url).expect("Invalid URL"))
53    }
54    pub fn head(url: &str) -> RequestBuilder<'a, ()> {
55        RequestBuilder::new(&(), Method::HEAD, Uri::from_str(url).expect("Invalid URL"))
56    }
57}
58
59impl<'a, C> RequestBuilder<'a, C> {
60    pub fn new(client: &'a C, method: Method, uri: Uri) -> RequestBuilder<'a, C, InMemoryBody> {
61        RequestBuilder {
62            client,
63            version: Default::default(),
64            method,
65            uri,
66            headers: Default::default(),
67            body: Default::default(),
68            middlewares: Default::default(),
69        }
70    }
71
72    #[must_use]
73    pub fn form<S: Serialize>(mut self, obj: S) -> Self {
74        match self.body {
75            None => {
76                let body = serde_qs::to_string(&obj).unwrap();
77                self.body = Some(InMemoryBody::Text(body));
78                self.headers.entry(CONTENT_TYPE).or_insert(CONTENT_URL_ENCODED.clone());
79                self.headers.entry(ACCEPT).or_insert(HeaderValue::from_static("html/text"));
80                self
81            }
82            Some(InMemoryBody::Text(ref mut body)) => {
83                let new_body = serde_qs::to_string(&obj).unwrap();
84                body.push('&');
85                body.push_str(&new_body);
86                self
87            }
88            _ => {
89                panic!("Cannot add form to non-form body");
90            }
91        }
92    }
93
94    /// Overwrite the current body with the provided JSON object.
95    #[must_use]
96    pub fn set_json<S: Serialize>(mut self, obj: S) -> Self {
97        self.body = Some(InMemoryBody::Json(serde_json::to_value(obj).unwrap()));
98        self.headers.entry(CONTENT_TYPE).or_insert(CONTENT_JSON.clone());
99        self.headers.entry(ACCEPT).or_insert(ACCEPT_JSON.clone());
100        self
101    }
102
103    /// Add the provided JSON object to the current body.
104    #[must_use]
105    pub fn json<S: Serialize>(mut self, obj: S) -> Self {
106        match self.body {
107            None => self.set_json(obj),
108            Some(InMemoryBody::Json(Value::Object(ref mut body))) => {
109                if let Value::Object(obj) = serde_json::to_value(obj).unwrap() {
110                    body.extend(obj);
111                } else {
112                    panic!("Tried to push a non-object to a json body.");
113                }
114                self
115            }
116            _ => panic!("Tried to call .json() on a non-json body. Use .set_json if you need to force a json body."),
117        }
118    }
119
120    /// Sets content-type to `application/octet-stream` and the body to the supplied bytes.
121    #[must_use]
122    pub fn bytes(mut self, bytes: Vec<u8>) -> Self {
123        // self.headers.insert(CONTENT_LENGTH, HeaderValue::from(bytes.len()));
124        self.body = Some(InMemoryBody::Bytes(bytes));
125        self.headers.entry(CONTENT_TYPE).or_insert(HeaderValue::from_static("application/octet-stream"));
126        self
127    }
128
129    /// Sets content-type to `text/plain` and the body to the supplied text.
130    #[must_use]
131    pub fn text(mut self, text: String) -> Self {
132        // self.headers.insert(CONTENT_LENGTH, HeaderValue::from(text.len()));
133        self.body = Some(InMemoryBody::Text(text));
134        self.headers.entry(CONTENT_TYPE).or_insert(HeaderValue::from_static("text/plain"));
135        self
136    }
137
138    #[must_use]
139    pub fn multipart<B>(mut self, form: Form<B>) -> Self
140    where
141        Form<B>: Into<Vec<u8>>,
142    {
143        let content_type = form.full_content_type();
144        self.headers.entry(CONTENT_TYPE).or_insert(content_type.parse().unwrap());
145        let body: Vec<u8> = form.into();
146        // let len = body.len();
147        match String::from_utf8(body) {
148            Ok(text) => self.body = Some(InMemoryBody::Text(text)),
149            Err(bytes) => self.body = Some(InMemoryBody::Bytes(bytes.into_bytes())),
150        }
151        // self.headers.insert(CONTENT_LENGTH, HeaderValue::from(len));
152        self
153    }
154}
155
156impl<'a> RequestBuilder<'a> {
157    /// There are two ways to trigger the request. Immediately using `.await` will call the `IntoFuture` implementation
158    /// which also awaits the body. If you want to await them separately, use this method `.send()`
159    pub async fn send(self) -> ProtocolResult<Response> {
160        let client = self.client;
161        let (request, middlewares) = self.into_req_and_middleware();
162        let next = Next {
163            client,
164            middlewares: &middlewares,
165        };
166        next.run(request).await
167    }
168}
169
170impl<'a, C, B: Default> RequestBuilder<'a, C, B> {
171    pub fn build(self) -> Request<B> {
172        let mut b = Request::builder().method(self.method).uri(self.uri).version(self.version);
173        *b.headers_mut().unwrap() = self.headers;
174        b.body(self.body.unwrap_or_default()).expect("Failed to build request in .build")
175    }
176
177    pub fn into_req_and_middleware(self) -> (Request<B>, Vec<Arc<dyn Middleware>>) {
178        let mut request = http::Request::builder().method(self.method).uri(self.uri).version(self.version);
179        *request.headers_mut().unwrap() = self.headers;
180        let request = request.body(self.body.unwrap_or_default().into()).unwrap();
181        (request, self.middlewares)
182    }
183}
184
185impl<'a, C, B> RequestBuilder<'a, C, B> {
186    pub fn for_client(client: &'a C) -> RequestBuilder<'a, C> {
187        RequestBuilder {
188            client,
189            version: Default::default(),
190            method: Default::default(),
191            uri: Default::default(),
192            headers: Default::default(),
193            body: Default::default(),
194            middlewares: Default::default(),
195        }
196    }
197
198    #[must_use]
199    pub fn method(mut self, method: Method) -> Self {
200        self.method = method;
201        self
202    }
203
204    #[must_use]
205    pub fn url(mut self, uri: &str) -> Self {
206        self.uri = Uri::from_str(uri).expect("Invalid URI");
207        self
208    }
209
210    #[must_use]
211    pub fn set_headers<S: AsRef<str>, I: Iterator<Item = (S, S)>>(mut self, headers: I) -> Self {
212        self.headers = HeaderMap::new();
213        self.headers(headers)
214    }
215
216    #[must_use]
217    pub fn headers<S: AsRef<str>, I: Iterator<Item = (S, S)>>(mut self, headers: I) -> Self {
218        self.headers
219            .extend(headers.map(|(k, v)| (HeaderName::from_str(k.as_ref()).unwrap(), HeaderValue::from_str(v.as_ref()).unwrap())));
220        self
221    }
222
223    #[must_use]
224    pub fn header<K: TryInto<HeaderName>>(mut self, key: K, value: &str) -> Self
225    where
226        <K as TryInto<HeaderName>>::Error: std::fmt::Debug,
227    {
228        let header = key.try_into().expect("Failed to convert key to HeaderName");
229        self.headers.insert(header, HeaderValue::from_str(value).unwrap());
230        self
231    }
232
233    #[must_use]
234    pub fn cookie(mut self, key: &str, value: &str) -> Self {
235        match self.headers.entry(COOKIE) {
236            Entry::Occupied(mut e) => {
237                let v = e.get_mut();
238                *v = HeaderValue::from_str(&format!("{}; {}={}", v.to_str().unwrap(), key, value)).unwrap();
239            }
240            Entry::Vacant(_) => {
241                let value = HeaderValue::from_str(&format!("{key}={value}")).unwrap();
242                self.headers.insert(COOKIE, value);
243            }
244        }
245        self
246    }
247
248    #[must_use]
249    pub fn bearer_auth(mut self, token: &str) -> Self {
250        self.headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {token}")).unwrap());
251        self
252    }
253
254    #[must_use]
255    pub fn token_auth(mut self, token: &str) -> Self {
256        self.headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Token {token}")).unwrap());
257        self
258    }
259
260    #[must_use]
261    pub fn basic_auth(mut self, token: &str) -> Self {
262        self.headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Basic {token}")).unwrap());
263        self
264    }
265
266    /// Overwrite the query with the provided value.
267    #[must_use]
268    pub fn set_query<S: Serialize>(mut self, obj: S) -> Self {
269        let qs = serde_qs::to_string(&obj).expect("Failed to serialize query in .set_query");
270        let mut parts = std::mem::take(&mut self.uri).into_parts();
271        let pq = parts.path_and_query.unwrap();
272        let pq = PathAndQuery::from_str(&format!("{}?{}", pq.path(), qs)).unwrap();
273        parts.path_and_query = Some(pq);
274        self.uri = Uri::from_parts(parts).unwrap();
275        self
276    }
277
278    /// Add a url query parameter, but keep existing parameters.
279    /// # Examples
280    /// ```
281    /// use httpclient::{Client, RequestBuilder, Method};
282    /// let client = Client::new();
283    /// let mut r = RequestBuilder::new(&client, Method::GET, "http://example.com/foo?a=1".parse().unwrap());
284    /// r = r.query("b", "2");
285    /// assert_eq!(r.uri.to_string(), "http://example.com/foo?a=1&b=2");
286    /// ```
287    #[must_use]
288    pub fn query(mut self, k: &str, v: &str) -> Self {
289        let mut parts = std::mem::take(&mut self.uri).into_parts();
290        let pq = parts.path_and_query.unwrap();
291        let pq = PathAndQuery::from_str(
292            match pq.query() {
293                Some(q) => format!("{}?{}&{}={}", pq.path(), q, urlencoding::encode(k), urlencoding::encode(v)),
294                None => format!("{}?{}={}", pq.path(), urlencoding::encode(k), urlencoding::encode(v)),
295            }
296            .as_str(),
297        )
298        .unwrap();
299        parts.path_and_query = Some(pq);
300        self.uri = Uri::from_parts(parts).unwrap();
301        self
302    }
303
304    #[must_use]
305    pub fn content_type(mut self, content_type: &str) -> Self {
306        self.headers.insert(header::CONTENT_TYPE, content_type.parse().unwrap());
307        self
308    }
309
310    /// Warning: Does not set content-type!
311    #[must_use]
312    pub fn body(mut self, body: B) -> Self {
313        self.body = Some(body);
314        self
315    }
316
317    #[must_use]
318    pub fn set_middlewares(mut self, middlewares: Vec<Arc<dyn Middleware>>) -> Self {
319        self.middlewares = middlewares;
320        self
321    }
322
323    #[must_use]
324    pub fn middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
325        self.middlewares.push(middleware);
326        self
327    }
328}
329
330impl<'a> IntoFuture for RequestBuilder<'a, Client> {
331    type Output = crate::InMemoryResult<InMemoryResponse>;
332    type IntoFuture = BoxFuture<'a, Self::Output>;
333
334    fn into_future(self) -> Self::IntoFuture {
335        Box::pin(async move {
336            let res = self.send().await;
337            let res = match res {
338                Ok(res) => res,
339                Err(e) => return Err(e.into()),
340            };
341            let (parts, body) = res.into_parts();
342            let mut body = match body.into_memory().await {
343                Ok(body) => body,
344                Err(e) => return Err(e.into()),
345            };
346            if let InMemoryBody::Bytes(bytes) = body {
347                body = match String::from_utf8(bytes) {
348                    Ok(text) => InMemoryBody::Text(text),
349                    Err(e) => InMemoryBody::Bytes(e.into_bytes()),
350                };
351            }
352            let status = &parts.status;
353            if status.is_client_error() || status.is_server_error() {
354                // Prevents us from showing bytes to end users in error situations.
355                Err(Error::HttpError(InMemoryResponse::from_parts(parts, body)))
356            } else {
357                Ok(InMemoryResponse::from_parts(parts, body))
358            }
359        })
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use serde::{Deserialize, Serialize};
366
367    use super::*;
368
369    #[derive(Serialize, Deserialize)]
370    pub struct TopLevel {
371        inside: Nested,
372    }
373
374    #[derive(Serialize, Deserialize)]
375    pub struct Nested {
376        a: usize,
377    }
378
379    #[test]
380    fn test_query() {
381        let c = Client::new();
382        let qs = TopLevel { inside: Nested { a: 1 } };
383        let r = c.get("/api").set_query(qs).build();
384        assert_eq!(r.uri().to_string(), "/api?inside[a]=1");
385    }
386}