1use std::any::Any;
2use std::collections::VecDeque;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use crate::context::HandlerContext;
8use crate::envelope::MessageEnvelope;
9use crate::error::HexeractError;
10
11pub type BoxOutput = Box<dyn Any + Send + Sync>;
16
17type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
18
19#[trait_variant::make(Send)]
47pub trait Middleware: Send + Sync + 'static {
48 async fn execute(
52 &self,
53 envelope: &MessageEnvelope,
54 ctx: &HandlerContext,
55 next: Next,
56 ) -> Result<BoxOutput, HexeractError>;
57}
58
59#[doc(hidden)]
60pub trait DynMiddleware: Send + Sync + 'static {
61 fn execute<'a>(
62 &'a self,
63 envelope: &'a MessageEnvelope,
64 ctx: &'a HandlerContext,
65 next: Next,
66 ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>>;
67}
68
69impl<M: Middleware> DynMiddleware for M {
70 fn execute<'a>(
71 &'a self,
72 envelope: &'a MessageEnvelope,
73 ctx: &'a HandlerContext,
74 next: Next,
75 ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>> {
76 Box::pin(<M as Middleware>::execute(self, envelope, ctx, next))
77 }
78}
79
80pub trait Terminal: Send + Sync + 'static {
88 fn dispatch<'a>(
90 &'a self,
91 envelope: &'a MessageEnvelope,
92 ctx: &'a HandlerContext,
93 ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>>;
94}
95
96pub struct Next {
100 chain: VecDeque<Arc<dyn DynMiddleware>>,
101 terminal: Arc<dyn Terminal>,
102}
103
104impl Next {
105 #[must_use]
110 pub fn new(middlewares: Vec<Arc<dyn DynMiddleware>>, terminal: Arc<dyn Terminal>) -> Self {
111 Self {
112 chain: middlewares.into(),
113 terminal,
114 }
115 }
116
117 pub async fn run(
124 mut self,
125 envelope: &MessageEnvelope,
126 ctx: &HandlerContext,
127 ) -> Result<BoxOutput, HexeractError> {
128 if let Some(head) = self.chain.pop_front() {
129 head.execute(envelope, ctx, self).await
130 } else {
131 self.terminal.dispatch(envelope, ctx).await
132 }
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use crate::ids::{CorrelationId, MessageId};
140 use std::sync::Mutex;
141
142 fn dyn_mw<M: Middleware>(m: M) -> Arc<dyn DynMiddleware> {
143 Arc::new(m)
144 }
145
146 struct DummyCmd;
147 impl crate::command::Command for DummyCmd {
148 type Output = i32;
149 }
150
151 fn fresh_env() -> MessageEnvelope {
152 MessageEnvelope::for_command::<DummyCmd>(MessageId::new(), CorrelationId::new())
153 }
154
155 fn fresh_ctx() -> HandlerContext {
156 HandlerContext::new(MessageId::new(), CorrelationId::new())
157 }
158
159 struct StaticTerminal {
160 value: i32,
161 }
162
163 impl Terminal for StaticTerminal {
164 fn dispatch<'a>(
165 &'a self,
166 _envelope: &'a MessageEnvelope,
167 _ctx: &'a HandlerContext,
168 ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>> {
169 let value = self.value;
170 Box::pin(async move { Ok(Box::new(value) as BoxOutput) })
171 }
172 }
173
174 struct FailingTerminal;
175 impl Terminal for FailingTerminal {
176 fn dispatch<'a>(
177 &'a self,
178 _envelope: &'a MessageEnvelope,
179 _ctx: &'a HandlerContext,
180 ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>> {
181 Box::pin(async move { Err(HexeractError::Dispatch("terminal failure".into())) })
182 }
183 }
184
185 #[derive(Clone)]
186 struct Recorder {
187 trace: Arc<Mutex<Vec<&'static str>>>,
188 }
189
190 impl Recorder {
191 fn new() -> Self {
192 Self {
193 trace: Arc::new(Mutex::new(Vec::new())),
194 }
195 }
196
197 fn snapshot(&self) -> Vec<&'static str> {
198 self.trace.lock().expect("poisoned").clone()
199 }
200 }
201
202 struct TracingMiddleware {
203 name: &'static str,
204 post_label: &'static str,
205 recorder: Recorder,
206 }
207
208 impl Middleware for TracingMiddleware {
209 async fn execute(
210 &self,
211 envelope: &MessageEnvelope,
212 ctx: &HandlerContext,
213 next: Next,
214 ) -> Result<BoxOutput, HexeractError> {
215 self.recorder
216 .trace
217 .lock()
218 .expect("poisoned")
219 .push(self.name);
220 let result = next.run(envelope, ctx).await;
221 self.recorder
222 .trace
223 .lock()
224 .expect("poisoned")
225 .push(self.post_label);
226 result
227 }
228 }
229
230 fn tracing_mw(name: &'static str, post: &'static str, recorder: Recorder) -> TracingMiddleware {
231 TracingMiddleware {
232 name,
233 post_label: post,
234 recorder,
235 }
236 }
237
238 #[tokio::test]
239 async fn single_middleware_delegates_to_terminal() {
240 let recorder = Recorder::new();
241 let next = Next::new(
242 vec![dyn_mw(tracing_mw("A", "A_post", recorder.clone()))],
243 Arc::new(StaticTerminal { value: 42 }),
244 );
245 let output = next
246 .run(&fresh_env(), &fresh_ctx())
247 .await
248 .expect("dispatch should succeed");
249 let downcast = output.downcast::<i32>().expect("output must be i32");
250 assert_eq!(*downcast, 42);
251 assert_eq!(recorder.snapshot(), vec!["A", "A_post"]);
252 }
253
254 #[tokio::test]
255 async fn chain_of_three_executes_in_onion_order() {
256 let recorder = Recorder::new();
257 let next = Next::new(
258 vec![
259 dyn_mw(tracing_mw("A", "A_post", recorder.clone())),
260 dyn_mw(tracing_mw("B", "B_post", recorder.clone())),
261 dyn_mw(tracing_mw("C", "C_post", recorder.clone())),
262 ],
263 Arc::new(StaticTerminal { value: 7 }),
264 );
265 let _ = next.run(&fresh_env(), &fresh_ctx()).await.unwrap();
266 assert_eq!(
267 recorder.snapshot(),
268 vec!["A", "B", "C", "C_post", "B_post", "A_post"]
269 );
270 }
271
272 struct ShortCircuit;
273 impl Middleware for ShortCircuit {
274 async fn execute(
275 &self,
276 _envelope: &MessageEnvelope,
277 _ctx: &HandlerContext,
278 _next: Next,
279 ) -> Result<BoxOutput, HexeractError> {
280 Ok(Box::new(99_i32) as BoxOutput)
281 }
282 }
283
284 #[tokio::test]
285 async fn short_circuit_middleware_skips_terminal() {
286 let next = Next::new(vec![dyn_mw(ShortCircuit)], Arc::new(FailingTerminal));
287 let output = next
288 .run(&fresh_env(), &fresh_ctx())
289 .await
290 .expect("short-circuit must succeed");
291 assert_eq!(*output.downcast::<i32>().unwrap(), 99);
292 }
293
294 #[tokio::test]
295 async fn error_from_terminal_propagates_through_chain() {
296 let recorder = Recorder::new();
297 let next = Next::new(
298 vec![dyn_mw(tracing_mw("A", "A_post", recorder.clone()))],
299 Arc::new(FailingTerminal),
300 );
301 let result = next.run(&fresh_env(), &fresh_ctx()).await;
302 assert!(matches!(result, Err(HexeractError::Dispatch(_))));
303 assert_eq!(recorder.snapshot(), vec!["A", "A_post"]);
304 }
305
306 struct ErrorMiddleware;
307 impl Middleware for ErrorMiddleware {
308 async fn execute(
309 &self,
310 _envelope: &MessageEnvelope,
311 _ctx: &HandlerContext,
312 _next: Next,
313 ) -> Result<BoxOutput, HexeractError> {
314 Err(HexeractError::Dispatch("middleware refusal".into()))
315 }
316 }
317
318 #[tokio::test]
319 async fn error_from_middleware_propagates() {
320 let next = Next::new(
321 vec![dyn_mw(ErrorMiddleware)],
322 Arc::new(StaticTerminal { value: 0 }),
323 );
324 let err = next
325 .run(&fresh_env(), &fresh_ctx())
326 .await
327 .expect_err("middleware should fail");
328 match err {
329 HexeractError::Dispatch(ref m) => assert_eq!(m, "middleware refusal"),
330 other => panic!("unexpected variant: {other:?}"),
331 }
332 }
333
334 fn assert_send<T: Send>(_: &T) {}
335
336 #[tokio::test]
337 async fn next_run_future_is_send() {
338 let next = Next::new(vec![], Arc::new(StaticTerminal { value: 1 }));
339 let env = fresh_env();
340 let ctx = fresh_ctx();
341 let future = next.run(&env, &ctx);
342 assert_send(&future);
343 let _ = future.await;
344 }
345
346 #[tokio::test]
347 async fn empty_chain_invokes_terminal_directly() {
348 let next = Next::new(vec![], Arc::new(StaticTerminal { value: 123 }));
349 let output = next.run(&fresh_env(), &fresh_ctx()).await.unwrap();
350 assert_eq!(*output.downcast::<i32>().unwrap(), 123);
351 }
352
353 struct EnvelopeInspector {
354 observed: Arc<Mutex<Option<String>>>,
355 }
356
357 impl Middleware for EnvelopeInspector {
358 async fn execute(
359 &self,
360 envelope: &MessageEnvelope,
361 ctx: &HandlerContext,
362 next: Next,
363 ) -> Result<BoxOutput, HexeractError> {
364 *self.observed.lock().expect("poisoned") = Some(envelope.type_name().to_string());
365 next.run(envelope, ctx).await
366 }
367 }
368
369 #[tokio::test]
370 async fn middleware_reads_envelope_type_name() {
371 let observed = Arc::new(Mutex::new(None));
372 let mw = EnvelopeInspector {
373 observed: Arc::clone(&observed),
374 };
375 let next = Next::new(vec![dyn_mw(mw)], Arc::new(StaticTerminal { value: 0 }));
376 let _ = next.run(&fresh_env(), &fresh_ctx()).await;
377 let observed = observed.lock().unwrap().clone();
378 assert!(observed.unwrap().ends_with("::DummyCmd"));
379 }
380}