ureq/middleware.rs
1//! Chained interception to modify the request or response.
2
3use std::fmt;
4use std::sync::Arc;
5
6use crate::http;
7use crate::run::run;
8use crate::{Agent, Body, Error, SendBody};
9
10/// Chained processing of request (and response).
11///
12/// # Middleware as `fn`
13///
14/// The middleware trait is implemented for all functions that have the signature
15///
16/// `Fn(Request, MiddlewareNext) -> Result<Response, Error>`
17///
18/// That means the easiest way to implement middleware is by providing a `fn`, like so
19///
20/// ```
21/// use ureq::{Body, SendBody};
22/// use ureq::middleware::MiddlewareNext;
23/// use ureq::http::{Request, Response};
24///
25/// fn my_middleware(req: Request<SendBody>, next: MiddlewareNext)
26/// -> Result<Response<Body>, ureq::Error> {
27///
28/// // do middleware things to request
29///
30/// // continue the middleware chain
31/// let res = next.handle(req)?;
32///
33/// // do middleware things to response
34///
35/// Ok(res)
36/// }
37/// ```
38///
39/// # Adding headers
40///
41/// A common use case is to add headers to the outgoing request. Here an example of how.
42///
43/// ```no_run
44/// use ureq::{Body, SendBody, Agent, config::Config};
45/// use ureq::middleware::MiddlewareNext;
46/// use ureq::http::{Request, Response, header::HeaderValue};
47///
48/// # #[cfg(feature = "json")]
49/// # {
50/// fn my_middleware(mut req: Request<SendBody>, next: MiddlewareNext)
51/// -> Result<Response<Body>, ureq::Error> {
52///
53/// req.headers_mut().insert("X-My-Header", HeaderValue::from_static("value_42"));
54///
55/// // set my bespoke header and continue the chain
56/// next.handle(req)
57/// }
58///
59/// let mut config = Config::builder()
60/// .middleware(my_middleware)
61/// .build();
62///
63/// let agent: Agent = config.into();
64///
65/// let result: serde_json::Value =
66/// agent.get("http://httpbin.org/headers").call()?.body_mut().read_json()?;
67///
68/// assert_eq!(&result["headers"]["X-My-Header"], "value_42");
69/// # } Ok::<_, ureq::Error>(())
70/// ```
71///
72/// # State
73///
74/// To maintain state between middleware invocations, we need to do something more elaborate than
75/// the simple `fn` and implement the `Middleware` trait directly.
76///
77/// ## Example with mutex lock
78///
79/// In the `examples` directory there is an additional example `count-bytes.rs` which uses
80/// a mutex lock like shown below.
81///
82/// ```
83/// use std::sync::{Arc, Mutex};
84///
85/// use ureq::{Body, SendBody};
86/// use ureq::middleware::{Middleware, MiddlewareNext};
87/// use ureq::http::{Request, Response};
88///
89/// struct MyState {
90/// // whatever is needed
91/// }
92///
93/// struct MyMiddleware(Arc<Mutex<MyState>>);
94///
95/// impl Middleware for MyMiddleware {
96/// fn handle(&self, request: Request<SendBody>, next: MiddlewareNext)
97/// -> Result<Response<Body>, ureq::Error> {
98///
99/// // These extra brackets ensures we release the Mutex lock before continuing the
100/// // chain. There could also be scenarios where we want to maintain the lock through
101/// // the invocation, which would block other requests from proceeding concurrently
102/// // through the middleware.
103/// {
104/// let mut state = self.0.lock().unwrap();
105/// // do stuff with state
106/// }
107///
108/// // continue middleware chain
109/// next.handle(request)
110/// }
111/// }
112/// ```
113///
114/// ## Example with atomic
115///
116/// This example shows how we can increase a counter for each request going
117/// through the agent.
118///
119/// ```
120/// use ureq::{Body, SendBody, Agent, config::Config};
121/// use ureq::middleware::{Middleware, MiddlewareNext};
122/// use ureq::http::{Request, Response};
123/// use std::sync::atomic::{AtomicU64, Ordering};
124/// use std::sync::Arc;
125///
126/// // Middleware that stores a counter state. This example uses an AtomicU64
127/// // since the middleware is potentially shared by multiple threads running
128/// // requests at the same time.
129/// struct MyCounter(Arc<AtomicU64>);
130///
131/// impl Middleware for MyCounter {
132/// fn handle(&self, req: Request<SendBody>, next: MiddlewareNext)
133/// -> Result<Response<Body>, ureq::Error> {
134///
135/// // increase the counter for each invocation
136/// self.0.fetch_add(1, Ordering::Relaxed);
137///
138/// // continue the middleware chain
139/// next.handle(req)
140/// }
141/// }
142///
143/// let shared_counter = Arc::new(AtomicU64::new(0));
144///
145/// let mut config = Config::builder()
146/// .middleware(MyCounter(shared_counter.clone()))
147/// .build();
148///
149/// let agent: Agent = config.into();
150///
151/// agent.get("http://httpbin.org/get").call()?;
152/// agent.get("http://httpbin.org/get").call()?;
153///
154/// // Check we did indeed increase the counter twice.
155/// assert_eq!(shared_counter.load(Ordering::Relaxed), 2);
156///
157/// # Ok::<_, ureq::Error>(())
158/// ```
159pub trait Middleware: Send + Sync + 'static {
160 /// Handle of the middleware logic.
161 fn handle(
162 &self,
163 request: http::Request<SendBody>,
164 next: MiddlewareNext,
165 ) -> Result<http::Response<Body>, Error>;
166}
167
168#[derive(Clone, Default)]
169pub(crate) struct MiddlewareChain {
170 chain: Arc<Vec<Box<dyn Middleware>>>,
171}
172
173impl MiddlewareChain {
174 pub(crate) fn add(&mut self, mw: impl Middleware) {
175 let Some(chain) = Arc::get_mut(&mut self.chain) else {
176 panic!("Can't add to a MiddlewareChain that is already cloned")
177 };
178
179 chain.push(Box::new(mw));
180 }
181}
182
183/// Continuation of a [`Middleware`] chain.
184pub struct MiddlewareNext<'a> {
185 agent: &'a Agent,
186 index: usize,
187}
188
189impl<'a> MiddlewareNext<'a> {
190 pub(crate) fn new(agent: &'a Agent) -> Self {
191 MiddlewareNext { agent, index: 0 }
192 }
193
194 /// Continue the middleware chain.
195 ///
196 /// The middleware must call this in order to run the request. Not calling
197 /// it is a valid choice for not wanting the request to execute.
198 pub fn handle(
199 mut self,
200 request: http::Request<SendBody>,
201 ) -> Result<http::Response<Body>, Error> {
202 if let Some(mw) = self.agent.config().middleware.chain.get(self.index) {
203 // This middleware exists, run it.
204 self.index += 1;
205 mw.handle(request, self)
206 } else {
207 // When chain is over, call the main run().
208 let (parts, body) = request.into_parts();
209 let request = http::Request::from_parts(parts, ());
210 run(self.agent, request, body)
211 }
212 }
213}
214
215impl<F> Middleware for F
216where
217 F: Fn(http::Request<SendBody>, MiddlewareNext) -> Result<http::Response<Body>, Error>
218 + Send
219 + Sync
220 + 'static,
221{
222 fn handle(
223 &self,
224 request: http::Request<SendBody>,
225 next: MiddlewareNext,
226 ) -> Result<http::Response<Body>, Error> {
227 (self)(request, next)
228 }
229}
230
231impl fmt::Debug for MiddlewareChain {
232 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
233 f.debug_struct("MiddlewareChain")
234 .field("len", &self.chain.len())
235 .finish()
236 }
237}