prax_query/middleware/
chain.rs

1//! Middleware chain and stack implementation.
2
3use super::context::QueryContext;
4use super::types::{
5    BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse, SharedMiddleware,
6};
7use std::sync::Arc;
8
9/// A chain of middleware that processes queries.
10///
11/// The chain executes middleware in order, with each middleware able to:
12/// - Modify the query context before passing to the next
13/// - Modify the response after receiving from the next
14/// - Short-circuit by not calling next
15pub struct MiddlewareChain {
16    middlewares: Vec<SharedMiddleware>,
17}
18
19impl MiddlewareChain {
20    /// Create an empty middleware chain.
21    pub fn new() -> Self {
22        Self {
23            middlewares: Vec::new(),
24        }
25    }
26
27    /// Create a chain with initial middleware.
28    pub fn with(middlewares: Vec<SharedMiddleware>) -> Self {
29        Self { middlewares }
30    }
31
32    /// Add middleware to the end of the chain.
33    pub fn push<M: Middleware + 'static>(&mut self, middleware: M) {
34        self.middlewares.push(Arc::new(middleware));
35    }
36
37    /// Add middleware to the beginning of the chain.
38    pub fn prepend<M: Middleware + 'static>(&mut self, middleware: M) {
39        self.middlewares.insert(0, Arc::new(middleware));
40    }
41
42    /// Get the number of middlewares in the chain.
43    pub fn len(&self) -> usize {
44        self.middlewares.len()
45    }
46
47    /// Check if the chain is empty.
48    pub fn is_empty(&self) -> bool {
49        self.middlewares.is_empty()
50    }
51
52    /// Execute the middleware chain.
53    pub fn execute<'a, F>(
54        &'a self,
55        ctx: QueryContext,
56        final_handler: F,
57    ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>>
58    where
59        F: FnOnce(QueryContext) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> + Send + 'a,
60    {
61        self.execute_at(0, ctx, final_handler)
62    }
63
64    fn execute_at<'a, F>(
65        &'a self,
66        index: usize,
67        ctx: QueryContext,
68        final_handler: F,
69    ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>>
70    where
71        F: FnOnce(QueryContext) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> + Send + 'a,
72    {
73        if index >= self.middlewares.len() {
74            // End of chain, call the final handler
75            return final_handler(ctx);
76        }
77
78        let middleware = &self.middlewares[index];
79
80        // Skip disabled middleware
81        if !middleware.enabled() {
82            return self.execute_at(index + 1, ctx, final_handler);
83        }
84
85        // Create the next handler that will call the rest of the chain
86
87        ({
88            // We need to move the final_handler but also use it in the closure
89            // This requires some careful handling
90            Box::pin(async move {
91                // This is a placeholder - the actual implementation needs
92                // to properly chain the middleware
93                middleware
94                    .handle(
95                        ctx,
96                        Next {
97                            inner: Box::new(move |ctx| {
98                                // Recursively call the rest of the chain
99                                // Note: This is simplified - real impl would be more complex
100                                final_handler(ctx)
101                            }),
102                        },
103                    )
104                    .await
105            })
106        }) as _
107    }
108}
109
110impl Default for MiddlewareChain {
111    fn default() -> Self {
112        Self::new()
113    }
114}
115
116/// A stack of middleware with builder pattern.
117///
118/// This is a more ergonomic wrapper around `MiddlewareChain`.
119pub struct MiddlewareStack {
120    chain: MiddlewareChain,
121}
122
123impl MiddlewareStack {
124    /// Create a new empty stack.
125    pub fn new() -> Self {
126        Self {
127            chain: MiddlewareChain::new(),
128        }
129    }
130
131    /// Add middleware to the stack (builder pattern).
132    pub fn with<M: Middleware + 'static>(mut self, middleware: M) -> Self {
133        self.chain.push(middleware);
134        self
135    }
136
137    /// Add middleware mutably.
138    pub fn push<M: Middleware + 'static>(&mut self, middleware: M) -> &mut Self {
139        self.chain.push(middleware);
140        self
141    }
142
143    /// Add middleware to the front of the stack.
144    pub fn prepend<M: Middleware + 'static>(&mut self, middleware: M) -> &mut Self {
145        self.chain.prepend(middleware);
146        self
147    }
148
149    /// Get the number of middlewares.
150    pub fn len(&self) -> usize {
151        self.chain.len()
152    }
153
154    /// Check if the stack is empty.
155    pub fn is_empty(&self) -> bool {
156        self.chain.is_empty()
157    }
158
159    /// Get the underlying chain.
160    pub fn into_chain(self) -> MiddlewareChain {
161        self.chain
162    }
163
164    /// Execute the stack with a final handler.
165    pub fn execute<'a, F>(
166        &'a self,
167        ctx: QueryContext,
168        final_handler: F,
169    ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>>
170    where
171        F: FnOnce(QueryContext) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> + Send + 'a,
172    {
173        self.chain.execute(ctx, final_handler)
174    }
175}
176
177impl Default for MiddlewareStack {
178    fn default() -> Self {
179        Self::new()
180    }
181}
182
183impl From<MiddlewareStack> for MiddlewareChain {
184    fn from(stack: MiddlewareStack) -> Self {
185        stack.chain
186    }
187}
188
189/// Builder for creating middleware stacks.
190pub struct MiddlewareBuilder {
191    middlewares: Vec<SharedMiddleware>,
192}
193
194impl MiddlewareBuilder {
195    /// Create a new builder.
196    pub fn new() -> Self {
197        Self {
198            middlewares: Vec::new(),
199        }
200    }
201
202    /// Add middleware.
203    pub fn with<M: Middleware + 'static>(mut self, middleware: M) -> Self {
204        self.middlewares.push(Arc::new(middleware));
205        self
206    }
207
208    /// Add middleware conditionally.
209    pub fn with_if<M: Middleware + 'static>(self, condition: bool, middleware: M) -> Self {
210        if condition {
211            self.with(middleware)
212        } else {
213            self
214        }
215    }
216
217    /// Build the middleware chain.
218    pub fn build(self) -> MiddlewareChain {
219        MiddlewareChain::with(self.middlewares)
220    }
221
222    /// Build into a stack.
223    pub fn build_stack(self) -> MiddlewareStack {
224        MiddlewareStack {
225            chain: self.build(),
226        }
227    }
228}
229
230impl Default for MiddlewareBuilder {
231    fn default() -> Self {
232        Self::new()
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn test_middleware_chain_empty() {
242        let chain = MiddlewareChain::new();
243        assert!(chain.is_empty());
244        assert_eq!(chain.len(), 0);
245    }
246
247    #[test]
248    fn test_middleware_stack_builder() {
249        struct DummyMiddleware;
250        impl Middleware for DummyMiddleware {
251            fn handle<'a>(
252                &'a self,
253                ctx: QueryContext,
254                next: Next<'a>,
255            ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
256                Box::pin(async move { next.run(ctx).await })
257            }
258        }
259
260        let stack = MiddlewareStack::new()
261            .with(DummyMiddleware)
262            .with(DummyMiddleware);
263
264        assert_eq!(stack.len(), 2);
265    }
266
267    #[test]
268    fn test_middleware_builder() {
269        struct TestMiddleware;
270        impl Middleware for TestMiddleware {
271            fn handle<'a>(
272                &'a self,
273                ctx: QueryContext,
274                next: Next<'a>,
275            ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
276                Box::pin(async move { next.run(ctx).await })
277            }
278        }
279
280        let chain = MiddlewareBuilder::new()
281            .with(TestMiddleware)
282            .with_if(true, TestMiddleware)
283            .with_if(false, TestMiddleware)
284            .build();
285
286        assert_eq!(chain.len(), 2);
287    }
288}