1use std::any::Any;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use crate::context::HandlerContext;
7use crate::envelope::MessageEnvelope;
8use crate::error::HexeractError;
9
10pub type BoxOutput = Box<dyn Any + Send + Sync>;
15
16type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
17
18#[trait_variant::make(Send)]
46pub trait Middleware: Send + Sync + 'static {
47 async fn execute(
51 &self,
52 envelope: &MessageEnvelope,
53 ctx: &HandlerContext,
54 next: Next,
55 ) -> Result<BoxOutput, HexeractError>;
56}
57
58#[doc(hidden)]
59pub trait DynMiddleware: Send + Sync + 'static {
60 fn execute<'a>(
61 &'a self,
62 envelope: &'a MessageEnvelope,
63 ctx: &'a HandlerContext,
64 next: Next,
65 ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>>;
66}
67
68impl<M: Middleware> DynMiddleware for M {
69 fn execute<'a>(
70 &'a self,
71 envelope: &'a MessageEnvelope,
72 ctx: &'a HandlerContext,
73 next: Next,
74 ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>> {
75 Box::pin(<M as Middleware>::execute(self, envelope, ctx, next))
76 }
77}
78
79pub trait Terminal: Send + Sync + 'static {
87 fn dispatch<'a>(
89 &'a self,
90 envelope: &'a MessageEnvelope,
91 ctx: &'a HandlerContext,
92 ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>>;
93}
94
95pub struct Next {
103 chain: Arc<[Arc<dyn DynMiddleware>]>,
104 index: usize,
105 terminal: Arc<dyn Terminal>,
106}
107
108impl Next {
109 #[must_use]
116 pub fn new(
117 middlewares: impl Into<Arc<[Arc<dyn DynMiddleware>]>>,
118 terminal: Arc<dyn Terminal>,
119 ) -> Self {
120 Self {
121 chain: middlewares.into(),
122 index: 0,
123 terminal,
124 }
125 }
126
127 pub async fn run(
140 self,
141 envelope: &MessageEnvelope,
142 ctx: &HandlerContext,
143 ) -> Result<BoxOutput, HexeractError> {
144 if ctx.is_cancelled() {
145 return Err(HexeractError::cancelled(envelope.type_name()));
146 }
147 if let Some(head) = self.chain.get(self.index).cloned() {
148 let next = Next {
149 chain: self.chain,
150 index: self.index + 1,
151 terminal: self.terminal,
152 };
153 head.execute(envelope, ctx, next).await
154 } else {
155 self.terminal.dispatch(envelope, ctx).await
156 }
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use crate::ids::{CorrelationId, MessageId};
164 use std::sync::Mutex;
165
166 fn dyn_mw<M: Middleware>(m: M) -> Arc<dyn DynMiddleware> {
167 Arc::new(m)
168 }
169
170 struct DummyCmd;
171 impl crate::command::Command for DummyCmd {
172 type Output = i32;
173 }
174
175 fn fresh_env() -> MessageEnvelope {
176 MessageEnvelope::for_command::<DummyCmd>(MessageId::new(), CorrelationId::new())
177 }
178
179 fn fresh_ctx() -> HandlerContext {
180 HandlerContext::new(MessageId::new(), CorrelationId::new())
181 }
182
183 struct StaticTerminal {
184 value: i32,
185 }
186
187 impl Terminal for StaticTerminal {
188 fn dispatch<'a>(
189 &'a self,
190 _envelope: &'a MessageEnvelope,
191 _ctx: &'a HandlerContext,
192 ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>> {
193 let value = self.value;
194 Box::pin(async move { Ok(Box::new(value) as BoxOutput) })
195 }
196 }
197
198 struct FailingTerminal;
199 impl Terminal for FailingTerminal {
200 fn dispatch<'a>(
201 &'a self,
202 _envelope: &'a MessageEnvelope,
203 _ctx: &'a HandlerContext,
204 ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>> {
205 Box::pin(async move { Err(HexeractError::Dispatch("terminal failure".into())) })
206 }
207 }
208
209 #[derive(Clone)]
210 struct Recorder {
211 trace: Arc<Mutex<Vec<&'static str>>>,
212 }
213
214 impl Recorder {
215 fn new() -> Self {
216 Self {
217 trace: Arc::new(Mutex::new(Vec::new())),
218 }
219 }
220
221 fn snapshot(&self) -> Vec<&'static str> {
222 self.trace.lock().expect("poisoned").clone()
223 }
224 }
225
226 struct TracingMiddleware {
227 name: &'static str,
228 post_label: &'static str,
229 recorder: Recorder,
230 }
231
232 impl Middleware for TracingMiddleware {
233 async fn execute(
234 &self,
235 envelope: &MessageEnvelope,
236 ctx: &HandlerContext,
237 next: Next,
238 ) -> Result<BoxOutput, HexeractError> {
239 self.recorder
240 .trace
241 .lock()
242 .expect("poisoned")
243 .push(self.name);
244 let result = next.run(envelope, ctx).await;
245 self.recorder
246 .trace
247 .lock()
248 .expect("poisoned")
249 .push(self.post_label);
250 result
251 }
252 }
253
254 fn tracing_mw(name: &'static str, post: &'static str, recorder: Recorder) -> TracingMiddleware {
255 TracingMiddleware {
256 name,
257 post_label: post,
258 recorder,
259 }
260 }
261
262 #[tokio::test]
263 async fn single_middleware_delegates_to_terminal() {
264 let recorder = Recorder::new();
265 let next = Next::new(
266 vec![dyn_mw(tracing_mw("A", "A_post", recorder.clone()))],
267 Arc::new(StaticTerminal { value: 42 }),
268 );
269 let output = next
270 .run(&fresh_env(), &fresh_ctx())
271 .await
272 .expect("dispatch should succeed");
273 let downcast = output.downcast::<i32>().expect("output must be i32");
274 assert_eq!(*downcast, 42);
275 assert_eq!(recorder.snapshot(), vec!["A", "A_post"]);
276 }
277
278 #[tokio::test]
279 async fn chain_of_three_executes_in_onion_order() {
280 let recorder = Recorder::new();
281 let next = Next::new(
282 vec![
283 dyn_mw(tracing_mw("A", "A_post", recorder.clone())),
284 dyn_mw(tracing_mw("B", "B_post", recorder.clone())),
285 dyn_mw(tracing_mw("C", "C_post", recorder.clone())),
286 ],
287 Arc::new(StaticTerminal { value: 7 }),
288 );
289 let _ = next.run(&fresh_env(), &fresh_ctx()).await.unwrap();
290 assert_eq!(
291 recorder.snapshot(),
292 vec!["A", "B", "C", "C_post", "B_post", "A_post"]
293 );
294 }
295
296 struct ShortCircuit;
297 impl Middleware for ShortCircuit {
298 async fn execute(
299 &self,
300 _envelope: &MessageEnvelope,
301 _ctx: &HandlerContext,
302 _next: Next,
303 ) -> Result<BoxOutput, HexeractError> {
304 Ok(Box::new(99_i32) as BoxOutput)
305 }
306 }
307
308 #[tokio::test]
309 async fn short_circuit_middleware_skips_terminal() {
310 let next = Next::new(vec![dyn_mw(ShortCircuit)], Arc::new(FailingTerminal));
311 let output = next
312 .run(&fresh_env(), &fresh_ctx())
313 .await
314 .expect("short-circuit must succeed");
315 assert_eq!(*output.downcast::<i32>().unwrap(), 99);
316 }
317
318 #[tokio::test]
319 async fn error_from_terminal_propagates_through_chain() {
320 let recorder = Recorder::new();
321 let next = Next::new(
322 vec![dyn_mw(tracing_mw("A", "A_post", recorder.clone()))],
323 Arc::new(FailingTerminal),
324 );
325 let result = next.run(&fresh_env(), &fresh_ctx()).await;
326 assert!(matches!(result, Err(HexeractError::Dispatch(_))));
327 assert_eq!(recorder.snapshot(), vec!["A", "A_post"]);
328 }
329
330 struct ErrorMiddleware;
331 impl Middleware for ErrorMiddleware {
332 async fn execute(
333 &self,
334 _envelope: &MessageEnvelope,
335 _ctx: &HandlerContext,
336 _next: Next,
337 ) -> Result<BoxOutput, HexeractError> {
338 Err(HexeractError::Dispatch("middleware refusal".into()))
339 }
340 }
341
342 #[tokio::test]
343 async fn error_from_middleware_propagates() {
344 let next = Next::new(
345 vec![dyn_mw(ErrorMiddleware)],
346 Arc::new(StaticTerminal { value: 0 }),
347 );
348 let err = next
349 .run(&fresh_env(), &fresh_ctx())
350 .await
351 .expect_err("middleware should fail");
352 match err {
353 HexeractError::Dispatch(ref m) => assert_eq!(m, "middleware refusal"),
354 other => panic!("unexpected variant: {other:?}"),
355 }
356 }
357
358 struct CancellingMiddleware;
359 impl Middleware for CancellingMiddleware {
360 async fn execute(
361 &self,
362 envelope: &MessageEnvelope,
363 ctx: &HandlerContext,
364 next: Next,
365 ) -> Result<BoxOutput, HexeractError> {
366 ctx.cancellation.cancel();
367 next.run(envelope, ctx).await
368 }
369 }
370
371 #[tokio::test]
372 async fn run_returns_cancelled_when_token_fired_before_dispatch() {
373 let ctx = fresh_ctx();
374 ctx.cancellation.cancel();
375 let next = Next::new(vec![], Arc::new(FailingTerminal));
376 let err = next
377 .run(&fresh_env(), &ctx)
378 .await
379 .expect_err("cancelled dispatch must fail");
380 assert!(
381 matches!(err, HexeractError::Cancelled { type_name } if type_name.contains("DummyCmd"))
382 );
383 }
384
385 #[tokio::test]
386 async fn middleware_cancelling_token_short_circuits_the_chain() {
387 let recorder = Recorder::new();
388 let next = Next::new(
389 vec![
390 dyn_mw(CancellingMiddleware),
391 dyn_mw(tracing_mw("B", "B_post", recorder.clone())),
392 ],
393 Arc::new(FailingTerminal),
394 );
395 let err = next
396 .run(&fresh_env(), &fresh_ctx())
397 .await
398 .expect_err("cancelled chain must fail");
399 assert!(matches!(err, HexeractError::Cancelled { .. }));
400 assert!(recorder.snapshot().is_empty());
401 }
402
403 fn assert_send<T: Send>(_: &T) {}
404
405 #[tokio::test]
406 async fn next_run_future_is_send() {
407 let next = Next::new(vec![], Arc::new(StaticTerminal { value: 1 }));
408 let env = fresh_env();
409 let ctx = fresh_ctx();
410 let future = next.run(&env, &ctx);
411 assert_send(&future);
412 let _ = future.await;
413 }
414
415 #[tokio::test]
416 async fn empty_chain_invokes_terminal_directly() {
417 let next = Next::new(vec![], Arc::new(StaticTerminal { value: 123 }));
418 let output = next.run(&fresh_env(), &fresh_ctx()).await.unwrap();
419 assert_eq!(*output.downcast::<i32>().unwrap(), 123);
420 }
421
422 struct EnvelopeInspector {
423 observed: Arc<Mutex<Option<String>>>,
424 }
425
426 impl Middleware for EnvelopeInspector {
427 async fn execute(
428 &self,
429 envelope: &MessageEnvelope,
430 ctx: &HandlerContext,
431 next: Next,
432 ) -> Result<BoxOutput, HexeractError> {
433 *self.observed.lock().expect("poisoned") = Some(envelope.type_name().to_string());
434 next.run(envelope, ctx).await
435 }
436 }
437
438 #[tokio::test]
439 async fn middleware_reads_envelope_type_name() {
440 let observed = Arc::new(Mutex::new(None));
441 let mw = EnvelopeInspector {
442 observed: Arc::clone(&observed),
443 };
444 let next = Next::new(vec![dyn_mw(mw)], Arc::new(StaticTerminal { value: 0 }));
445 let _ = next.run(&fresh_env(), &fresh_ctx()).await;
446 let observed = observed.lock().unwrap().clone();
447 assert!(observed.unwrap().ends_with("::DummyCmd"));
448 }
449}