Skip to main content

hexeract_core/
middleware.rs

1use std::any::Any;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use crate::context::HandlerContext;
7use crate::envelope::MessageEnvelope;
8use crate::error::HexeractError;
9
10/// Type-erased handler output, passed through the middleware chain.
11///
12/// The terminal dispatcher boxes the concrete `C::Output` into this alias.
13/// Callers downcast back to the typed output at the dispatch boundary.
14pub type BoxOutput = Box<dyn Any + Send + Sync>;
15
16type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
17
18/// Intercepts a dispatch before reaching its handler.
19///
20/// Middlewares are stacked onion-style: the first registered middleware
21/// wraps all the others, observing both the entry and the exit of every
22/// dispatch.
23///
24/// # Example
25///
26/// ```
27/// use hexeract_core::{BoxOutput, HandlerContext, HexeractError, MessageEnvelope, Middleware, Next};
28///
29/// struct LoggingMiddleware;
30///
31/// impl Middleware for LoggingMiddleware {
32///     async fn execute(
33///         &self,
34///         envelope: &MessageEnvelope,
35///         ctx: &HandlerContext,
36///         next: Next,
37///     ) -> Result<BoxOutput, HexeractError> {
38///         tracing::info!(type_name = envelope.type_name(), "dispatching");
39///         let result = next.run(envelope, ctx).await;
40///         tracing::info!(type_name = envelope.type_name(), "dispatched");
41///         result
42///     }
43/// }
44/// ```
45#[trait_variant::make(Send)]
46pub trait Middleware: Send + Sync + 'static {
47    /// Executes the middleware. The implementation must call `next.run(...)`
48    /// to proceed to the next middleware or terminal, unless it intentionally
49    /// short-circuits the chain.
50    async fn execute(
51        &self,
52        envelope: &MessageEnvelope,
53        ctx: &HandlerContext,
54        next: Next,
55    ) -> Result<BoxOutput, HexeractError>;
56}
57
58#[doc(hidden)]
59pub trait DynMiddleware: Send + Sync + 'static {
60    fn execute<'a>(
61        &'a self,
62        envelope: &'a MessageEnvelope,
63        ctx: &'a HandlerContext,
64        next: Next,
65    ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>>;
66}
67
68impl<M: Middleware> DynMiddleware for M {
69    fn execute<'a>(
70        &'a self,
71        envelope: &'a MessageEnvelope,
72        ctx: &'a HandlerContext,
73        next: Next,
74    ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>> {
75        Box::pin(<M as Middleware>::execute(self, envelope, ctx, next))
76    }
77}
78
79/// Terminal of the middleware chain. The mediator (issue #6) supplies a
80/// concrete implementation that downcasts the message and invokes the
81/// registered handler.
82///
83/// This trait is public so external dispatchers and test harnesses can
84/// build a pipeline without depending on the mediator. The API may evolve
85/// before v1.0.
86pub trait Terminal: Send + Sync + 'static {
87    /// Dispatches the message to its terminal destination.
88    fn dispatch<'a>(
89        &'a self,
90        envelope: &'a MessageEnvelope,
91        ctx: &'a HandlerContext,
92    ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>>;
93}
94
95/// Opaque continuation passed to a [`Middleware`]. Calling [`Next::run`]
96/// proceeds to the next middleware in the chain or to the [`Terminal`] if
97/// the chain is exhausted.
98///
99/// The middleware chain is held as a shared `Arc<[_]>` walked with an index
100/// cursor, so advancing the pipeline is a reference-count bump rather than a
101/// per-dispatch allocation of the chain.
102pub struct Next {
103    chain: Arc<[Arc<dyn DynMiddleware>]>,
104    index: usize,
105    terminal: Arc<dyn Terminal>,
106}
107
108impl Next {
109    /// Builds a new [`Next`] from a chain of middlewares and a terminal.
110    ///
111    /// Middlewares are executed in the order they appear: the first one wraps
112    /// the second, which wraps the third, and so on. The chain accepts any
113    /// `Into<Arc<[_]>>`, so a freshly built `Vec` or a pre-shared `Arc<[_]>`
114    /// (cloned once per dispatch as a reference-count bump) both work.
115    #[must_use]
116    pub fn new(
117        middlewares: impl Into<Arc<[Arc<dyn DynMiddleware>]>>,
118        terminal: Arc<dyn Terminal>,
119    ) -> Self {
120        Self {
121            chain: middlewares.into(),
122            index: 0,
123            terminal,
124        }
125    }
126
127    /// Advances the pipeline by one step.
128    ///
129    /// The context's cancellation token is observed before each step: a
130    /// middleware that cancels the token short-circuits the rest of the
131    /// chain at the next [`Next::run`] call, and the [`Terminal`] is never
132    /// reached. A step that is already executing is not interrupted.
133    ///
134    /// # Errors
135    ///
136    /// Returns [`HexeractError::Cancelled`] if the context's cancellation
137    /// token fired, or the [`HexeractError`] produced by the next middleware
138    /// in the chain or by the [`Terminal`] when the chain is exhausted.
139    pub async fn run(
140        self,
141        envelope: &MessageEnvelope,
142        ctx: &HandlerContext,
143    ) -> Result<BoxOutput, HexeractError> {
144        if ctx.is_cancelled() {
145            return Err(HexeractError::cancelled(envelope.type_name()));
146        }
147        if let Some(head) = self.chain.get(self.index).cloned() {
148            let next = Next {
149                chain: self.chain,
150                index: self.index + 1,
151                terminal: self.terminal,
152            };
153            head.execute(envelope, ctx, next).await
154        } else {
155            self.terminal.dispatch(envelope, ctx).await
156        }
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use crate::ids::{CorrelationId, MessageId};
164    use std::sync::Mutex;
165
166    fn dyn_mw<M: Middleware>(m: M) -> Arc<dyn DynMiddleware> {
167        Arc::new(m)
168    }
169
170    struct DummyCmd;
171    impl crate::command::Command for DummyCmd {
172        type Output = i32;
173    }
174
175    fn fresh_env() -> MessageEnvelope {
176        MessageEnvelope::for_command::<DummyCmd>(MessageId::new(), CorrelationId::new())
177    }
178
179    fn fresh_ctx() -> HandlerContext {
180        HandlerContext::new(MessageId::new(), CorrelationId::new())
181    }
182
183    struct StaticTerminal {
184        value: i32,
185    }
186
187    impl Terminal for StaticTerminal {
188        fn dispatch<'a>(
189            &'a self,
190            _envelope: &'a MessageEnvelope,
191            _ctx: &'a HandlerContext,
192        ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>> {
193            let value = self.value;
194            Box::pin(async move { Ok(Box::new(value) as BoxOutput) })
195        }
196    }
197
198    struct FailingTerminal;
199    impl Terminal for FailingTerminal {
200        fn dispatch<'a>(
201            &'a self,
202            _envelope: &'a MessageEnvelope,
203            _ctx: &'a HandlerContext,
204        ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>> {
205            Box::pin(async move { Err(HexeractError::Dispatch("terminal failure".into())) })
206        }
207    }
208
209    #[derive(Clone)]
210    struct Recorder {
211        trace: Arc<Mutex<Vec<&'static str>>>,
212    }
213
214    impl Recorder {
215        fn new() -> Self {
216            Self {
217                trace: Arc::new(Mutex::new(Vec::new())),
218            }
219        }
220
221        fn snapshot(&self) -> Vec<&'static str> {
222            self.trace.lock().expect("poisoned").clone()
223        }
224    }
225
226    struct TracingMiddleware {
227        name: &'static str,
228        post_label: &'static str,
229        recorder: Recorder,
230    }
231
232    impl Middleware for TracingMiddleware {
233        async fn execute(
234            &self,
235            envelope: &MessageEnvelope,
236            ctx: &HandlerContext,
237            next: Next,
238        ) -> Result<BoxOutput, HexeractError> {
239            self.recorder
240                .trace
241                .lock()
242                .expect("poisoned")
243                .push(self.name);
244            let result = next.run(envelope, ctx).await;
245            self.recorder
246                .trace
247                .lock()
248                .expect("poisoned")
249                .push(self.post_label);
250            result
251        }
252    }
253
254    fn tracing_mw(name: &'static str, post: &'static str, recorder: Recorder) -> TracingMiddleware {
255        TracingMiddleware {
256            name,
257            post_label: post,
258            recorder,
259        }
260    }
261
262    #[tokio::test]
263    async fn single_middleware_delegates_to_terminal() {
264        let recorder = Recorder::new();
265        let next = Next::new(
266            vec![dyn_mw(tracing_mw("A", "A_post", recorder.clone()))],
267            Arc::new(StaticTerminal { value: 42 }),
268        );
269        let output = next
270            .run(&fresh_env(), &fresh_ctx())
271            .await
272            .expect("dispatch should succeed");
273        let downcast = output.downcast::<i32>().expect("output must be i32");
274        assert_eq!(*downcast, 42);
275        assert_eq!(recorder.snapshot(), vec!["A", "A_post"]);
276    }
277
278    #[tokio::test]
279    async fn chain_of_three_executes_in_onion_order() {
280        let recorder = Recorder::new();
281        let next = Next::new(
282            vec![
283                dyn_mw(tracing_mw("A", "A_post", recorder.clone())),
284                dyn_mw(tracing_mw("B", "B_post", recorder.clone())),
285                dyn_mw(tracing_mw("C", "C_post", recorder.clone())),
286            ],
287            Arc::new(StaticTerminal { value: 7 }),
288        );
289        let _ = next.run(&fresh_env(), &fresh_ctx()).await.unwrap();
290        assert_eq!(
291            recorder.snapshot(),
292            vec!["A", "B", "C", "C_post", "B_post", "A_post"]
293        );
294    }
295
296    struct ShortCircuit;
297    impl Middleware for ShortCircuit {
298        async fn execute(
299            &self,
300            _envelope: &MessageEnvelope,
301            _ctx: &HandlerContext,
302            _next: Next,
303        ) -> Result<BoxOutput, HexeractError> {
304            Ok(Box::new(99_i32) as BoxOutput)
305        }
306    }
307
308    #[tokio::test]
309    async fn short_circuit_middleware_skips_terminal() {
310        let next = Next::new(vec![dyn_mw(ShortCircuit)], Arc::new(FailingTerminal));
311        let output = next
312            .run(&fresh_env(), &fresh_ctx())
313            .await
314            .expect("short-circuit must succeed");
315        assert_eq!(*output.downcast::<i32>().unwrap(), 99);
316    }
317
318    #[tokio::test]
319    async fn error_from_terminal_propagates_through_chain() {
320        let recorder = Recorder::new();
321        let next = Next::new(
322            vec![dyn_mw(tracing_mw("A", "A_post", recorder.clone()))],
323            Arc::new(FailingTerminal),
324        );
325        let result = next.run(&fresh_env(), &fresh_ctx()).await;
326        assert!(matches!(result, Err(HexeractError::Dispatch(_))));
327        assert_eq!(recorder.snapshot(), vec!["A", "A_post"]);
328    }
329
330    struct ErrorMiddleware;
331    impl Middleware for ErrorMiddleware {
332        async fn execute(
333            &self,
334            _envelope: &MessageEnvelope,
335            _ctx: &HandlerContext,
336            _next: Next,
337        ) -> Result<BoxOutput, HexeractError> {
338            Err(HexeractError::Dispatch("middleware refusal".into()))
339        }
340    }
341
342    #[tokio::test]
343    async fn error_from_middleware_propagates() {
344        let next = Next::new(
345            vec![dyn_mw(ErrorMiddleware)],
346            Arc::new(StaticTerminal { value: 0 }),
347        );
348        let err = next
349            .run(&fresh_env(), &fresh_ctx())
350            .await
351            .expect_err("middleware should fail");
352        match err {
353            HexeractError::Dispatch(ref m) => assert_eq!(m, "middleware refusal"),
354            other => panic!("unexpected variant: {other:?}"),
355        }
356    }
357
358    struct CancellingMiddleware;
359    impl Middleware for CancellingMiddleware {
360        async fn execute(
361            &self,
362            envelope: &MessageEnvelope,
363            ctx: &HandlerContext,
364            next: Next,
365        ) -> Result<BoxOutput, HexeractError> {
366            ctx.cancellation.cancel();
367            next.run(envelope, ctx).await
368        }
369    }
370
371    #[tokio::test]
372    async fn run_returns_cancelled_when_token_fired_before_dispatch() {
373        let ctx = fresh_ctx();
374        ctx.cancellation.cancel();
375        let next = Next::new(vec![], Arc::new(FailingTerminal));
376        let err = next
377            .run(&fresh_env(), &ctx)
378            .await
379            .expect_err("cancelled dispatch must fail");
380        assert!(
381            matches!(err, HexeractError::Cancelled { type_name } if type_name.contains("DummyCmd"))
382        );
383    }
384
385    #[tokio::test]
386    async fn middleware_cancelling_token_short_circuits_the_chain() {
387        let recorder = Recorder::new();
388        let next = Next::new(
389            vec![
390                dyn_mw(CancellingMiddleware),
391                dyn_mw(tracing_mw("B", "B_post", recorder.clone())),
392            ],
393            Arc::new(FailingTerminal),
394        );
395        let err = next
396            .run(&fresh_env(), &fresh_ctx())
397            .await
398            .expect_err("cancelled chain must fail");
399        assert!(matches!(err, HexeractError::Cancelled { .. }));
400        assert!(recorder.snapshot().is_empty());
401    }
402
403    fn assert_send<T: Send>(_: &T) {}
404
405    #[tokio::test]
406    async fn next_run_future_is_send() {
407        let next = Next::new(vec![], Arc::new(StaticTerminal { value: 1 }));
408        let env = fresh_env();
409        let ctx = fresh_ctx();
410        let future = next.run(&env, &ctx);
411        assert_send(&future);
412        let _ = future.await;
413    }
414
415    #[tokio::test]
416    async fn empty_chain_invokes_terminal_directly() {
417        let next = Next::new(vec![], Arc::new(StaticTerminal { value: 123 }));
418        let output = next.run(&fresh_env(), &fresh_ctx()).await.unwrap();
419        assert_eq!(*output.downcast::<i32>().unwrap(), 123);
420    }
421
422    struct EnvelopeInspector {
423        observed: Arc<Mutex<Option<String>>>,
424    }
425
426    impl Middleware for EnvelopeInspector {
427        async fn execute(
428            &self,
429            envelope: &MessageEnvelope,
430            ctx: &HandlerContext,
431            next: Next,
432        ) -> Result<BoxOutput, HexeractError> {
433            *self.observed.lock().expect("poisoned") = Some(envelope.type_name().to_string());
434            next.run(envelope, ctx).await
435        }
436    }
437
438    #[tokio::test]
439    async fn middleware_reads_envelope_type_name() {
440        let observed = Arc::new(Mutex::new(None));
441        let mw = EnvelopeInspector {
442            observed: Arc::clone(&observed),
443        };
444        let next = Next::new(vec![dyn_mw(mw)], Arc::new(StaticTerminal { value: 0 }));
445        let _ = next.run(&fresh_env(), &fresh_ctx()).await;
446        let observed = observed.lock().unwrap().clone();
447        assert!(observed.unwrap().ends_with("::DummyCmd"));
448    }
449}