Skip to main content

ureq/
middleware.rs

1//! Chained interception to modify the request or response.
2
3use std::fmt;
4use std::sync::Arc;
5
6use crate::http;
7use crate::run::run;
8use crate::{Agent, Body, Error, SendBody};
9
10/// Chained processing of request (and response).
11///
12/// # Middleware as `fn`
13///
14/// The middleware trait is implemented for all functions that have the signature
15///
16/// `Fn(Request, MiddlewareNext) -> Result<Response, Error>`
17///
18/// That means the easiest way to implement middleware is by providing a `fn`, like so
19///
20/// ```
21/// use ureq::{Body, SendBody};
22/// use ureq::middleware::MiddlewareNext;
23/// use ureq::http::{Request, Response};
24///
25/// fn my_middleware(req: Request<SendBody>, next: MiddlewareNext)
26///     -> Result<Response<Body>, ureq::Error> {
27///
28///     // do middleware things to request
29///
30///     // continue the middleware chain
31///     let res = next.handle(req)?;
32///
33///     // do middleware things to response
34///
35///     Ok(res)
36/// }
37/// ```
38///
39/// # Adding headers
40///
41/// A common use case is to add headers to the outgoing request. Here an example of how.
42///
43/// ```no_run
44/// use ureq::{Body, SendBody, Agent, config::Config};
45/// use ureq::middleware::MiddlewareNext;
46/// use ureq::http::{Request, Response, header::HeaderValue};
47///
48/// # #[cfg(feature = "json")]
49/// # {
50/// fn my_middleware(mut req: Request<SendBody>, next: MiddlewareNext)
51///     -> Result<Response<Body>, ureq::Error> {
52///
53///     req.headers_mut().insert("X-My-Header", HeaderValue::from_static("value_42"));
54///
55///     // set my bespoke header and continue the chain
56///     next.handle(req)
57/// }
58///
59/// let mut config = Config::builder()
60///     .middleware(my_middleware)
61///     .build();
62///
63/// let agent: Agent = config.into();
64///
65/// let result: serde_json::Value =
66///     agent.get("http://httpbin.org/headers").call()?.body_mut().read_json()?;
67///
68/// assert_eq!(&result["headers"]["X-My-Header"], "value_42");
69/// # } Ok::<_, ureq::Error>(())
70/// ```
71///
72/// # State
73///
74/// To maintain state between middleware invocations, we need to do something more elaborate than
75/// the simple `fn` and implement the `Middleware` trait directly.
76///
77/// ## Example with mutex lock
78///
79/// In the `examples` directory there is an additional example `count-bytes.rs` which uses
80/// a mutex lock like shown below.
81///
82/// ```
83/// use std::sync::{Arc, Mutex};
84///
85/// use ureq::{Body, SendBody};
86/// use ureq::middleware::{Middleware, MiddlewareNext};
87/// use ureq::http::{Request, Response};
88///
89/// struct MyState {
90///     // whatever is needed
91/// }
92///
93/// struct MyMiddleware(Arc<Mutex<MyState>>);
94///
95/// impl Middleware for MyMiddleware {
96///     fn handle(&self, request: Request<SendBody>, next: MiddlewareNext)
97///         -> Result<Response<Body>, ureq::Error> {
98///
99///         // These extra brackets ensures we release the Mutex lock before continuing the
100///         // chain. There could also be scenarios where we want to maintain the lock through
101///         // the invocation, which would block other requests from proceeding concurrently
102///         // through the middleware.
103///         {
104///             let mut state = self.0.lock().unwrap();
105///             // do stuff with state
106///         }
107///
108///         // continue middleware chain
109///         next.handle(request)
110///     }
111/// }
112/// ```
113///
114/// ## Example with atomic
115///
116/// This example shows how we can increase a counter for each request going
117/// through the agent.
118///
119/// ```
120/// use ureq::{Body, SendBody, Agent, config::Config};
121/// use ureq::middleware::{Middleware, MiddlewareNext};
122/// use ureq::http::{Request, Response};
123/// use std::sync::atomic::{AtomicU64, Ordering};
124/// use std::sync::Arc;
125///
126/// // Middleware that stores a counter state. This example uses an AtomicU64
127/// // since the middleware is potentially shared by multiple threads running
128/// // requests at the same time.
129/// struct MyCounter(Arc<AtomicU64>);
130///
131/// impl Middleware for MyCounter {
132///     fn handle(&self, req: Request<SendBody>, next: MiddlewareNext)
133///         -> Result<Response<Body>, ureq::Error> {
134///
135///         // increase the counter for each invocation
136///         self.0.fetch_add(1, Ordering::Relaxed);
137///
138///         // continue the middleware chain
139///         next.handle(req)
140///     }
141/// }
142///
143/// let shared_counter = Arc::new(AtomicU64::new(0));
144///
145/// let mut config = Config::builder()
146///     .middleware(MyCounter(shared_counter.clone()))
147///     .build();
148///
149/// let agent: Agent = config.into();
150///
151/// agent.get("http://httpbin.org/get").call()?;
152/// agent.get("http://httpbin.org/get").call()?;
153///
154/// // Check we did indeed increase the counter twice.
155/// assert_eq!(shared_counter.load(Ordering::Relaxed), 2);
156///
157/// # Ok::<_, ureq::Error>(())
158/// ```
159pub trait Middleware: Send + Sync + 'static {
160    /// Handle of the middleware logic.
161    fn handle(
162        &self,
163        request: http::Request<SendBody>,
164        next: MiddlewareNext,
165    ) -> Result<http::Response<Body>, Error>;
166}
167
168#[derive(Clone, Default)]
169pub(crate) struct MiddlewareChain {
170    chain: Arc<Vec<Box<dyn Middleware>>>,
171}
172
173impl MiddlewareChain {
174    pub(crate) fn add(&mut self, mw: impl Middleware) {
175        let Some(chain) = Arc::get_mut(&mut self.chain) else {
176            panic!("Can't add to a MiddlewareChain that is already cloned")
177        };
178
179        chain.push(Box::new(mw));
180    }
181}
182
183/// Continuation of a [`Middleware`] chain.
184pub struct MiddlewareNext<'a> {
185    agent: &'a Agent,
186    index: usize,
187}
188
189impl<'a> MiddlewareNext<'a> {
190    pub(crate) fn new(agent: &'a Agent) -> Self {
191        MiddlewareNext { agent, index: 0 }
192    }
193
194    /// Continue the middleware chain.
195    ///
196    /// The middleware must call this in order to run the request. Not calling
197    /// it is a valid choice for not wanting the request to execute.
198    pub fn handle(
199        mut self,
200        request: http::Request<SendBody>,
201    ) -> Result<http::Response<Body>, Error> {
202        if let Some(mw) = self.agent.config().middleware.chain.get(self.index) {
203            // This middleware exists, run it.
204            self.index += 1;
205            mw.handle(request, self)
206        } else {
207            // When chain is over, call the main run().
208            let (parts, body) = request.into_parts();
209            let request = http::Request::from_parts(parts, ());
210            run(self.agent, request, body)
211        }
212    }
213}
214
215impl<F> Middleware for F
216where
217    F: Fn(http::Request<SendBody>, MiddlewareNext) -> Result<http::Response<Body>, Error>
218        + Send
219        + Sync
220        + 'static,
221{
222    fn handle(
223        &self,
224        request: http::Request<SendBody>,
225        next: MiddlewareNext,
226    ) -> Result<http::Response<Body>, Error> {
227        (self)(request, next)
228    }
229}
230
231impl fmt::Debug for MiddlewareChain {
232    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
233        f.debug_struct("MiddlewareChain")
234            .field("len", &self.chain.len())
235            .finish()
236    }
237}