Skip to main content

reqwest_middleware/
client.rs

1use http::Extensions;
2use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
3use reqwest::{Body, Client, IntoUrl, Method, Request, Response};
4use std::convert::TryFrom;
5use std::fmt::{self, Display};
6use std::sync::Arc;
7
8#[cfg(feature = "multipart")]
9use reqwest::multipart;
10
11use crate::error::Result;
12use crate::middleware::{Middleware, Next};
13use crate::RequestInitialiser;
14
15/// A `ClientBuilder` is used to build a [`ClientWithMiddleware`].
16///
17/// [`ClientWithMiddleware`]: crate::ClientWithMiddleware
18pub struct ClientBuilder {
19    client: Client,
20    middleware_stack: Vec<Arc<dyn Middleware>>,
21    initialiser_stack: Vec<Arc<dyn RequestInitialiser>>,
22}
23
24impl ClientBuilder {
25    pub fn new(client: Client) -> Self {
26        ClientBuilder {
27            client,
28            middleware_stack: Vec::new(),
29            initialiser_stack: Vec::new(),
30        }
31    }
32
33    /// This method allows creating a ClientBuilder
34    /// from an existing ClientWithMiddleware instance
35    pub fn from_client(client_with_middleware: ClientWithMiddleware) -> Self {
36        Self {
37            client: client_with_middleware.inner,
38            middleware_stack: client_with_middleware.middleware_stack.into_vec(),
39            initialiser_stack: client_with_middleware.initialiser_stack.into_vec(),
40        }
41    }
42
43    /// Convenience method to attach middleware.
44    ///
45    /// If you need to keep a reference to the middleware after attaching, use [`with_arc`].
46    ///
47    /// [`with_arc`]: Self::with_arc
48    pub fn with<M>(self, middleware: M) -> Self
49    where
50        M: Middleware,
51    {
52        self.with_arc(Arc::new(middleware))
53    }
54
55    /// Add middleware to the chain. [`with`] is more ergonomic if you don't need the `Arc`.
56    ///
57    /// [`with`]: Self::with
58    pub fn with_arc(mut self, middleware: Arc<dyn Middleware>) -> Self {
59        self.middleware_stack.push(middleware);
60        self
61    }
62
63    /// Convenience method to attach a request initialiser.
64    ///
65    /// If you need to keep a reference to the initialiser after attaching, use [`with_arc_init`].
66    ///
67    /// [`with_arc_init`]: Self::with_arc_init
68    pub fn with_init<I>(self, initialiser: I) -> Self
69    where
70        I: RequestInitialiser,
71    {
72        self.with_arc_init(Arc::new(initialiser))
73    }
74
75    /// Add a request initialiser to the chain. [`with_init`] is more ergonomic if you don't need the `Arc`.
76    ///
77    /// [`with_init`]: Self::with_init
78    pub fn with_arc_init(mut self, initialiser: Arc<dyn RequestInitialiser>) -> Self {
79        self.initialiser_stack.push(initialiser);
80        self
81    }
82
83    /// Returns a `ClientWithMiddleware` using this builder configuration.
84    pub fn build(self) -> ClientWithMiddleware {
85        ClientWithMiddleware {
86            inner: self.client,
87            middleware_stack: self.middleware_stack.into_boxed_slice(),
88            initialiser_stack: self.initialiser_stack.into_boxed_slice(),
89        }
90    }
91}
92
93/// `ClientWithMiddleware` is a wrapper around [`reqwest::Client`] which runs middleware on every
94/// request.
95#[derive(Clone, Default)]
96pub struct ClientWithMiddleware {
97    inner: reqwest::Client,
98    middleware_stack: Box<[Arc<dyn Middleware>]>,
99    initialiser_stack: Box<[Arc<dyn RequestInitialiser>]>,
100}
101
102impl ClientWithMiddleware {
103    /// See [`ClientBuilder`] for a more ergonomic way to build `ClientWithMiddleware` instances.
104    pub fn new<T>(client: Client, middleware_stack: T) -> Self
105    where
106        T: Into<Box<[Arc<dyn Middleware>]>>,
107    {
108        ClientWithMiddleware {
109            inner: client,
110            middleware_stack: middleware_stack.into(),
111            // TODO(conradludgate) - allow downstream code to control this manually if desired
112            initialiser_stack: Box::new([]),
113        }
114    }
115
116    /// Convenience method to make a `GET` request to a URL.
117    ///
118    /// # Errors
119    ///
120    /// This method fails whenever the supplied `Url` cannot be parsed.
121    pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder {
122        self.request(Method::GET, url)
123    }
124
125    /// Convenience method to make a `POST` request to a URL.
126    ///
127    /// # Errors
128    ///
129    /// This method fails whenever the supplied `Url` cannot be parsed.
130    pub fn post<U: IntoUrl>(&self, url: U) -> RequestBuilder {
131        self.request(Method::POST, url)
132    }
133
134    /// Convenience method to make a `PUT` request to a URL.
135    ///
136    /// # Errors
137    ///
138    /// This method fails whenever the supplied `Url` cannot be parsed.
139    pub fn put<U: IntoUrl>(&self, url: U) -> RequestBuilder {
140        self.request(Method::PUT, url)
141    }
142
143    /// Convenience method to make a `PATCH` request to a URL.
144    ///
145    /// # Errors
146    ///
147    /// This method fails whenever the supplied `Url` cannot be parsed.
148    pub fn patch<U: IntoUrl>(&self, url: U) -> RequestBuilder {
149        self.request(Method::PATCH, url)
150    }
151
152    /// Convenience method to make a `DELETE` request to a URL.
153    ///
154    /// # Errors
155    ///
156    /// This method fails whenever the supplied `Url` cannot be parsed.
157    pub fn delete<U: IntoUrl>(&self, url: U) -> RequestBuilder {
158        self.request(Method::DELETE, url)
159    }
160
161    /// Convenience method to make a `HEAD` request to a URL.
162    ///
163    /// # Errors
164    ///
165    /// This method fails whenever the supplied `Url` cannot be parsed.
166    pub fn head<U: IntoUrl>(&self, url: U) -> RequestBuilder {
167        self.request(Method::HEAD, url)
168    }
169
170    /// Start building a `Request` with the `Method` and `Url`.
171    ///
172    /// Returns a `RequestBuilder`, which will allow setting headers and
173    /// the request body before sending.
174    ///
175    /// # Errors
176    ///
177    /// This method fails whenever the supplied `Url` cannot be parsed.
178    pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
179        let req = RequestBuilder {
180            inner: self.inner.request(method, url),
181            extensions: Extensions::new(),
182            middleware_stack: self.middleware_stack.clone(),
183            initialiser_stack: self.initialiser_stack.clone(),
184        };
185        self.initialiser_stack
186            .iter()
187            .fold(req, |req, i| i.init(req))
188    }
189
190    /// Executes a `Request`.
191    ///
192    /// A `Request` can be built manually with `Request::new()` or obtained
193    /// from a RequestBuilder with `RequestBuilder::build()`.
194    ///
195    /// You should prefer to use the `RequestBuilder` and
196    /// `RequestBuilder::send()`.
197    ///
198    /// # Errors
199    ///
200    /// This method fails if there was an error while sending request,
201    /// redirect loop was detected or redirect limit was exhausted.
202    pub async fn execute(&self, req: Request) -> Result<Response> {
203        let mut ext = Extensions::new();
204        self.execute_with_extensions(req, &mut ext).await
205    }
206
207    /// Executes a `Request` with initial [`Extensions`].
208    ///
209    /// A `Request` can be built manually with `Request::new()` or obtained
210    /// from a RequestBuilder with `RequestBuilder::build()`.
211    ///
212    /// You should prefer to use the `RequestBuilder` and
213    /// `RequestBuilder::send()`.
214    ///
215    /// # Errors
216    ///
217    /// This method fails if there was an error while sending request,
218    /// redirect loop was detected or redirect limit was exhausted.
219    pub async fn execute_with_extensions(
220        &self,
221        req: Request,
222        ext: &mut Extensions,
223    ) -> Result<Response> {
224        let next = Next::new(&self.inner, &self.middleware_stack);
225        next.run(req, ext).await
226    }
227}
228
229/// Create a `ClientWithMiddleware` without any middleware.
230impl From<Client> for ClientWithMiddleware {
231    fn from(client: Client) -> Self {
232        ClientWithMiddleware {
233            inner: client,
234            middleware_stack: Box::new([]),
235            initialiser_stack: Box::new([]),
236        }
237    }
238}
239
240impl fmt::Debug for ClientWithMiddleware {
241    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
242        // skipping middleware_stack field for now
243        f.debug_struct("ClientWithMiddleware")
244            .field("inner", &self.inner)
245            .finish_non_exhaustive()
246    }
247}
248
249// Implementing AsRef<Client> for ClientWithMiddleware.
250//
251// This allows to use ClientWithMiddleware as a reqwest::Client.
252impl AsRef<Client> for ClientWithMiddleware {
253    fn as_ref(&self) -> &Client {
254        &self.inner
255    }
256}
257
258#[cfg(not(target_arch = "wasm32"))]
259mod service {
260    use std::{
261        future::Future,
262        pin::Pin,
263        task::{Context, Poll},
264    };
265
266    use crate::Result;
267    use http::Extensions;
268    use reqwest::{Request, Response};
269
270    use crate::{middleware::BoxFuture, ClientWithMiddleware, Next};
271
272    // this is meant to be semi-private, same as reqwest's pending
273    pub struct Pending {
274        inner: BoxFuture<'static, Result<Response>>,
275    }
276
277    impl Unpin for Pending {}
278
279    impl Future for Pending {
280        type Output = Result<Response>;
281
282        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
283            self.inner.as_mut().poll(cx)
284        }
285    }
286
287    impl tower_service::Service<Request> for ClientWithMiddleware {
288        type Response = Response;
289        type Error = crate::Error;
290        type Future = Pending;
291
292        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
293            self.inner.poll_ready(cx).map_err(crate::Error::Reqwest)
294        }
295
296        fn call(&mut self, req: Request) -> Self::Future {
297            let inner = self.inner.clone();
298            let middlewares = self.middleware_stack.clone();
299            let mut extensions = Extensions::new();
300            Pending {
301                inner: Box::pin(async move {
302                    let next = Next::new(&inner, &middlewares);
303                    next.run(req, &mut extensions).await
304                }),
305            }
306        }
307    }
308
309    impl tower_service::Service<Request> for &'_ ClientWithMiddleware {
310        type Response = Response;
311        type Error = crate::Error;
312        type Future = Pending;
313
314        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
315            (&self.inner).poll_ready(cx).map_err(crate::Error::Reqwest)
316        }
317
318        fn call(&mut self, req: Request) -> Self::Future {
319            let inner = self.inner.clone();
320            let middlewares = self.middleware_stack.clone();
321            let mut extensions = Extensions::new();
322            Pending {
323                inner: Box::pin(async move {
324                    let next = Next::new(&inner, &middlewares);
325                    next.run(req, &mut extensions).await
326                }),
327            }
328        }
329    }
330}
331
332/// This is a wrapper around [`reqwest::RequestBuilder`] exposing the same API.
333#[must_use = "RequestBuilder does nothing until you 'send' it"]
334pub struct RequestBuilder {
335    inner: reqwest::RequestBuilder,
336    middleware_stack: Box<[Arc<dyn Middleware>]>,
337    initialiser_stack: Box<[Arc<dyn RequestInitialiser>]>,
338    extensions: Extensions,
339}
340
341impl RequestBuilder {
342    /// Assemble a builder starting from an existing `Client` and a `Request`.
343    pub fn from_parts(client: ClientWithMiddleware, request: Request) -> RequestBuilder {
344        let inner = reqwest::RequestBuilder::from_parts(client.inner, request);
345        RequestBuilder {
346            inner,
347            middleware_stack: client.middleware_stack,
348            initialiser_stack: client.initialiser_stack,
349            extensions: Extensions::new(),
350        }
351    }
352
353    /// Add a `Header` to this Request.
354    pub fn header<K, V>(self, key: K, value: V) -> Self
355    where
356        HeaderName: TryFrom<K>,
357        <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
358        HeaderValue: TryFrom<V>,
359        <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
360    {
361        RequestBuilder {
362            inner: self.inner.header(key, value),
363            ..self
364        }
365    }
366
367    /// Add a set of Headers to the existing ones on this Request.
368    ///
369    /// The headers will be merged in to any already set.
370    pub fn headers(self, headers: HeaderMap) -> Self {
371        RequestBuilder {
372            inner: self.inner.headers(headers),
373            ..self
374        }
375    }
376
377    #[cfg(not(target_arch = "wasm32"))]
378    pub fn version(self, version: reqwest::Version) -> Self {
379        RequestBuilder {
380            inner: self.inner.version(version),
381            ..self
382        }
383    }
384
385    /// Enable HTTP basic authentication.
386    ///
387    /// ```rust
388    /// # use anyhow::Error;
389    ///
390    /// # async fn run() -> Result<(), Error> {
391    /// let client = reqwest_middleware::ClientWithMiddleware::from(reqwest::Client::new());
392    /// let resp = client.delete("http://httpbin.org/delete")
393    ///     .basic_auth("admin", Some("good password"))
394    ///     .send()
395    ///     .await?;
396    /// # Ok(())
397    /// # }
398    /// ```
399    pub fn basic_auth<U, P>(self, username: U, password: Option<P>) -> Self
400    where
401        U: Display,
402        P: Display,
403    {
404        RequestBuilder {
405            inner: self.inner.basic_auth(username, password),
406            ..self
407        }
408    }
409
410    /// Enable HTTP bearer authentication.
411    pub fn bearer_auth<T>(self, token: T) -> Self
412    where
413        T: Display,
414    {
415        RequestBuilder {
416            inner: self.inner.bearer_auth(token),
417            ..self
418        }
419    }
420
421    /// Set the request body.
422    pub fn body<T: Into<Body>>(self, body: T) -> Self {
423        RequestBuilder {
424            inner: self.inner.body(body),
425            ..self
426        }
427    }
428
429    /// Enables a request timeout.
430    ///
431    /// The timeout is applied from when the request starts connecting until the
432    /// response body has finished. It affects only this request and overrides
433    /// the timeout configured using `ClientBuilder::timeout()`.
434    #[cfg(not(target_arch = "wasm32"))]
435    pub fn timeout(self, timeout: std::time::Duration) -> Self {
436        RequestBuilder {
437            inner: self.inner.timeout(timeout),
438            ..self
439        }
440    }
441
442    #[cfg(feature = "multipart")]
443    #[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
444    pub fn multipart(self, multipart: multipart::Form) -> Self {
445        RequestBuilder {
446            inner: self.inner.multipart(multipart),
447            ..self
448        }
449    }
450
451    #[cfg(feature = "query")]
452    /// Modify the query string of the URL.
453    ///
454    /// Modifies the URL of this request, adding the parameters provided.
455    /// This method appends and does not overwrite. This means that it can
456    /// be called multiple times and that existing query parameters are not
457    /// overwritten if the same key is used. The key will simply show up
458    /// twice in the query string.
459    /// Calling `.query(&[("foo", "a"), ("foo", "b")])` gives `"foo=a&foo=b"`.
460    ///
461    /// # Note
462    /// This method does not support serializing a single key-value
463    /// pair. Instead of using `.query(("key", "val"))`, use a sequence, such
464    /// as `.query(&[("key", "val")])`. It's also possible to serialize structs
465    /// and maps into a key-value pair.
466    ///
467    /// # Errors
468    /// This method will fail if the object you provide cannot be serialized
469    /// into a query string.
470    pub fn query<T: serde::Serialize + ?Sized>(self, query: &T) -> Self {
471        RequestBuilder {
472            inner: self.inner.query(query),
473            ..self
474        }
475    }
476
477    #[cfg(feature = "form")]
478    /// Send a form body.
479    ///
480    /// Sets the body to the url encoded serialization of the passed value,
481    /// and also sets the `Content-Type: application/x-www-form-urlencoded`
482    /// header.
483    ///
484    /// ```rust
485    /// # use anyhow::Error;
486    /// # use std::collections::HashMap;
487    /// #
488    /// # async fn run() -> Result<(), Error> {
489    /// let mut params = HashMap::new();
490    /// params.insert("lang", "rust");
491    ///
492    /// let client = reqwest_middleware::ClientWithMiddleware::from(reqwest::Client::new());
493    /// let res = client.post("http://httpbin.org")
494    ///     .form(&params)
495    ///     .send()
496    ///     .await?;
497    /// # Ok(())
498    /// # }
499    /// ```
500    ///
501    /// # Errors
502    ///
503    /// This method fails if the passed value cannot be serialized into
504    /// url encoded format
505    pub fn form<T: serde::Serialize + ?Sized>(self, form: &T) -> Self {
506        RequestBuilder {
507            inner: self.inner.form(form),
508            ..self
509        }
510    }
511
512    /// Send a JSON body.
513    ///
514    /// # Optional
515    ///
516    /// This requires the optional `json` feature enabled.
517    ///
518    /// # Errors
519    ///
520    /// Serialization can fail if `T`'s implementation of `Serialize` decides to
521    /// fail, or if `T` contains a map with non-string keys.
522    #[cfg(feature = "json")]
523    #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
524    pub fn json<T: serde::Serialize + ?Sized>(self, json: &T) -> Self {
525        RequestBuilder {
526            inner: self.inner.json(json),
527            ..self
528        }
529    }
530
531    /// Build a `Request`, which can be inspected, modified and executed with
532    /// `ClientWithMiddleware::execute()`.
533    pub fn build(self) -> reqwest::Result<Request> {
534        self.inner.build()
535    }
536
537    /// Build a `Request`, which can be inspected, modified and executed with
538    /// `ClientWithMiddleware::execute()`.
539    ///
540    /// This is similar to [`RequestBuilder::build()`], but also returns the
541    /// embedded `Client`.
542    pub fn build_split(self) -> (ClientWithMiddleware, reqwest::Result<Request>) {
543        let Self {
544            inner,
545            middleware_stack,
546            initialiser_stack,
547            ..
548        } = self;
549        let (inner, req) = inner.build_split();
550        let client = ClientWithMiddleware {
551            inner,
552            middleware_stack,
553            initialiser_stack,
554        };
555        (client, req)
556    }
557
558    /// Inserts the extension into this request builder
559    pub fn with_extension<T: Send + Sync + Clone + 'static>(mut self, extension: T) -> Self {
560        self.extensions.insert(extension);
561        self
562    }
563
564    /// Returns a mutable reference to the internal set of extensions for this request
565    pub fn extensions(&mut self) -> &mut Extensions {
566        &mut self.extensions
567    }
568
569    /// Constructs the Request and sends it to the target URL, returning a
570    /// future Response.
571    ///
572    /// # Errors
573    ///
574    /// This method fails if there was an error while sending request,
575    /// redirect loop was detected or redirect limit was exhausted.
576    ///
577    /// # Example
578    ///
579    /// ```no_run
580    /// # use anyhow::Error;
581    /// #
582    /// # async fn run() -> Result<(), Error> {
583    /// let response = reqwest_middleware::ClientWithMiddleware::from(reqwest::Client::new())
584    ///     .get("https://hyper.rs")
585    ///     .send()
586    ///     .await?;
587    /// # Ok(())
588    /// # }
589    /// ```
590    pub async fn send(mut self) -> Result<Response> {
591        let mut extensions = std::mem::take(self.extensions());
592        let (client, req) = self.build_split();
593        client.execute_with_extensions(req?, &mut extensions).await
594    }
595
596    /// Attempt to clone the RequestBuilder.
597    ///
598    /// `None` is returned if the RequestBuilder can not be cloned,
599    /// i.e. if the request body is a stream.
600    ///
601    /// # Examples
602    ///
603    /// ```
604    /// # use reqwest::Error;
605    /// #
606    /// # fn run() -> Result<(), Error> {
607    /// let client = reqwest_middleware::ClientWithMiddleware::from(reqwest::Client::new());
608    /// let builder = client.post("http://httpbin.org/post")
609    ///     .body("from a &str!");
610    /// let clone = builder.try_clone();
611    /// assert!(clone.is_some());
612    /// # Ok(())
613    /// # }
614    /// ```
615    pub fn try_clone(&self) -> Option<Self> {
616        self.inner.try_clone().map(|inner| RequestBuilder {
617            inner,
618            middleware_stack: self.middleware_stack.clone(),
619            initialiser_stack: self.initialiser_stack.clone(),
620            extensions: self.extensions.clone(),
621        })
622    }
623}
624
625impl fmt::Debug for RequestBuilder {
626    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
627        // skipping middleware_stack field for now
628        f.debug_struct("RequestBuilder")
629            .field("inner", &self.inner)
630            .finish_non_exhaustive()
631    }
632}