Skip to main content

hexeract_core/
middleware.rs

1use std::any::Any;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use crate::context::HandlerContext;
7use crate::envelope::MessageEnvelope;
8use crate::error::HexeractError;
9
10/// Type-erased handler output, passed through the middleware chain.
11///
12/// The terminal dispatcher boxes the concrete `C::Output` into this alias.
13/// Callers downcast back to the typed output at the dispatch boundary.
14pub type BoxOutput = Box<dyn Any + Send + Sync>;
15
16type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
17
18/// Intercepts a dispatch before reaching its handler.
19///
20/// Middlewares are stacked onion-style: the first registered middleware
21/// wraps all the others, observing both the entry and the exit of every
22/// dispatch.
23///
24/// # Example
25///
26/// ```
27/// use hexeract_core::{BoxOutput, HandlerContext, HexeractError, MessageEnvelope, Middleware, Next};
28///
29/// struct LoggingMiddleware;
30///
31/// impl Middleware for LoggingMiddleware {
32///     async fn execute(
33///         &self,
34///         envelope: &MessageEnvelope,
35///         ctx: &HandlerContext,
36///         next: Next,
37///     ) -> Result<BoxOutput, HexeractError> {
38///         tracing::info!(type_name = envelope.type_name(), "dispatching");
39///         let result = next.run(envelope, ctx).await;
40///         tracing::info!(type_name = envelope.type_name(), "dispatched");
41///         result
42///     }
43/// }
44/// ```
45#[trait_variant::make(Send)]
46pub trait Middleware: Send + Sync + 'static {
47    /// Executes the middleware. The implementation must call `next.run(...)`
48    /// to proceed to the next middleware or terminal, unless it intentionally
49    /// short-circuits the chain.
50    async fn execute(
51        &self,
52        envelope: &MessageEnvelope,
53        ctx: &HandlerContext,
54        next: Next,
55    ) -> Result<BoxOutput, HexeractError>;
56}
57
58#[doc(hidden)]
59pub trait DynMiddleware: Send + Sync + 'static {
60    fn execute<'a>(
61        &'a self,
62        envelope: &'a MessageEnvelope,
63        ctx: &'a HandlerContext,
64        next: Next,
65    ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>>;
66}
67
68impl<M: Middleware> DynMiddleware for M {
69    fn execute<'a>(
70        &'a self,
71        envelope: &'a MessageEnvelope,
72        ctx: &'a HandlerContext,
73        next: Next,
74    ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>> {
75        Box::pin(<M as Middleware>::execute(self, envelope, ctx, next))
76    }
77}
78
79/// Terminal of the middleware chain. The mediator (issue #6) supplies a
80/// concrete implementation that downcasts the message and invokes the
81/// registered handler.
82///
83/// This trait is public so external dispatchers and test harnesses can
84/// build a pipeline without depending on the mediator. The API may evolve
85/// before v1.0.
86pub trait Terminal: Send + Sync + 'static {
87    /// Dispatches the message to its terminal destination.
88    fn dispatch<'a>(
89        &'a self,
90        envelope: &'a MessageEnvelope,
91        ctx: &'a HandlerContext,
92    ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>>;
93}
94
95/// Opaque continuation passed to a [`Middleware`]. Calling [`Next::run`]
96/// proceeds to the next middleware in the chain or to the [`Terminal`] if
97/// the chain is exhausted.
98///
99/// The middleware chain is held as a shared `Arc<[_]>` walked with an index
100/// cursor, so advancing the pipeline is a reference-count bump rather than a
101/// per-dispatch allocation of the chain.
102pub struct Next {
103    chain: Arc<[Arc<dyn DynMiddleware>]>,
104    index: usize,
105    terminal: Arc<dyn Terminal>,
106}
107
108impl Next {
109    /// Builds a new [`Next`] from a chain of middlewares and a terminal.
110    ///
111    /// Middlewares are executed in the order they appear: the first one wraps
112    /// the second, which wraps the third, and so on. The chain accepts any
113    /// `Into<Arc<[_]>>`, so a freshly built `Vec` or a pre-shared `Arc<[_]>`
114    /// (cloned once per dispatch as a reference-count bump) both work.
115    #[must_use]
116    pub fn new(
117        middlewares: impl Into<Arc<[Arc<dyn DynMiddleware>]>>,
118        terminal: Arc<dyn Terminal>,
119    ) -> Self {
120        Self {
121            chain: middlewares.into(),
122            index: 0,
123            terminal,
124        }
125    }
126
127    /// Advances the pipeline by one step.
128    ///
129    /// # Errors
130    ///
131    /// Returns the [`HexeractError`] produced by the next middleware in the
132    /// chain or by the [`Terminal`] when the chain is exhausted.
133    pub async fn run(
134        self,
135        envelope: &MessageEnvelope,
136        ctx: &HandlerContext,
137    ) -> Result<BoxOutput, HexeractError> {
138        if let Some(head) = self.chain.get(self.index).cloned() {
139            let next = Next {
140                chain: self.chain,
141                index: self.index + 1,
142                terminal: self.terminal,
143            };
144            head.execute(envelope, ctx, next).await
145        } else {
146            self.terminal.dispatch(envelope, ctx).await
147        }
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use crate::ids::{CorrelationId, MessageId};
155    use std::sync::Mutex;
156
157    fn dyn_mw<M: Middleware>(m: M) -> Arc<dyn DynMiddleware> {
158        Arc::new(m)
159    }
160
161    struct DummyCmd;
162    impl crate::command::Command for DummyCmd {
163        type Output = i32;
164    }
165
166    fn fresh_env() -> MessageEnvelope {
167        MessageEnvelope::for_command::<DummyCmd>(MessageId::new(), CorrelationId::new())
168    }
169
170    fn fresh_ctx() -> HandlerContext {
171        HandlerContext::new(MessageId::new(), CorrelationId::new())
172    }
173
174    struct StaticTerminal {
175        value: i32,
176    }
177
178    impl Terminal for StaticTerminal {
179        fn dispatch<'a>(
180            &'a self,
181            _envelope: &'a MessageEnvelope,
182            _ctx: &'a HandlerContext,
183        ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>> {
184            let value = self.value;
185            Box::pin(async move { Ok(Box::new(value) as BoxOutput) })
186        }
187    }
188
189    struct FailingTerminal;
190    impl Terminal for FailingTerminal {
191        fn dispatch<'a>(
192            &'a self,
193            _envelope: &'a MessageEnvelope,
194            _ctx: &'a HandlerContext,
195        ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>> {
196            Box::pin(async move { Err(HexeractError::Dispatch("terminal failure".into())) })
197        }
198    }
199
200    #[derive(Clone)]
201    struct Recorder {
202        trace: Arc<Mutex<Vec<&'static str>>>,
203    }
204
205    impl Recorder {
206        fn new() -> Self {
207            Self {
208                trace: Arc::new(Mutex::new(Vec::new())),
209            }
210        }
211
212        fn snapshot(&self) -> Vec<&'static str> {
213            self.trace.lock().expect("poisoned").clone()
214        }
215    }
216
217    struct TracingMiddleware {
218        name: &'static str,
219        post_label: &'static str,
220        recorder: Recorder,
221    }
222
223    impl Middleware for TracingMiddleware {
224        async fn execute(
225            &self,
226            envelope: &MessageEnvelope,
227            ctx: &HandlerContext,
228            next: Next,
229        ) -> Result<BoxOutput, HexeractError> {
230            self.recorder
231                .trace
232                .lock()
233                .expect("poisoned")
234                .push(self.name);
235            let result = next.run(envelope, ctx).await;
236            self.recorder
237                .trace
238                .lock()
239                .expect("poisoned")
240                .push(self.post_label);
241            result
242        }
243    }
244
245    fn tracing_mw(name: &'static str, post: &'static str, recorder: Recorder) -> TracingMiddleware {
246        TracingMiddleware {
247            name,
248            post_label: post,
249            recorder,
250        }
251    }
252
253    #[tokio::test]
254    async fn single_middleware_delegates_to_terminal() {
255        let recorder = Recorder::new();
256        let next = Next::new(
257            vec![dyn_mw(tracing_mw("A", "A_post", recorder.clone()))],
258            Arc::new(StaticTerminal { value: 42 }),
259        );
260        let output = next
261            .run(&fresh_env(), &fresh_ctx())
262            .await
263            .expect("dispatch should succeed");
264        let downcast = output.downcast::<i32>().expect("output must be i32");
265        assert_eq!(*downcast, 42);
266        assert_eq!(recorder.snapshot(), vec!["A", "A_post"]);
267    }
268
269    #[tokio::test]
270    async fn chain_of_three_executes_in_onion_order() {
271        let recorder = Recorder::new();
272        let next = Next::new(
273            vec![
274                dyn_mw(tracing_mw("A", "A_post", recorder.clone())),
275                dyn_mw(tracing_mw("B", "B_post", recorder.clone())),
276                dyn_mw(tracing_mw("C", "C_post", recorder.clone())),
277            ],
278            Arc::new(StaticTerminal { value: 7 }),
279        );
280        let _ = next.run(&fresh_env(), &fresh_ctx()).await.unwrap();
281        assert_eq!(
282            recorder.snapshot(),
283            vec!["A", "B", "C", "C_post", "B_post", "A_post"]
284        );
285    }
286
287    struct ShortCircuit;
288    impl Middleware for ShortCircuit {
289        async fn execute(
290            &self,
291            _envelope: &MessageEnvelope,
292            _ctx: &HandlerContext,
293            _next: Next,
294        ) -> Result<BoxOutput, HexeractError> {
295            Ok(Box::new(99_i32) as BoxOutput)
296        }
297    }
298
299    #[tokio::test]
300    async fn short_circuit_middleware_skips_terminal() {
301        let next = Next::new(vec![dyn_mw(ShortCircuit)], Arc::new(FailingTerminal));
302        let output = next
303            .run(&fresh_env(), &fresh_ctx())
304            .await
305            .expect("short-circuit must succeed");
306        assert_eq!(*output.downcast::<i32>().unwrap(), 99);
307    }
308
309    #[tokio::test]
310    async fn error_from_terminal_propagates_through_chain() {
311        let recorder = Recorder::new();
312        let next = Next::new(
313            vec![dyn_mw(tracing_mw("A", "A_post", recorder.clone()))],
314            Arc::new(FailingTerminal),
315        );
316        let result = next.run(&fresh_env(), &fresh_ctx()).await;
317        assert!(matches!(result, Err(HexeractError::Dispatch(_))));
318        assert_eq!(recorder.snapshot(), vec!["A", "A_post"]);
319    }
320
321    struct ErrorMiddleware;
322    impl Middleware for ErrorMiddleware {
323        async fn execute(
324            &self,
325            _envelope: &MessageEnvelope,
326            _ctx: &HandlerContext,
327            _next: Next,
328        ) -> Result<BoxOutput, HexeractError> {
329            Err(HexeractError::Dispatch("middleware refusal".into()))
330        }
331    }
332
333    #[tokio::test]
334    async fn error_from_middleware_propagates() {
335        let next = Next::new(
336            vec![dyn_mw(ErrorMiddleware)],
337            Arc::new(StaticTerminal { value: 0 }),
338        );
339        let err = next
340            .run(&fresh_env(), &fresh_ctx())
341            .await
342            .expect_err("middleware should fail");
343        match err {
344            HexeractError::Dispatch(ref m) => assert_eq!(m, "middleware refusal"),
345            other => panic!("unexpected variant: {other:?}"),
346        }
347    }
348
349    fn assert_send<T: Send>(_: &T) {}
350
351    #[tokio::test]
352    async fn next_run_future_is_send() {
353        let next = Next::new(vec![], Arc::new(StaticTerminal { value: 1 }));
354        let env = fresh_env();
355        let ctx = fresh_ctx();
356        let future = next.run(&env, &ctx);
357        assert_send(&future);
358        let _ = future.await;
359    }
360
361    #[tokio::test]
362    async fn empty_chain_invokes_terminal_directly() {
363        let next = Next::new(vec![], Arc::new(StaticTerminal { value: 123 }));
364        let output = next.run(&fresh_env(), &fresh_ctx()).await.unwrap();
365        assert_eq!(*output.downcast::<i32>().unwrap(), 123);
366    }
367
368    struct EnvelopeInspector {
369        observed: Arc<Mutex<Option<String>>>,
370    }
371
372    impl Middleware for EnvelopeInspector {
373        async fn execute(
374            &self,
375            envelope: &MessageEnvelope,
376            ctx: &HandlerContext,
377            next: Next,
378        ) -> Result<BoxOutput, HexeractError> {
379            *self.observed.lock().expect("poisoned") = Some(envelope.type_name().to_string());
380            next.run(envelope, ctx).await
381        }
382    }
383
384    #[tokio::test]
385    async fn middleware_reads_envelope_type_name() {
386        let observed = Arc::new(Mutex::new(None));
387        let mw = EnvelopeInspector {
388            observed: Arc::clone(&observed),
389        };
390        let next = Next::new(vec![dyn_mw(mw)], Arc::new(StaticTerminal { value: 0 }));
391        let _ = next.run(&fresh_env(), &fresh_ctx()).await;
392        let observed = observed.lock().unwrap().clone();
393        assert!(observed.unwrap().ends_with("::DummyCmd"));
394    }
395}