prax_query/middleware/
chain.rs1use super::context::QueryContext;
4use super::types::{
5 BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse, SharedMiddleware,
6};
7use std::sync::Arc;
8
9pub struct MiddlewareChain {
16 middlewares: Vec<SharedMiddleware>,
17}
18
19impl MiddlewareChain {
20 pub fn new() -> Self {
22 Self {
23 middlewares: Vec::new(),
24 }
25 }
26
27 pub fn with(middlewares: Vec<SharedMiddleware>) -> Self {
29 Self { middlewares }
30 }
31
32 pub fn push<M: Middleware + 'static>(&mut self, middleware: M) {
34 self.middlewares.push(Arc::new(middleware));
35 }
36
37 pub fn prepend<M: Middleware + 'static>(&mut self, middleware: M) {
39 self.middlewares.insert(0, Arc::new(middleware));
40 }
41
42 pub fn len(&self) -> usize {
44 self.middlewares.len()
45 }
46
47 pub fn is_empty(&self) -> bool {
49 self.middlewares.is_empty()
50 }
51
52 pub fn execute<'a, F>(
54 &'a self,
55 ctx: QueryContext,
56 final_handler: F,
57 ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>>
58 where
59 F: FnOnce(QueryContext) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> + Send + 'a,
60 {
61 self.execute_at(0, ctx, final_handler)
62 }
63
64 fn execute_at<'a, F>(
65 &'a self,
66 index: usize,
67 ctx: QueryContext,
68 final_handler: F,
69 ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>>
70 where
71 F: FnOnce(QueryContext) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> + Send + 'a,
72 {
73 if index >= self.middlewares.len() {
74 return final_handler(ctx);
76 }
77
78 let middleware = &self.middlewares[index];
79
80 if !middleware.enabled() {
82 return self.execute_at(index + 1, ctx, final_handler);
83 }
84
85 ({
88 Box::pin(async move {
91 middleware
94 .handle(
95 ctx,
96 Next {
97 inner: Box::new(move |ctx| {
98 final_handler(ctx)
101 }),
102 },
103 )
104 .await
105 })
106 }) as _
107 }
108}
109
110impl Default for MiddlewareChain {
111 fn default() -> Self {
112 Self::new()
113 }
114}
115
116pub struct MiddlewareStack {
120 chain: MiddlewareChain,
121}
122
123impl MiddlewareStack {
124 pub fn new() -> Self {
126 Self {
127 chain: MiddlewareChain::new(),
128 }
129 }
130
131 pub fn with<M: Middleware + 'static>(mut self, middleware: M) -> Self {
133 self.chain.push(middleware);
134 self
135 }
136
137 pub fn push<M: Middleware + 'static>(&mut self, middleware: M) -> &mut Self {
139 self.chain.push(middleware);
140 self
141 }
142
143 pub fn prepend<M: Middleware + 'static>(&mut self, middleware: M) -> &mut Self {
145 self.chain.prepend(middleware);
146 self
147 }
148
149 pub fn len(&self) -> usize {
151 self.chain.len()
152 }
153
154 pub fn is_empty(&self) -> bool {
156 self.chain.is_empty()
157 }
158
159 pub fn into_chain(self) -> MiddlewareChain {
161 self.chain
162 }
163
164 pub fn execute<'a, F>(
166 &'a self,
167 ctx: QueryContext,
168 final_handler: F,
169 ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>>
170 where
171 F: FnOnce(QueryContext) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> + Send + 'a,
172 {
173 self.chain.execute(ctx, final_handler)
174 }
175}
176
177impl Default for MiddlewareStack {
178 fn default() -> Self {
179 Self::new()
180 }
181}
182
183impl From<MiddlewareStack> for MiddlewareChain {
184 fn from(stack: MiddlewareStack) -> Self {
185 stack.chain
186 }
187}
188
189pub struct MiddlewareBuilder {
191 middlewares: Vec<SharedMiddleware>,
192}
193
194impl MiddlewareBuilder {
195 pub fn new() -> Self {
197 Self {
198 middlewares: Vec::new(),
199 }
200 }
201
202 pub fn with<M: Middleware + 'static>(mut self, middleware: M) -> Self {
204 self.middlewares.push(Arc::new(middleware));
205 self
206 }
207
208 pub fn with_if<M: Middleware + 'static>(self, condition: bool, middleware: M) -> Self {
210 if condition {
211 self.with(middleware)
212 } else {
213 self
214 }
215 }
216
217 pub fn build(self) -> MiddlewareChain {
219 MiddlewareChain::with(self.middlewares)
220 }
221
222 pub fn build_stack(self) -> MiddlewareStack {
224 MiddlewareStack {
225 chain: self.build(),
226 }
227 }
228}
229
230impl Default for MiddlewareBuilder {
231 fn default() -> Self {
232 Self::new()
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn test_middleware_chain_empty() {
242 let chain = MiddlewareChain::new();
243 assert!(chain.is_empty());
244 assert_eq!(chain.len(), 0);
245 }
246
247 #[test]
248 fn test_middleware_stack_builder() {
249 struct DummyMiddleware;
250 impl Middleware for DummyMiddleware {
251 fn handle<'a>(
252 &'a self,
253 ctx: QueryContext,
254 next: Next<'a>,
255 ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
256 Box::pin(async move { next.run(ctx).await })
257 }
258 }
259
260 let stack = MiddlewareStack::new()
261 .with(DummyMiddleware)
262 .with(DummyMiddleware);
263
264 assert_eq!(stack.len(), 2);
265 }
266
267 #[test]
268 fn test_middleware_builder() {
269 struct TestMiddleware;
270 impl Middleware for TestMiddleware {
271 fn handle<'a>(
272 &'a self,
273 ctx: QueryContext,
274 next: Next<'a>,
275 ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
276 Box::pin(async move { next.run(ctx).await })
277 }
278 }
279
280 let chain = MiddlewareBuilder::new()
281 .with(TestMiddleware)
282 .with_if(true, TestMiddleware)
283 .with_if(false, TestMiddleware)
284 .build();
285
286 assert_eq!(chain.len(), 2);
287 }
288}