prax_query/middleware/
chain.rs

1//! Middleware chain and stack implementation.
2
3use super::context::QueryContext;
4use super::types::{BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse, SharedMiddleware};
5use std::sync::Arc;
6
7/// A chain of middleware that processes queries.
8///
9/// The chain executes middleware in order, with each middleware able to:
10/// - Modify the query context before passing to the next
11/// - Modify the response after receiving from the next
12/// - Short-circuit by not calling next
13pub struct MiddlewareChain {
14    middlewares: Vec<SharedMiddleware>,
15}
16
17impl MiddlewareChain {
18    /// Create an empty middleware chain.
19    pub fn new() -> Self {
20        Self {
21            middlewares: Vec::new(),
22        }
23    }
24
25    /// Create a chain with initial middleware.
26    pub fn with(middlewares: Vec<SharedMiddleware>) -> Self {
27        Self { middlewares }
28    }
29
30    /// Add middleware to the end of the chain.
31    pub fn push<M: Middleware + 'static>(&mut self, middleware: M) {
32        self.middlewares.push(Arc::new(middleware));
33    }
34
35    /// Add middleware to the beginning of the chain.
36    pub fn prepend<M: Middleware + 'static>(&mut self, middleware: M) {
37        self.middlewares.insert(0, Arc::new(middleware));
38    }
39
40    /// Get the number of middlewares in the chain.
41    pub fn len(&self) -> usize {
42        self.middlewares.len()
43    }
44
45    /// Check if the chain is empty.
46    pub fn is_empty(&self) -> bool {
47        self.middlewares.is_empty()
48    }
49
50    /// Execute the middleware chain.
51    pub fn execute<'a, F>(
52        &'a self,
53        ctx: QueryContext,
54        final_handler: F,
55    ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>>
56    where
57        F: FnOnce(QueryContext) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> + Send + 'a,
58    {
59        self.execute_at(0, ctx, final_handler)
60    }
61
62    fn execute_at<'a, F>(
63        &'a self,
64        index: usize,
65        ctx: QueryContext,
66        final_handler: F,
67    ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>>
68    where
69        F: FnOnce(QueryContext) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> + Send + 'a,
70    {
71        if index >= self.middlewares.len() {
72            // End of chain, call the final handler
73            return final_handler(ctx);
74        }
75
76        let middleware = &self.middlewares[index];
77
78        // Skip disabled middleware
79        if !middleware.enabled() {
80            return self.execute_at(index + 1, ctx, final_handler);
81        }
82
83        // Create the next handler that will call the rest of the chain
84        let next = {
85            // We need to move the final_handler but also use it in the closure
86            // This requires some careful handling
87            Box::pin(async move {
88                // This is a placeholder - the actual implementation needs
89                // to properly chain the middleware
90                middleware.handle(ctx, Next {
91                    inner: Box::new(move |ctx| {
92                        // Recursively call the rest of the chain
93                        // Note: This is simplified - real impl would be more complex
94                        final_handler(ctx)
95                    }),
96                }).await
97            })
98        };
99
100        next
101    }
102}
103
104impl Default for MiddlewareChain {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110/// A stack of middleware with builder pattern.
111///
112/// This is a more ergonomic wrapper around `MiddlewareChain`.
113pub struct MiddlewareStack {
114    chain: MiddlewareChain,
115}
116
117impl MiddlewareStack {
118    /// Create a new empty stack.
119    pub fn new() -> Self {
120        Self {
121            chain: MiddlewareChain::new(),
122        }
123    }
124
125    /// Add middleware to the stack (builder pattern).
126    pub fn with<M: Middleware + 'static>(mut self, middleware: M) -> Self {
127        self.chain.push(middleware);
128        self
129    }
130
131    /// Add middleware mutably.
132    pub fn push<M: Middleware + 'static>(&mut self, middleware: M) -> &mut Self {
133        self.chain.push(middleware);
134        self
135    }
136
137    /// Add middleware to the front of the stack.
138    pub fn prepend<M: Middleware + 'static>(&mut self, middleware: M) -> &mut Self {
139        self.chain.prepend(middleware);
140        self
141    }
142
143    /// Get the number of middlewares.
144    pub fn len(&self) -> usize {
145        self.chain.len()
146    }
147
148    /// Check if the stack is empty.
149    pub fn is_empty(&self) -> bool {
150        self.chain.is_empty()
151    }
152
153    /// Get the underlying chain.
154    pub fn into_chain(self) -> MiddlewareChain {
155        self.chain
156    }
157
158    /// Execute the stack with a final handler.
159    pub fn execute<'a, F>(
160        &'a self,
161        ctx: QueryContext,
162        final_handler: F,
163    ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>>
164    where
165        F: FnOnce(QueryContext) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> + Send + 'a,
166    {
167        self.chain.execute(ctx, final_handler)
168    }
169}
170
171impl Default for MiddlewareStack {
172    fn default() -> Self {
173        Self::new()
174    }
175}
176
177impl From<MiddlewareStack> for MiddlewareChain {
178    fn from(stack: MiddlewareStack) -> Self {
179        stack.chain
180    }
181}
182
183/// Builder for creating middleware stacks.
184pub struct MiddlewareBuilder {
185    middlewares: Vec<SharedMiddleware>,
186}
187
188impl MiddlewareBuilder {
189    /// Create a new builder.
190    pub fn new() -> Self {
191        Self {
192            middlewares: Vec::new(),
193        }
194    }
195
196    /// Add middleware.
197    pub fn with<M: Middleware + 'static>(mut self, middleware: M) -> Self {
198        self.middlewares.push(Arc::new(middleware));
199        self
200    }
201
202    /// Add middleware conditionally.
203    pub fn with_if<M: Middleware + 'static>(self, condition: bool, middleware: M) -> Self {
204        if condition {
205            self.with(middleware)
206        } else {
207            self
208        }
209    }
210
211    /// Build the middleware chain.
212    pub fn build(self) -> MiddlewareChain {
213        MiddlewareChain::with(self.middlewares)
214    }
215
216    /// Build into a stack.
217    pub fn build_stack(self) -> MiddlewareStack {
218        MiddlewareStack {
219            chain: self.build(),
220        }
221    }
222}
223
224impl Default for MiddlewareBuilder {
225    fn default() -> Self {
226        Self::new()
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233
234    #[test]
235    fn test_middleware_chain_empty() {
236        let chain = MiddlewareChain::new();
237        assert!(chain.is_empty());
238        assert_eq!(chain.len(), 0);
239    }
240
241    #[test]
242    fn test_middleware_stack_builder() {
243        struct DummyMiddleware;
244        impl Middleware for DummyMiddleware {
245            fn handle<'a>(
246                &'a self,
247                ctx: QueryContext,
248                next: Next<'a>,
249            ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
250                Box::pin(async move { next.run(ctx).await })
251            }
252        }
253
254        let stack = MiddlewareStack::new()
255            .with(DummyMiddleware)
256            .with(DummyMiddleware);
257
258        assert_eq!(stack.len(), 2);
259    }
260
261    #[test]
262    fn test_middleware_builder() {
263        struct TestMiddleware;
264        impl Middleware for TestMiddleware {
265            fn handle<'a>(
266                &'a self,
267                ctx: QueryContext,
268                next: Next<'a>,
269            ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
270                Box::pin(async move { next.run(ctx).await })
271            }
272        }
273
274        let chain = MiddlewareBuilder::new()
275            .with(TestMiddleware)
276            .with_if(true, TestMiddleware)
277            .with_if(false, TestMiddleware)
278            .build();
279
280        assert_eq!(chain.len(), 2);
281    }
282}
283
284