prax_query/middleware/
chain.rs1use super::context::QueryContext;
4use super::types::{BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse, SharedMiddleware};
5use std::sync::Arc;
6
7pub struct MiddlewareChain {
14 middlewares: Vec<SharedMiddleware>,
15}
16
17impl MiddlewareChain {
18 pub fn new() -> Self {
20 Self {
21 middlewares: Vec::new(),
22 }
23 }
24
25 pub fn with(middlewares: Vec<SharedMiddleware>) -> Self {
27 Self { middlewares }
28 }
29
30 pub fn push<M: Middleware + 'static>(&mut self, middleware: M) {
32 self.middlewares.push(Arc::new(middleware));
33 }
34
35 pub fn prepend<M: Middleware + 'static>(&mut self, middleware: M) {
37 self.middlewares.insert(0, Arc::new(middleware));
38 }
39
40 pub fn len(&self) -> usize {
42 self.middlewares.len()
43 }
44
45 pub fn is_empty(&self) -> bool {
47 self.middlewares.is_empty()
48 }
49
50 pub fn execute<'a, F>(
52 &'a self,
53 ctx: QueryContext,
54 final_handler: F,
55 ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>>
56 where
57 F: FnOnce(QueryContext) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> + Send + 'a,
58 {
59 self.execute_at(0, ctx, final_handler)
60 }
61
62 fn execute_at<'a, F>(
63 &'a self,
64 index: usize,
65 ctx: QueryContext,
66 final_handler: F,
67 ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>>
68 where
69 F: FnOnce(QueryContext) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> + Send + 'a,
70 {
71 if index >= self.middlewares.len() {
72 return final_handler(ctx);
74 }
75
76 let middleware = &self.middlewares[index];
77
78 if !middleware.enabled() {
80 return self.execute_at(index + 1, ctx, final_handler);
81 }
82
83 let next = {
85 Box::pin(async move {
88 middleware.handle(ctx, Next {
91 inner: Box::new(move |ctx| {
92 final_handler(ctx)
95 }),
96 }).await
97 })
98 };
99
100 next
101 }
102}
103
104impl Default for MiddlewareChain {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110pub struct MiddlewareStack {
114 chain: MiddlewareChain,
115}
116
117impl MiddlewareStack {
118 pub fn new() -> Self {
120 Self {
121 chain: MiddlewareChain::new(),
122 }
123 }
124
125 pub fn with<M: Middleware + 'static>(mut self, middleware: M) -> Self {
127 self.chain.push(middleware);
128 self
129 }
130
131 pub fn push<M: Middleware + 'static>(&mut self, middleware: M) -> &mut Self {
133 self.chain.push(middleware);
134 self
135 }
136
137 pub fn prepend<M: Middleware + 'static>(&mut self, middleware: M) -> &mut Self {
139 self.chain.prepend(middleware);
140 self
141 }
142
143 pub fn len(&self) -> usize {
145 self.chain.len()
146 }
147
148 pub fn is_empty(&self) -> bool {
150 self.chain.is_empty()
151 }
152
153 pub fn into_chain(self) -> MiddlewareChain {
155 self.chain
156 }
157
158 pub fn execute<'a, F>(
160 &'a self,
161 ctx: QueryContext,
162 final_handler: F,
163 ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>>
164 where
165 F: FnOnce(QueryContext) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> + Send + 'a,
166 {
167 self.chain.execute(ctx, final_handler)
168 }
169}
170
171impl Default for MiddlewareStack {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177impl From<MiddlewareStack> for MiddlewareChain {
178 fn from(stack: MiddlewareStack) -> Self {
179 stack.chain
180 }
181}
182
183pub struct MiddlewareBuilder {
185 middlewares: Vec<SharedMiddleware>,
186}
187
188impl MiddlewareBuilder {
189 pub fn new() -> Self {
191 Self {
192 middlewares: Vec::new(),
193 }
194 }
195
196 pub fn with<M: Middleware + 'static>(mut self, middleware: M) -> Self {
198 self.middlewares.push(Arc::new(middleware));
199 self
200 }
201
202 pub fn with_if<M: Middleware + 'static>(self, condition: bool, middleware: M) -> Self {
204 if condition {
205 self.with(middleware)
206 } else {
207 self
208 }
209 }
210
211 pub fn build(self) -> MiddlewareChain {
213 MiddlewareChain::with(self.middlewares)
214 }
215
216 pub fn build_stack(self) -> MiddlewareStack {
218 MiddlewareStack {
219 chain: self.build(),
220 }
221 }
222}
223
224impl Default for MiddlewareBuilder {
225 fn default() -> Self {
226 Self::new()
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233
234 #[test]
235 fn test_middleware_chain_empty() {
236 let chain = MiddlewareChain::new();
237 assert!(chain.is_empty());
238 assert_eq!(chain.len(), 0);
239 }
240
241 #[test]
242 fn test_middleware_stack_builder() {
243 struct DummyMiddleware;
244 impl Middleware for DummyMiddleware {
245 fn handle<'a>(
246 &'a self,
247 ctx: QueryContext,
248 next: Next<'a>,
249 ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
250 Box::pin(async move { next.run(ctx).await })
251 }
252 }
253
254 let stack = MiddlewareStack::new()
255 .with(DummyMiddleware)
256 .with(DummyMiddleware);
257
258 assert_eq!(stack.len(), 2);
259 }
260
261 #[test]
262 fn test_middleware_builder() {
263 struct TestMiddleware;
264 impl Middleware for TestMiddleware {
265 fn handle<'a>(
266 &'a self,
267 ctx: QueryContext,
268 next: Next<'a>,
269 ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
270 Box::pin(async move { next.run(ctx).await })
271 }
272 }
273
274 let chain = MiddlewareBuilder::new()
275 .with(TestMiddleware)
276 .with_if(true, TestMiddleware)
277 .with_if(false, TestMiddleware)
278 .build();
279
280 assert_eq!(chain.len(), 2);
281 }
282}
283
284