reqwest_conditional_middleware/
lib.rs

1//! The only export of this crate is a struct [`ConditionalMiddleware`] for creating conditional middlewares.
2//! This struct implements the [`Middleware`][reqwest_middleware::Middleware] trait
3//! and forwards requests on to the middleware that it wraps.
4//!
5//! The conditional wrapper holds a closure that will be run for each request. If the
6//! closure returns true, then the inner middleware will run. Otherwise it will be
7//! skipped and the current request will be passed along to the next middleware.
8//!
9//! # Example
10//!
11//! Short-circuits a middleware stack and returns `OK` whenever the request method
12//! is `GET`
13//!
14//! ```
15//! use reqwest::{Request, Response};
16//! use reqwest_conditional_middleware::ConditionalMiddleware;
17//! use reqwest_middleware::{Middleware, Next, Result};
18//! use http::Extensions;
19//!
20//! struct AlwaysOk;
21//!
22//! #[async_trait::async_trait]
23//! impl Middleware for AlwaysOk {
24//!     async fn handle(
25//!         &self,
26//!         _req: Request,
27//!         _extensions: &mut Extensions,
28//!         _next: Next<'_>,
29//!     ) -> Result<Response> {
30//!         let builder = http::Response::builder().status(http::StatusCode::OK);
31//!         Ok(builder.body("").unwrap().into())
32//!     }
33//! }
34//!
35//! let conditional = ConditionalMiddleware::new(
36//!     AlwaysOk,
37//!     |req: &Request| req.method() == http::Method::GET
38//! );
39//!
40//! ```
41
42use async_trait::async_trait;
43use http::Extensions;
44use reqwest::{Request, Response};
45use reqwest_middleware::{Middleware, Next, Result};
46
47/// A struct for holding a [`Middleware`][reqwest_middleware::Middleware] T that will be
48/// run when C evaluates to true
49pub struct ConditionalMiddleware<T, C> {
50    inner: T,
51    condition: C,
52}
53
54impl<T, C> ConditionalMiddleware<T, C>
55where
56    T: Middleware,
57    C: Fn(&Request) -> bool + Send + Sync + 'static,
58{
59    /// Creates a new wrapped middleware. The function C will be run for each request to
60    /// determine if the wrapped middleware should be run.
61    pub fn new(inner: T, condition: C) -> Self {
62        Self { inner, condition }
63    }
64}
65
66#[async_trait]
67impl<T, C> Middleware for ConditionalMiddleware<T, C>
68where
69    T: Middleware,
70    C: Fn(&Request) -> bool + Send + Sync + 'static,
71{
72    async fn handle(
73        &self,
74        req: Request,
75        extensions: &mut Extensions,
76        next: Next<'_>,
77    ) -> Result<Response> {
78        let should_handle = (self.condition)(&req);
79
80        if should_handle {
81            self.inner.handle(req, extensions, next).await
82        } else {
83            next.run(req, extensions).await
84        }
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91    use http::StatusCode;
92    use reqwest::{Request, Response};
93    use std::sync::{Arc, Mutex};
94
95    struct End;
96
97    #[async_trait]
98    impl Middleware for End {
99        async fn handle(
100            &self,
101            _req: Request,
102            _extensions: &mut Extensions,
103            _next: Next<'_>,
104        ) -> Result<Response> {
105            let builder = http::Response::builder().status(StatusCode::OK);
106            let resp = builder.body("end").unwrap();
107            Ok(resp.into())
108        }
109    }
110
111    struct CheckMiddleware {
112        check: Arc<Mutex<bool>>,
113    }
114
115    impl CheckMiddleware {
116        fn new() -> Self {
117            Self {
118                check: Arc::new(Mutex::new(false)),
119            }
120        }
121
122        fn flip(&self) {
123            let value = *self.check.lock().unwrap();
124            *self.check.lock().unwrap() = !value;
125        }
126
127        fn checker(&self) -> Arc<Mutex<bool>> {
128            self.check.clone()
129        }
130    }
131
132    #[async_trait]
133    impl Middleware for CheckMiddleware {
134        async fn handle(
135            &self,
136            req: Request,
137            extensions: &mut Extensions,
138            next: Next<'_>,
139        ) -> Result<Response> {
140            self.flip();
141            next.run(req, extensions).await
142        }
143    }
144
145    #[tokio::test]
146    async fn test_runs_inner_middleware() {
147        let check = CheckMiddleware::new();
148        let test = check.checker();
149        let conditional = ConditionalMiddleware::new(check, |_req: &Request| true);
150        let request = reqwest::Request::new(http::Method::GET, "http://localhost".parse().unwrap());
151
152        let client =
153            reqwest_middleware::ClientBuilder::new(reqwest::Client::builder().build().unwrap())
154                .with(conditional)
155                .with(End)
156                .build();
157
158        let resp = client.execute(request).await.unwrap().text().await.unwrap();
159
160        assert_eq!("end", resp);
161        assert!(*test.lock().unwrap())
162    }
163
164    #[tokio::test]
165    async fn test_does_not_run_inner_middleware() {
166        let check = CheckMiddleware::new();
167        let test = check.checker();
168        let conditional = ConditionalMiddleware::new(check, |_req: &Request| false);
169        let request = reqwest::Request::new(http::Method::GET, "http://localhost".parse().unwrap());
170
171        let client =
172            reqwest_middleware::ClientBuilder::new(reqwest::Client::builder().build().unwrap())
173                .with(conditional)
174                .with(End)
175                .build();
176
177        let resp = client.execute(request).await.unwrap().text().await.unwrap();
178
179        assert_eq!("end", resp);
180        assert!(!*test.lock().unwrap())
181    }
182}