1use std::any::Any;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use crate::context::HandlerContext;
7use crate::envelope::MessageEnvelope;
8use crate::error::HexeractError;
9
10pub type BoxOutput = Box<dyn Any + Send + Sync>;
15
16type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
17
18#[trait_variant::make(Send)]
46pub trait Middleware: Send + Sync + 'static {
47 async fn execute(
51 &self,
52 envelope: &MessageEnvelope,
53 ctx: &HandlerContext,
54 next: Next,
55 ) -> Result<BoxOutput, HexeractError>;
56}
57
58#[doc(hidden)]
59pub trait DynMiddleware: Send + Sync + 'static {
60 fn execute<'a>(
61 &'a self,
62 envelope: &'a MessageEnvelope,
63 ctx: &'a HandlerContext,
64 next: Next,
65 ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>>;
66}
67
68impl<M: Middleware> DynMiddleware for M {
69 fn execute<'a>(
70 &'a self,
71 envelope: &'a MessageEnvelope,
72 ctx: &'a HandlerContext,
73 next: Next,
74 ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>> {
75 Box::pin(<M as Middleware>::execute(self, envelope, ctx, next))
76 }
77}
78
79pub trait Terminal: Send + Sync + 'static {
87 fn dispatch<'a>(
89 &'a self,
90 envelope: &'a MessageEnvelope,
91 ctx: &'a HandlerContext,
92 ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>>;
93}
94
95pub struct Next {
103 chain: Arc<[Arc<dyn DynMiddleware>]>,
104 index: usize,
105 terminal: Arc<dyn Terminal>,
106}
107
108impl Next {
109 #[must_use]
116 pub fn new(
117 middlewares: impl Into<Arc<[Arc<dyn DynMiddleware>]>>,
118 terminal: Arc<dyn Terminal>,
119 ) -> Self {
120 Self {
121 chain: middlewares.into(),
122 index: 0,
123 terminal,
124 }
125 }
126
127 pub async fn run(
134 self,
135 envelope: &MessageEnvelope,
136 ctx: &HandlerContext,
137 ) -> Result<BoxOutput, HexeractError> {
138 if let Some(head) = self.chain.get(self.index).cloned() {
139 let next = Next {
140 chain: self.chain,
141 index: self.index + 1,
142 terminal: self.terminal,
143 };
144 head.execute(envelope, ctx, next).await
145 } else {
146 self.terminal.dispatch(envelope, ctx).await
147 }
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use crate::ids::{CorrelationId, MessageId};
155 use std::sync::Mutex;
156
157 fn dyn_mw<M: Middleware>(m: M) -> Arc<dyn DynMiddleware> {
158 Arc::new(m)
159 }
160
161 struct DummyCmd;
162 impl crate::command::Command for DummyCmd {
163 type Output = i32;
164 }
165
166 fn fresh_env() -> MessageEnvelope {
167 MessageEnvelope::for_command::<DummyCmd>(MessageId::new(), CorrelationId::new())
168 }
169
170 fn fresh_ctx() -> HandlerContext {
171 HandlerContext::new(MessageId::new(), CorrelationId::new())
172 }
173
174 struct StaticTerminal {
175 value: i32,
176 }
177
178 impl Terminal for StaticTerminal {
179 fn dispatch<'a>(
180 &'a self,
181 _envelope: &'a MessageEnvelope,
182 _ctx: &'a HandlerContext,
183 ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>> {
184 let value = self.value;
185 Box::pin(async move { Ok(Box::new(value) as BoxOutput) })
186 }
187 }
188
189 struct FailingTerminal;
190 impl Terminal for FailingTerminal {
191 fn dispatch<'a>(
192 &'a self,
193 _envelope: &'a MessageEnvelope,
194 _ctx: &'a HandlerContext,
195 ) -> BoxFuture<'a, Result<BoxOutput, HexeractError>> {
196 Box::pin(async move { Err(HexeractError::Dispatch("terminal failure".into())) })
197 }
198 }
199
200 #[derive(Clone)]
201 struct Recorder {
202 trace: Arc<Mutex<Vec<&'static str>>>,
203 }
204
205 impl Recorder {
206 fn new() -> Self {
207 Self {
208 trace: Arc::new(Mutex::new(Vec::new())),
209 }
210 }
211
212 fn snapshot(&self) -> Vec<&'static str> {
213 self.trace.lock().expect("poisoned").clone()
214 }
215 }
216
217 struct TracingMiddleware {
218 name: &'static str,
219 post_label: &'static str,
220 recorder: Recorder,
221 }
222
223 impl Middleware for TracingMiddleware {
224 async fn execute(
225 &self,
226 envelope: &MessageEnvelope,
227 ctx: &HandlerContext,
228 next: Next,
229 ) -> Result<BoxOutput, HexeractError> {
230 self.recorder
231 .trace
232 .lock()
233 .expect("poisoned")
234 .push(self.name);
235 let result = next.run(envelope, ctx).await;
236 self.recorder
237 .trace
238 .lock()
239 .expect("poisoned")
240 .push(self.post_label);
241 result
242 }
243 }
244
245 fn tracing_mw(name: &'static str, post: &'static str, recorder: Recorder) -> TracingMiddleware {
246 TracingMiddleware {
247 name,
248 post_label: post,
249 recorder,
250 }
251 }
252
253 #[tokio::test]
254 async fn single_middleware_delegates_to_terminal() {
255 let recorder = Recorder::new();
256 let next = Next::new(
257 vec![dyn_mw(tracing_mw("A", "A_post", recorder.clone()))],
258 Arc::new(StaticTerminal { value: 42 }),
259 );
260 let output = next
261 .run(&fresh_env(), &fresh_ctx())
262 .await
263 .expect("dispatch should succeed");
264 let downcast = output.downcast::<i32>().expect("output must be i32");
265 assert_eq!(*downcast, 42);
266 assert_eq!(recorder.snapshot(), vec!["A", "A_post"]);
267 }
268
269 #[tokio::test]
270 async fn chain_of_three_executes_in_onion_order() {
271 let recorder = Recorder::new();
272 let next = Next::new(
273 vec![
274 dyn_mw(tracing_mw("A", "A_post", recorder.clone())),
275 dyn_mw(tracing_mw("B", "B_post", recorder.clone())),
276 dyn_mw(tracing_mw("C", "C_post", recorder.clone())),
277 ],
278 Arc::new(StaticTerminal { value: 7 }),
279 );
280 let _ = next.run(&fresh_env(), &fresh_ctx()).await.unwrap();
281 assert_eq!(
282 recorder.snapshot(),
283 vec!["A", "B", "C", "C_post", "B_post", "A_post"]
284 );
285 }
286
287 struct ShortCircuit;
288 impl Middleware for ShortCircuit {
289 async fn execute(
290 &self,
291 _envelope: &MessageEnvelope,
292 _ctx: &HandlerContext,
293 _next: Next,
294 ) -> Result<BoxOutput, HexeractError> {
295 Ok(Box::new(99_i32) as BoxOutput)
296 }
297 }
298
299 #[tokio::test]
300 async fn short_circuit_middleware_skips_terminal() {
301 let next = Next::new(vec![dyn_mw(ShortCircuit)], Arc::new(FailingTerminal));
302 let output = next
303 .run(&fresh_env(), &fresh_ctx())
304 .await
305 .expect("short-circuit must succeed");
306 assert_eq!(*output.downcast::<i32>().unwrap(), 99);
307 }
308
309 #[tokio::test]
310 async fn error_from_terminal_propagates_through_chain() {
311 let recorder = Recorder::new();
312 let next = Next::new(
313 vec![dyn_mw(tracing_mw("A", "A_post", recorder.clone()))],
314 Arc::new(FailingTerminal),
315 );
316 let result = next.run(&fresh_env(), &fresh_ctx()).await;
317 assert!(matches!(result, Err(HexeractError::Dispatch(_))));
318 assert_eq!(recorder.snapshot(), vec!["A", "A_post"]);
319 }
320
321 struct ErrorMiddleware;
322 impl Middleware for ErrorMiddleware {
323 async fn execute(
324 &self,
325 _envelope: &MessageEnvelope,
326 _ctx: &HandlerContext,
327 _next: Next,
328 ) -> Result<BoxOutput, HexeractError> {
329 Err(HexeractError::Dispatch("middleware refusal".into()))
330 }
331 }
332
333 #[tokio::test]
334 async fn error_from_middleware_propagates() {
335 let next = Next::new(
336 vec![dyn_mw(ErrorMiddleware)],
337 Arc::new(StaticTerminal { value: 0 }),
338 );
339 let err = next
340 .run(&fresh_env(), &fresh_ctx())
341 .await
342 .expect_err("middleware should fail");
343 match err {
344 HexeractError::Dispatch(ref m) => assert_eq!(m, "middleware refusal"),
345 other => panic!("unexpected variant: {other:?}"),
346 }
347 }
348
349 fn assert_send<T: Send>(_: &T) {}
350
351 #[tokio::test]
352 async fn next_run_future_is_send() {
353 let next = Next::new(vec![], Arc::new(StaticTerminal { value: 1 }));
354 let env = fresh_env();
355 let ctx = fresh_ctx();
356 let future = next.run(&env, &ctx);
357 assert_send(&future);
358 let _ = future.await;
359 }
360
361 #[tokio::test]
362 async fn empty_chain_invokes_terminal_directly() {
363 let next = Next::new(vec![], Arc::new(StaticTerminal { value: 123 }));
364 let output = next.run(&fresh_env(), &fresh_ctx()).await.unwrap();
365 assert_eq!(*output.downcast::<i32>().unwrap(), 123);
366 }
367
368 struct EnvelopeInspector {
369 observed: Arc<Mutex<Option<String>>>,
370 }
371
372 impl Middleware for EnvelopeInspector {
373 async fn execute(
374 &self,
375 envelope: &MessageEnvelope,
376 ctx: &HandlerContext,
377 next: Next,
378 ) -> Result<BoxOutput, HexeractError> {
379 *self.observed.lock().expect("poisoned") = Some(envelope.type_name().to_string());
380 next.run(envelope, ctx).await
381 }
382 }
383
384 #[tokio::test]
385 async fn middleware_reads_envelope_type_name() {
386 let observed = Arc::new(Mutex::new(None));
387 let mw = EnvelopeInspector {
388 observed: Arc::clone(&observed),
389 };
390 let next = Next::new(vec![dyn_mw(mw)], Arc::new(StaticTerminal { value: 0 }));
391 let _ = next.run(&fresh_env(), &fresh_ctx()).await;
392 let observed = observed.lock().unwrap().clone();
393 assert!(observed.unwrap().ends_with("::DummyCmd"));
394 }
395}