1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
//! The only export of this crate is a struct [`ConditionalMiddleware`] for creating conditional middlewares.
//! This struct implements the [`Middleware`][reqwest_middleware::Middleware] trait
//! and forwards requests on to the middleware that it wraps.
//!
//! The conditional wrapper holds a closure that will be run for each request. If the
//! closure returns true, then the inner middleware will run. Otherwise it will be
//! skipped and the current request will be passed along to the next middleware.
//!
//! # Example
//!
//! Short-circuits a middleware stack and returns `OK` whenever the request method
//! is `GET`
//!
//! ```
//! use reqwest::{Request, Response};
//! use reqwest_conditional_middleware::ConditionalMiddleware;
//! use reqwest_middleware::{Middleware, Next, Result};
//! use task_local_extensions::Extensions;
//!
//! struct AlwaysOk;
//!
//! #[async_trait::async_trait]
//! impl Middleware for AlwaysOk {
//!     async fn handle(
//!         &self,
//!         _req: Request,
//!         _extensions: &mut Extensions,
//!         _next: Next<'_>,
//!     ) -> Result<Response> {
//!         let builder = http::Response::builder().status(http::StatusCode::OK);
//!         Ok(builder.body("").unwrap().into())
//!     }
//! }
//!
//! let conditional = ConditionalMiddleware::new(
//!     AlwaysOk,
//!     |req: &Request| req.method() == http::Method::GET
//! );
//!
//! ```

use async_trait::async_trait;
use reqwest::{Request, Response};
use reqwest_middleware::{Middleware, Next, Result};
use task_local_extensions::Extensions;

/// A struct for holding a [`Middleware`][reqwest_middleware::Middleware] T that will be
/// run when C evaluates to true
pub struct ConditionalMiddleware<T, C> {
    inner: T,
    condition: C,
}

impl<T, C> ConditionalMiddleware<T, C>
where
    T: Middleware,
    C: Fn(&Request) -> bool + Send + Sync + 'static,
{
    /// Creates a new wrapped middleware. The function C will be run for each request to
    /// determine if the wrapped middleware should be run.
    pub fn new(inner: T, condition: C) -> Self {
        Self { inner, condition }
    }
}

#[async_trait]
impl<T, C> Middleware for ConditionalMiddleware<T, C>
where
    T: Middleware,
    C: Fn(&Request) -> bool + Send + Sync + 'static,
{
    async fn handle(
        &self,
        req: Request,
        extensions: &mut Extensions,
        next: Next<'_>,
    ) -> Result<Response> {
        let should_handle = (self.condition)(&req);

        if should_handle {
            self.inner.handle(req, extensions, next).await
        } else {
            next.run(req, extensions).await
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use http::StatusCode;
    use reqwest::{Request, Response};
    use std::sync::{Arc, Mutex};

    struct End;

    #[async_trait]
    impl Middleware for End {
        async fn handle(
            &self,
            _req: Request,
            _extensions: &mut Extensions,
            _next: Next<'_>,
        ) -> Result<Response> {
            let builder = http::Response::builder().status(StatusCode::OK);
            let resp = builder.body("end").unwrap();
            Ok(resp.into())
        }
    }

    struct CheckMiddleware {
        check: Arc<Mutex<bool>>,
    }

    impl CheckMiddleware {
        fn new() -> Self {
            Self {
                check: Arc::new(Mutex::new(false)),
            }
        }

        fn flip(&self) {
            let value = *self.check.lock().unwrap();
            *self.check.lock().unwrap() = !value;
        }

        fn checker(&self) -> Arc<Mutex<bool>> {
            self.check.clone()
        }
    }

    #[async_trait]
    impl Middleware for CheckMiddleware {
        async fn handle(
            &self,
            req: Request,
            extensions: &mut Extensions,
            next: Next<'_>,
        ) -> Result<Response> {
            self.flip();
            next.run(req, extensions).await
        }
    }

    #[tokio::test]
    async fn test_runs_inner_middleware() {
        let check = CheckMiddleware::new();
        let test = check.checker();
        let conditional = ConditionalMiddleware::new(check, |_req: &Request| true);
        let request = reqwest::Request::new(http::Method::GET, "http://localhost".parse().unwrap());

        let client =
            reqwest_middleware::ClientBuilder::new(reqwest::Client::builder().build().unwrap())
                .with(conditional)
                .with(End)
                .build();

        let resp = client.execute(request).await.unwrap().text().await.unwrap();

        assert_eq!("end", resp);
        assert!(*test.lock().unwrap())
    }

    #[tokio::test]
    async fn test_does_not_run_inner_middleware() {
        let check = CheckMiddleware::new();
        let test = check.checker();
        let conditional = ConditionalMiddleware::new(check, |_req: &Request| false);
        let request = reqwest::Request::new(http::Method::GET, "http://localhost".parse().unwrap());

        let client =
            reqwest_middleware::ClientBuilder::new(reqwest::Client::builder().build().unwrap())
                .with(conditional)
                .with(End)
                .build();

        let resp = client.execute(request).await.unwrap().text().await.unwrap();

        assert_eq!("end", resp);
        assert!(!*test.lock().unwrap())
    }
}