prax_query/middleware/
chain.rs

1//! Middleware chain and stack implementation.
2
3use super::context::QueryContext;
4use super::types::{
5    BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse, SharedMiddleware,
6};
7use std::sync::Arc;
8
9/// A chain of middleware that processes queries.
10///
11/// The chain executes middleware in order, with each middleware able to:
12/// - Modify the query context before passing to the next
13/// - Modify the response after receiving from the next
14/// - Short-circuit by not calling next
15pub struct MiddlewareChain {
16    middlewares: Vec<SharedMiddleware>,
17}
18
19impl MiddlewareChain {
20    /// Create an empty middleware chain.
21    pub fn new() -> Self {
22        Self {
23            middlewares: Vec::new(),
24        }
25    }
26
27    /// Create a chain with initial middleware.
28    pub fn with(middlewares: Vec<SharedMiddleware>) -> Self {
29        Self { middlewares }
30    }
31
32    /// Add middleware to the end of the chain.
33    pub fn push<M: Middleware + 'static>(&mut self, middleware: M) {
34        self.middlewares.push(Arc::new(middleware));
35    }
36
37    /// Add middleware to the beginning of the chain.
38    pub fn prepend<M: Middleware + 'static>(&mut self, middleware: M) {
39        self.middlewares.insert(0, Arc::new(middleware));
40    }
41
42    /// Get the number of middlewares in the chain.
43    pub fn len(&self) -> usize {
44        self.middlewares.len()
45    }
46
47    /// Check if the chain is empty.
48    pub fn is_empty(&self) -> bool {
49        self.middlewares.is_empty()
50    }
51
52    /// Execute the middleware chain.
53    pub fn execute<'a, F>(
54        &'a self,
55        ctx: QueryContext,
56        final_handler: F,
57    ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>>
58    where
59        F: FnOnce(QueryContext) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> + Send + 'a,
60    {
61        self.execute_at(0, ctx, final_handler)
62    }
63
64    fn execute_at<'a, F>(
65        &'a self,
66        index: usize,
67        ctx: QueryContext,
68        final_handler: F,
69    ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>>
70    where
71        F: FnOnce(QueryContext) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> + Send + 'a,
72    {
73        if index >= self.middlewares.len() {
74            // End of chain, call the final handler
75            return final_handler(ctx);
76        }
77
78        let middleware = &self.middlewares[index];
79
80        // Skip disabled middleware
81        if !middleware.enabled() {
82            return self.execute_at(index + 1, ctx, final_handler);
83        }
84
85        // Create the next handler that will call the rest of the chain
86        
87
88        ({
89            // We need to move the final_handler but also use it in the closure
90            // This requires some careful handling
91            Box::pin(async move {
92                // This is a placeholder - the actual implementation needs
93                // to properly chain the middleware
94                middleware
95                    .handle(
96                        ctx,
97                        Next {
98                            inner: Box::new(move |ctx| {
99                                // Recursively call the rest of the chain
100                                // Note: This is simplified - real impl would be more complex
101                                final_handler(ctx)
102                            }),
103                        },
104                    )
105                    .await
106            })
107        }) as _
108    }
109}
110
111impl Default for MiddlewareChain {
112    fn default() -> Self {
113        Self::new()
114    }
115}
116
117/// A stack of middleware with builder pattern.
118///
119/// This is a more ergonomic wrapper around `MiddlewareChain`.
120pub struct MiddlewareStack {
121    chain: MiddlewareChain,
122}
123
124impl MiddlewareStack {
125    /// Create a new empty stack.
126    pub fn new() -> Self {
127        Self {
128            chain: MiddlewareChain::new(),
129        }
130    }
131
132    /// Add middleware to the stack (builder pattern).
133    pub fn with<M: Middleware + 'static>(mut self, middleware: M) -> Self {
134        self.chain.push(middleware);
135        self
136    }
137
138    /// Add middleware mutably.
139    pub fn push<M: Middleware + 'static>(&mut self, middleware: M) -> &mut Self {
140        self.chain.push(middleware);
141        self
142    }
143
144    /// Add middleware to the front of the stack.
145    pub fn prepend<M: Middleware + 'static>(&mut self, middleware: M) -> &mut Self {
146        self.chain.prepend(middleware);
147        self
148    }
149
150    /// Get the number of middlewares.
151    pub fn len(&self) -> usize {
152        self.chain.len()
153    }
154
155    /// Check if the stack is empty.
156    pub fn is_empty(&self) -> bool {
157        self.chain.is_empty()
158    }
159
160    /// Get the underlying chain.
161    pub fn into_chain(self) -> MiddlewareChain {
162        self.chain
163    }
164
165    /// Execute the stack with a final handler.
166    pub fn execute<'a, F>(
167        &'a self,
168        ctx: QueryContext,
169        final_handler: F,
170    ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>>
171    where
172        F: FnOnce(QueryContext) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> + Send + 'a,
173    {
174        self.chain.execute(ctx, final_handler)
175    }
176}
177
178impl Default for MiddlewareStack {
179    fn default() -> Self {
180        Self::new()
181    }
182}
183
184impl From<MiddlewareStack> for MiddlewareChain {
185    fn from(stack: MiddlewareStack) -> Self {
186        stack.chain
187    }
188}
189
190/// Builder for creating middleware stacks.
191pub struct MiddlewareBuilder {
192    middlewares: Vec<SharedMiddleware>,
193}
194
195impl MiddlewareBuilder {
196    /// Create a new builder.
197    pub fn new() -> Self {
198        Self {
199            middlewares: Vec::new(),
200        }
201    }
202
203    /// Add middleware.
204    pub fn with<M: Middleware + 'static>(mut self, middleware: M) -> Self {
205        self.middlewares.push(Arc::new(middleware));
206        self
207    }
208
209    /// Add middleware conditionally.
210    pub fn with_if<M: Middleware + 'static>(self, condition: bool, middleware: M) -> Self {
211        if condition {
212            self.with(middleware)
213        } else {
214            self
215        }
216    }
217
218    /// Build the middleware chain.
219    pub fn build(self) -> MiddlewareChain {
220        MiddlewareChain::with(self.middlewares)
221    }
222
223    /// Build into a stack.
224    pub fn build_stack(self) -> MiddlewareStack {
225        MiddlewareStack {
226            chain: self.build(),
227        }
228    }
229}
230
231impl Default for MiddlewareBuilder {
232    fn default() -> Self {
233        Self::new()
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn test_middleware_chain_empty() {
243        let chain = MiddlewareChain::new();
244        assert!(chain.is_empty());
245        assert_eq!(chain.len(), 0);
246    }
247
248    #[test]
249    fn test_middleware_stack_builder() {
250        struct DummyMiddleware;
251        impl Middleware for DummyMiddleware {
252            fn handle<'a>(
253                &'a self,
254                ctx: QueryContext,
255                next: Next<'a>,
256            ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
257                Box::pin(async move { next.run(ctx).await })
258            }
259        }
260
261        let stack = MiddlewareStack::new()
262            .with(DummyMiddleware)
263            .with(DummyMiddleware);
264
265        assert_eq!(stack.len(), 2);
266    }
267
268    #[test]
269    fn test_middleware_builder() {
270        struct TestMiddleware;
271        impl Middleware for TestMiddleware {
272            fn handle<'a>(
273                &'a self,
274                ctx: QueryContext,
275                next: Next<'a>,
276            ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
277                Box::pin(async move { next.run(ctx).await })
278            }
279        }
280
281        let chain = MiddlewareBuilder::new()
282            .with(TestMiddleware)
283            .with_if(true, TestMiddleware)
284            .with_if(false, TestMiddleware)
285            .build();
286
287        assert_eq!(chain.len(), 2);
288    }
289}