reqwest_conditional_middleware/
lib.rs1use async_trait::async_trait;
43use http::Extensions;
44use reqwest::{Request, Response};
45use reqwest_middleware::{Middleware, Next, Result};
46
47pub struct ConditionalMiddleware<T, C> {
50 inner: T,
51 condition: C,
52}
53
54impl<T, C> ConditionalMiddleware<T, C>
55where
56 T: Middleware,
57 C: Fn(&Request) -> bool + Send + Sync + 'static,
58{
59 pub fn new(inner: T, condition: C) -> Self {
62 Self { inner, condition }
63 }
64}
65
66#[async_trait]
67impl<T, C> Middleware for ConditionalMiddleware<T, C>
68where
69 T: Middleware,
70 C: Fn(&Request) -> bool + Send + Sync + 'static,
71{
72 async fn handle(
73 &self,
74 req: Request,
75 extensions: &mut Extensions,
76 next: Next<'_>,
77 ) -> Result<Response> {
78 let should_handle = (self.condition)(&req);
79
80 if should_handle {
81 self.inner.handle(req, extensions, next).await
82 } else {
83 next.run(req, extensions).await
84 }
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91 use http::StatusCode;
92 use reqwest::{Request, Response};
93 use std::sync::{Arc, Mutex};
94
95 struct End;
96
97 #[async_trait]
98 impl Middleware for End {
99 async fn handle(
100 &self,
101 _req: Request,
102 _extensions: &mut Extensions,
103 _next: Next<'_>,
104 ) -> Result<Response> {
105 let builder = http::Response::builder().status(StatusCode::OK);
106 let resp = builder.body("end").unwrap();
107 Ok(resp.into())
108 }
109 }
110
111 struct CheckMiddleware {
112 check: Arc<Mutex<bool>>,
113 }
114
115 impl CheckMiddleware {
116 fn new() -> Self {
117 Self {
118 check: Arc::new(Mutex::new(false)),
119 }
120 }
121
122 fn flip(&self) {
123 let value = *self.check.lock().unwrap();
124 *self.check.lock().unwrap() = !value;
125 }
126
127 fn checker(&self) -> Arc<Mutex<bool>> {
128 self.check.clone()
129 }
130 }
131
132 #[async_trait]
133 impl Middleware for CheckMiddleware {
134 async fn handle(
135 &self,
136 req: Request,
137 extensions: &mut Extensions,
138 next: Next<'_>,
139 ) -> Result<Response> {
140 self.flip();
141 next.run(req, extensions).await
142 }
143 }
144
145 #[tokio::test]
146 async fn test_runs_inner_middleware() {
147 let check = CheckMiddleware::new();
148 let test = check.checker();
149 let conditional = ConditionalMiddleware::new(check, |_req: &Request| true);
150 let request = reqwest::Request::new(http::Method::GET, "http://localhost".parse().unwrap());
151
152 let client =
153 reqwest_middleware::ClientBuilder::new(reqwest::Client::builder().build().unwrap())
154 .with(conditional)
155 .with(End)
156 .build();
157
158 let resp = client.execute(request).await.unwrap().text().await.unwrap();
159
160 assert_eq!("end", resp);
161 assert!(*test.lock().unwrap())
162 }
163
164 #[tokio::test]
165 async fn test_does_not_run_inner_middleware() {
166 let check = CheckMiddleware::new();
167 let test = check.checker();
168 let conditional = ConditionalMiddleware::new(check, |_req: &Request| false);
169 let request = reqwest::Request::new(http::Method::GET, "http://localhost".parse().unwrap());
170
171 let client =
172 reqwest_middleware::ClientBuilder::new(reqwest::Client::builder().build().unwrap())
173 .with(conditional)
174 .with(End)
175 .build();
176
177 let resp = client.execute(request).await.unwrap().text().await.unwrap();
178
179 assert_eq!("end", resp);
180 assert!(!*test.lock().unwrap())
181 }
182}