prax_query/middleware/
chain.rs1use super::context::QueryContext;
4use super::types::{
5 BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse, SharedMiddleware,
6};
7use std::sync::Arc;
8
9pub struct MiddlewareChain {
16 middlewares: Vec<SharedMiddleware>,
17}
18
19impl MiddlewareChain {
20 pub fn new() -> Self {
22 Self {
23 middlewares: Vec::new(),
24 }
25 }
26
27 pub fn with(middlewares: Vec<SharedMiddleware>) -> Self {
29 Self { middlewares }
30 }
31
32 pub fn push<M: Middleware + 'static>(&mut self, middleware: M) {
34 self.middlewares.push(Arc::new(middleware));
35 }
36
37 pub fn prepend<M: Middleware + 'static>(&mut self, middleware: M) {
39 self.middlewares.insert(0, Arc::new(middleware));
40 }
41
42 pub fn len(&self) -> usize {
44 self.middlewares.len()
45 }
46
47 pub fn is_empty(&self) -> bool {
49 self.middlewares.is_empty()
50 }
51
52 pub fn execute<'a, F>(
54 &'a self,
55 ctx: QueryContext,
56 final_handler: F,
57 ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>>
58 where
59 F: FnOnce(QueryContext) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> + Send + 'a,
60 {
61 self.execute_at(0, ctx, final_handler)
62 }
63
64 fn execute_at<'a, F>(
65 &'a self,
66 index: usize,
67 ctx: QueryContext,
68 final_handler: F,
69 ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>>
70 where
71 F: FnOnce(QueryContext) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> + Send + 'a,
72 {
73 if index >= self.middlewares.len() {
74 return final_handler(ctx);
76 }
77
78 let middleware = &self.middlewares[index];
79
80 if !middleware.enabled() {
82 return self.execute_at(index + 1, ctx, final_handler);
83 }
84
85 let next = {
87 Box::pin(async move {
90 middleware
93 .handle(
94 ctx,
95 Next {
96 inner: Box::new(move |ctx| {
97 final_handler(ctx)
100 }),
101 },
102 )
103 .await
104 })
105 };
106
107 next
108 }
109}
110
111impl Default for MiddlewareChain {
112 fn default() -> Self {
113 Self::new()
114 }
115}
116
117pub struct MiddlewareStack {
121 chain: MiddlewareChain,
122}
123
124impl MiddlewareStack {
125 pub fn new() -> Self {
127 Self {
128 chain: MiddlewareChain::new(),
129 }
130 }
131
132 pub fn with<M: Middleware + 'static>(mut self, middleware: M) -> Self {
134 self.chain.push(middleware);
135 self
136 }
137
138 pub fn push<M: Middleware + 'static>(&mut self, middleware: M) -> &mut Self {
140 self.chain.push(middleware);
141 self
142 }
143
144 pub fn prepend<M: Middleware + 'static>(&mut self, middleware: M) -> &mut Self {
146 self.chain.prepend(middleware);
147 self
148 }
149
150 pub fn len(&self) -> usize {
152 self.chain.len()
153 }
154
155 pub fn is_empty(&self) -> bool {
157 self.chain.is_empty()
158 }
159
160 pub fn into_chain(self) -> MiddlewareChain {
162 self.chain
163 }
164
165 pub fn execute<'a, F>(
167 &'a self,
168 ctx: QueryContext,
169 final_handler: F,
170 ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>>
171 where
172 F: FnOnce(QueryContext) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> + Send + 'a,
173 {
174 self.chain.execute(ctx, final_handler)
175 }
176}
177
178impl Default for MiddlewareStack {
179 fn default() -> Self {
180 Self::new()
181 }
182}
183
184impl From<MiddlewareStack> for MiddlewareChain {
185 fn from(stack: MiddlewareStack) -> Self {
186 stack.chain
187 }
188}
189
190pub struct MiddlewareBuilder {
192 middlewares: Vec<SharedMiddleware>,
193}
194
195impl MiddlewareBuilder {
196 pub fn new() -> Self {
198 Self {
199 middlewares: Vec::new(),
200 }
201 }
202
203 pub fn with<M: Middleware + 'static>(mut self, middleware: M) -> Self {
205 self.middlewares.push(Arc::new(middleware));
206 self
207 }
208
209 pub fn with_if<M: Middleware + 'static>(self, condition: bool, middleware: M) -> Self {
211 if condition {
212 self.with(middleware)
213 } else {
214 self
215 }
216 }
217
218 pub fn build(self) -> MiddlewareChain {
220 MiddlewareChain::with(self.middlewares)
221 }
222
223 pub fn build_stack(self) -> MiddlewareStack {
225 MiddlewareStack {
226 chain: self.build(),
227 }
228 }
229}
230
231impl Default for MiddlewareBuilder {
232 fn default() -> Self {
233 Self::new()
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[test]
242 fn test_middleware_chain_empty() {
243 let chain = MiddlewareChain::new();
244 assert!(chain.is_empty());
245 assert_eq!(chain.len(), 0);
246 }
247
248 #[test]
249 fn test_middleware_stack_builder() {
250 struct DummyMiddleware;
251 impl Middleware for DummyMiddleware {
252 fn handle<'a>(
253 &'a self,
254 ctx: QueryContext,
255 next: Next<'a>,
256 ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
257 Box::pin(async move { next.run(ctx).await })
258 }
259 }
260
261 let stack = MiddlewareStack::new()
262 .with(DummyMiddleware)
263 .with(DummyMiddleware);
264
265 assert_eq!(stack.len(), 2);
266 }
267
268 #[test]
269 fn test_middleware_builder() {
270 struct TestMiddleware;
271 impl Middleware for TestMiddleware {
272 fn handle<'a>(
273 &'a self,
274 ctx: QueryContext,
275 next: Next<'a>,
276 ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
277 Box::pin(async move { next.run(ctx).await })
278 }
279 }
280
281 let chain = MiddlewareBuilder::new()
282 .with(TestMiddleware)
283 .with_if(true, TestMiddleware)
284 .with_if(false, TestMiddleware)
285 .build();
286
287 assert_eq!(chain.len(), 2);
288 }
289}