Skip to main content

hexeract_core/
middleware.rs

1use std::any::Any;
2use std::collections::VecDeque;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use crate::context::HandlerContext;
8use crate::envelope::MessageEnvelope;
9use crate::error::HexeractError;
10
11/// Type-erased handler output, passed through the middleware chain.
12///
13/// The terminal dispatcher boxes the concrete `C::Output` into this alias.
14/// Callers downcast back to the typed output at the dispatch boundary.
15pub type BoxOutput = Box<dyn Any + Send + Sync>;
16
17type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
18
19/// Intercepts a dispatch before reaching its handler.
20///
21/// Middlewares are stacked onion-style: the first registered middleware
22/// wraps all the others, observing both the entry and the exit of every
23/// dispatch.
24///
25/// # Example
26///
27/// ```
28/// use hexeract_core::{BoxOutput, HandlerContext, HexeractError, MessageEnvelope, Middleware, Next};
29///
30/// struct LoggingMiddleware;
31///
32/// impl Middleware for LoggingMiddleware {
33///     async fn execute(
34///         &self,
35///         envelope: &MessageEnvelope,
36///         ctx: &HandlerContext,
37///         next: Next,
38///     ) -> Result<BoxOutput, HexeractError> {
39///         tracing::info!(type_name = envelope.type_name(), "dispatching");
40///         let result = next.run(envelope, ctx).await;
41///         tracing::info!(type_name = envelope.type_name(), "dispatched");
42///         result
43///     }
44/// }
45/// ```
46#[trait_variant::make(Send)]
47pub trait Middleware: Send + Sync + 'static {
48    /// Executes the middleware. The implementation must call `next.run(...)`
49    /// to proceed to the next middleware or terminal, unless it intentionally
50    /// short-circuits the chain.
51    async fn execute(
52        &self,
53        envelope: &MessageEnvelope,
54        ctx: &HandlerContext,
55        next: Next,
56    ) -> Result<BoxOutput, HexeractError>;
57}
58
59#[doc(hidden)]
60pub trait DynMiddleware: Send + Sync + 'static {
61    fn execute<'a>(
62        &'a self,
63        envelope: &'a MessageEnvelope,
64        ctx: &'a HandlerContext,
65        next: Next,
66    ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>>;
67}
68
69impl<M: Middleware> DynMiddleware for M {
70    fn execute<'a>(
71        &'a self,
72        envelope: &'a MessageEnvelope,
73        ctx: &'a HandlerContext,
74        next: Next,
75    ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>> {
76        Box::pin(<M as Middleware>::execute(self, envelope, ctx, next))
77    }
78}
79
80/// Terminal of the middleware chain. The mediator (issue #6) supplies a
81/// concrete implementation that downcasts the message and invokes the
82/// registered handler.
83///
84/// This trait is public so external dispatchers and test harnesses can
85/// build a pipeline without depending on the mediator. The API may evolve
86/// before v1.0.
87pub trait Terminal: Send + Sync + 'static {
88    /// Dispatches the message to its terminal destination.
89    fn dispatch<'a>(
90        &'a self,
91        envelope: &'a MessageEnvelope,
92        ctx: &'a HandlerContext,
93    ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>>;
94}
95
96/// Opaque continuation passed to a [`Middleware`]. Calling [`Next::run`]
97/// proceeds to the next middleware in the chain or to the [`Terminal`] if
98/// the chain is empty.
99pub struct Next {
100    chain: VecDeque<Arc<dyn DynMiddleware>>,
101    terminal: Arc<dyn Terminal>,
102}
103
104impl Next {
105    /// Builds a new [`Next`] from a chain of middlewares and a terminal.
106    ///
107    /// Middlewares are executed in the order they appear in the slice: the
108    /// first one wraps the second, which wraps the third, and so on.
109    #[must_use]
110    pub fn new(middlewares: Vec<Arc<dyn DynMiddleware>>, terminal: Arc<dyn Terminal>) -> Self {
111        Self {
112            chain: middlewares.into(),
113            terminal,
114        }
115    }
116
117    /// Advances the pipeline by one step.
118    ///
119    /// # Errors
120    ///
121    /// Returns the [`HexeractError`] produced by the next middleware in the
122    /// chain or by the [`Terminal`] when the chain is exhausted.
123    pub async fn run(
124        mut self,
125        envelope: &MessageEnvelope,
126        ctx: &HandlerContext,
127    ) -> Result<BoxOutput, HexeractError> {
128        if let Some(head) = self.chain.pop_front() {
129            head.execute(envelope, ctx, self).await
130        } else {
131            self.terminal.dispatch(envelope, ctx).await
132        }
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use crate::ids::{CorrelationId, MessageId};
140    use std::sync::Mutex;
141
142    fn dyn_mw<M: Middleware>(m: M) -> Arc<dyn DynMiddleware> {
143        Arc::new(m)
144    }
145
146    struct DummyCmd;
147    impl crate::command::Command for DummyCmd {
148        type Output = i32;
149    }
150
151    fn fresh_env() -> MessageEnvelope {
152        MessageEnvelope::for_command::<DummyCmd>(MessageId::new(), CorrelationId::new())
153    }
154
155    fn fresh_ctx() -> HandlerContext {
156        HandlerContext::new(MessageId::new(), CorrelationId::new())
157    }
158
159    struct StaticTerminal {
160        value: i32,
161    }
162
163    impl Terminal for StaticTerminal {
164        fn dispatch<'a>(
165            &'a self,
166            _envelope: &'a MessageEnvelope,
167            _ctx: &'a HandlerContext,
168        ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>> {
169            let value = self.value;
170            Box::pin(async move { Ok(Box::new(value) as BoxOutput) })
171        }
172    }
173
174    struct FailingTerminal;
175    impl Terminal for FailingTerminal {
176        fn dispatch<'a>(
177            &'a self,
178            _envelope: &'a MessageEnvelope,
179            _ctx: &'a HandlerContext,
180        ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>> {
181            Box::pin(async move { Err(HexeractError::Dispatch("terminal failure".into())) })
182        }
183    }
184
185    #[derive(Clone)]
186    struct Recorder {
187        trace: Arc<Mutex<Vec<&'static str>>>,
188    }
189
190    impl Recorder {
191        fn new() -> Self {
192            Self {
193                trace: Arc::new(Mutex::new(Vec::new())),
194            }
195        }
196
197        fn snapshot(&self) -> Vec<&'static str> {
198            self.trace.lock().expect("poisoned").clone()
199        }
200    }
201
202    struct TracingMiddleware {
203        name: &'static str,
204        post_label: &'static str,
205        recorder: Recorder,
206    }
207
208    impl Middleware for TracingMiddleware {
209        async fn execute(
210            &self,
211            envelope: &MessageEnvelope,
212            ctx: &HandlerContext,
213            next: Next,
214        ) -> Result<BoxOutput, HexeractError> {
215            self.recorder
216                .trace
217                .lock()
218                .expect("poisoned")
219                .push(self.name);
220            let result = next.run(envelope, ctx).await;
221            self.recorder
222                .trace
223                .lock()
224                .expect("poisoned")
225                .push(self.post_label);
226            result
227        }
228    }
229
230    fn tracing_mw(name: &'static str, post: &'static str, recorder: Recorder) -> TracingMiddleware {
231        TracingMiddleware {
232            name,
233            post_label: post,
234            recorder,
235        }
236    }
237
238    #[tokio::test]
239    async fn single_middleware_delegates_to_terminal() {
240        let recorder = Recorder::new();
241        let next = Next::new(
242            vec![dyn_mw(tracing_mw("A", "A_post", recorder.clone()))],
243            Arc::new(StaticTerminal { value: 42 }),
244        );
245        let output = next
246            .run(&fresh_env(), &fresh_ctx())
247            .await
248            .expect("dispatch should succeed");
249        let downcast = output.downcast::<i32>().expect("output must be i32");
250        assert_eq!(*downcast, 42);
251        assert_eq!(recorder.snapshot(), vec!["A", "A_post"]);
252    }
253
254    #[tokio::test]
255    async fn chain_of_three_executes_in_onion_order() {
256        let recorder = Recorder::new();
257        let next = Next::new(
258            vec![
259                dyn_mw(tracing_mw("A", "A_post", recorder.clone())),
260                dyn_mw(tracing_mw("B", "B_post", recorder.clone())),
261                dyn_mw(tracing_mw("C", "C_post", recorder.clone())),
262            ],
263            Arc::new(StaticTerminal { value: 7 }),
264        );
265        let _ = next.run(&fresh_env(), &fresh_ctx()).await.unwrap();
266        assert_eq!(
267            recorder.snapshot(),
268            vec!["A", "B", "C", "C_post", "B_post", "A_post"]
269        );
270    }
271
272    struct ShortCircuit;
273    impl Middleware for ShortCircuit {
274        async fn execute(
275            &self,
276            _envelope: &MessageEnvelope,
277            _ctx: &HandlerContext,
278            _next: Next,
279        ) -> Result<BoxOutput, HexeractError> {
280            Ok(Box::new(99_i32) as BoxOutput)
281        }
282    }
283
284    #[tokio::test]
285    async fn short_circuit_middleware_skips_terminal() {
286        let next = Next::new(vec![dyn_mw(ShortCircuit)], Arc::new(FailingTerminal));
287        let output = next
288            .run(&fresh_env(), &fresh_ctx())
289            .await
290            .expect("short-circuit must succeed");
291        assert_eq!(*output.downcast::<i32>().unwrap(), 99);
292    }
293
294    #[tokio::test]
295    async fn error_from_terminal_propagates_through_chain() {
296        let recorder = Recorder::new();
297        let next = Next::new(
298            vec![dyn_mw(tracing_mw("A", "A_post", recorder.clone()))],
299            Arc::new(FailingTerminal),
300        );
301        let result = next.run(&fresh_env(), &fresh_ctx()).await;
302        assert!(matches!(result, Err(HexeractError::Dispatch(_))));
303        assert_eq!(recorder.snapshot(), vec!["A", "A_post"]);
304    }
305
306    struct ErrorMiddleware;
307    impl Middleware for ErrorMiddleware {
308        async fn execute(
309            &self,
310            _envelope: &MessageEnvelope,
311            _ctx: &HandlerContext,
312            _next: Next,
313        ) -> Result<BoxOutput, HexeractError> {
314            Err(HexeractError::Dispatch("middleware refusal".into()))
315        }
316    }
317
318    #[tokio::test]
319    async fn error_from_middleware_propagates() {
320        let next = Next::new(
321            vec![dyn_mw(ErrorMiddleware)],
322            Arc::new(StaticTerminal { value: 0 }),
323        );
324        let err = next
325            .run(&fresh_env(), &fresh_ctx())
326            .await
327            .expect_err("middleware should fail");
328        match err {
329            HexeractError::Dispatch(ref m) => assert_eq!(m, "middleware refusal"),
330            other => panic!("unexpected variant: {other:?}"),
331        }
332    }
333
334    fn assert_send<T: Send>(_: &T) {}
335
336    #[tokio::test]
337    async fn next_run_future_is_send() {
338        let next = Next::new(vec![], Arc::new(StaticTerminal { value: 1 }));
339        let env = fresh_env();
340        let ctx = fresh_ctx();
341        let future = next.run(&env, &ctx);
342        assert_send(&future);
343        let _ = future.await;
344    }
345
346    #[tokio::test]
347    async fn empty_chain_invokes_terminal_directly() {
348        let next = Next::new(vec![], Arc::new(StaticTerminal { value: 123 }));
349        let output = next.run(&fresh_env(), &fresh_ctx()).await.unwrap();
350        assert_eq!(*output.downcast::<i32>().unwrap(), 123);
351    }
352
353    struct EnvelopeInspector {
354        observed: Arc<Mutex<Option<String>>>,
355    }
356
357    impl Middleware for EnvelopeInspector {
358        async fn execute(
359            &self,
360            envelope: &MessageEnvelope,
361            ctx: &HandlerContext,
362            next: Next,
363        ) -> Result<BoxOutput, HexeractError> {
364            *self.observed.lock().expect("poisoned") = Some(envelope.type_name().to_string());
365            next.run(envelope, ctx).await
366        }
367    }
368
369    #[tokio::test]
370    async fn middleware_reads_envelope_type_name() {
371        let observed = Arc::new(Mutex::new(None));
372        let mw = EnvelopeInspector {
373            observed: Arc::clone(&observed),
374        };
375        let next = Next::new(vec![dyn_mw(mw)], Arc::new(StaticTerminal { value: 0 }));
376        let _ = next.run(&fresh_env(), &fresh_ctx()).await;
377        let observed = observed.lock().unwrap().clone();
378        assert!(observed.unwrap().ends_with("::DummyCmd"));
379    }
380}