1use crate::effect::Scope;
12use crate::environment::EnvironmentSpec;
13use crate::error::{EnvError, OrchError, StateError};
14use crate::id::OperatorId;
15use crate::operator::{OperatorInput, OperatorOutput};
16use crate::state::StoreOptions;
17use async_trait::async_trait;
18use std::sync::Arc;
19
20#[async_trait]
29pub trait DispatchNext: Send + Sync {
30 async fn dispatch(
32 &self,
33 operator: &OperatorId,
34 input: OperatorInput,
35 ) -> Result<OperatorOutput, OrchError>;
36}
37
38#[async_trait]
44pub trait DispatchMiddleware: Send + Sync {
45 async fn dispatch(
47 &self,
48 operator: &OperatorId,
49 input: OperatorInput,
50 next: &dyn DispatchNext,
51 ) -> Result<OperatorOutput, OrchError>;
52}
53
54#[async_trait]
60pub trait StoreWriteNext: Send + Sync {
61 async fn write(
63 &self,
64 scope: &Scope,
65 key: &str,
66 value: serde_json::Value,
67 options: Option<&StoreOptions>,
68 ) -> Result<(), StateError>;
69}
70
71#[async_trait]
73pub trait StoreReadNext: Send + Sync {
74 async fn read(&self, scope: &Scope, key: &str)
76 -> Result<Option<serde_json::Value>, StateError>;
77}
78
79#[async_trait]
83pub trait StoreMiddleware: Send + Sync {
84 async fn write(
86 &self,
87 scope: &Scope,
88 key: &str,
89 value: serde_json::Value,
90 options: Option<&StoreOptions>,
91 next: &dyn StoreWriteNext,
92 ) -> Result<(), StateError>;
93
94 async fn read(
96 &self,
97 scope: &Scope,
98 key: &str,
99 next: &dyn StoreReadNext,
100 ) -> Result<Option<serde_json::Value>, StateError> {
101 next.read(scope, key).await
102 }
103}
104
105#[async_trait]
111pub trait ExecNext: Send + Sync {
112 async fn run(
114 &self,
115 input: OperatorInput,
116 spec: &EnvironmentSpec,
117 ) -> Result<OperatorOutput, EnvError>;
118}
119
120#[async_trait]
124pub trait ExecMiddleware: Send + Sync {
125 async fn run(
127 &self,
128 input: OperatorInput,
129 spec: &EnvironmentSpec,
130 next: &dyn ExecNext,
131 ) -> Result<OperatorOutput, EnvError>;
132}
133
134pub struct DispatchStack {
147 layers: Vec<Arc<dyn DispatchMiddleware>>,
149}
150
151pub struct DispatchStackBuilder {
153 observers: Vec<Arc<dyn DispatchMiddleware>>,
154 transformers: Vec<Arc<dyn DispatchMiddleware>>,
155 guards: Vec<Arc<dyn DispatchMiddleware>>,
156}
157
158impl DispatchStack {
159 pub fn builder() -> DispatchStackBuilder {
161 DispatchStackBuilder {
162 observers: Vec::new(),
163 transformers: Vec::new(),
164 guards: Vec::new(),
165 }
166 }
167
168 pub async fn dispatch_with(
170 &self,
171 operator: &OperatorId,
172 input: OperatorInput,
173 terminal: &dyn DispatchNext,
174 ) -> Result<OperatorOutput, OrchError> {
175 if self.layers.is_empty() {
176 return terminal.dispatch(operator, input).await;
177 }
178 let chain = DispatchChain {
179 layers: &self.layers,
180 index: 0,
181 terminal,
182 };
183 chain.dispatch(operator, input).await
184 }
185}
186
187impl DispatchStackBuilder {
188 pub fn observe(mut self, mw: Arc<dyn DispatchMiddleware>) -> Self {
190 self.observers.push(mw);
191 self
192 }
193
194 pub fn transform(mut self, mw: Arc<dyn DispatchMiddleware>) -> Self {
196 self.transformers.push(mw);
197 self
198 }
199
200 pub fn guard(mut self, mw: Arc<dyn DispatchMiddleware>) -> Self {
202 self.guards.push(mw);
203 self
204 }
205
206 pub fn build(self) -> DispatchStack {
208 let mut layers = Vec::new();
209 layers.extend(self.observers);
210 layers.extend(self.transformers);
211 layers.extend(self.guards);
212 DispatchStack { layers }
213 }
214}
215
216struct DispatchChain<'a> {
217 layers: &'a [Arc<dyn DispatchMiddleware>],
218 index: usize,
219 terminal: &'a dyn DispatchNext,
220}
221
222#[async_trait]
223impl DispatchNext for DispatchChain<'_> {
224 async fn dispatch(
225 &self,
226 operator: &OperatorId,
227 input: OperatorInput,
228 ) -> Result<OperatorOutput, OrchError> {
229 if self.index >= self.layers.len() {
230 return self.terminal.dispatch(operator, input).await;
231 }
232 let next = DispatchChain {
233 layers: self.layers,
234 index: self.index + 1,
235 terminal: self.terminal,
236 };
237 self.layers[self.index]
238 .dispatch(operator, input, &next)
239 .await
240 }
241}
242
243pub struct StoreStack {
252 layers: Vec<Arc<dyn StoreMiddleware>>,
253}
254
255pub struct StoreStackBuilder {
257 observers: Vec<Arc<dyn StoreMiddleware>>,
258 transformers: Vec<Arc<dyn StoreMiddleware>>,
259 guards: Vec<Arc<dyn StoreMiddleware>>,
260}
261
262impl StoreStack {
263 pub fn builder() -> StoreStackBuilder {
265 StoreStackBuilder {
266 observers: Vec::new(),
267 transformers: Vec::new(),
268 guards: Vec::new(),
269 }
270 }
271
272 pub async fn write_with(
274 &self,
275 scope: &Scope,
276 key: &str,
277 value: serde_json::Value,
278 options: Option<&StoreOptions>,
279 terminal: &dyn StoreWriteNext,
280 ) -> Result<(), StateError> {
281 if self.layers.is_empty() {
282 return terminal.write(scope, key, value, options).await;
283 }
284 let chain = StoreWriteChain {
285 layers: &self.layers,
286 index: 0,
287 terminal,
288 options,
289 };
290 chain.write(scope, key, value, options).await
291 }
292
293 pub async fn read_with(
295 &self,
296 scope: &Scope,
297 key: &str,
298 terminal: &dyn StoreReadNext,
299 ) -> Result<Option<serde_json::Value>, StateError> {
300 if self.layers.is_empty() {
301 return terminal.read(scope, key).await;
302 }
303 let chain = StoreReadChain {
304 layers: &self.layers,
305 index: 0,
306 terminal,
307 };
308 chain.read(scope, key).await
309 }
310}
311
312impl StoreStackBuilder {
313 pub fn observe(mut self, mw: Arc<dyn StoreMiddleware>) -> Self {
315 self.observers.push(mw);
316 self
317 }
318
319 pub fn transform(mut self, mw: Arc<dyn StoreMiddleware>) -> Self {
321 self.transformers.push(mw);
322 self
323 }
324
325 pub fn guard(mut self, mw: Arc<dyn StoreMiddleware>) -> Self {
327 self.guards.push(mw);
328 self
329 }
330
331 pub fn build(self) -> StoreStack {
333 let mut layers = Vec::new();
334 layers.extend(self.observers);
335 layers.extend(self.transformers);
336 layers.extend(self.guards);
337 StoreStack { layers }
338 }
339}
340
341struct StoreWriteChain<'a> {
342 layers: &'a [Arc<dyn StoreMiddleware>],
343 index: usize,
344 terminal: &'a dyn StoreWriteNext,
345 options: Option<&'a StoreOptions>,
346}
347
348#[async_trait]
349impl StoreWriteNext for StoreWriteChain<'_> {
350 async fn write(
351 &self,
352 scope: &Scope,
353 key: &str,
354 value: serde_json::Value,
355 options: Option<&StoreOptions>,
356 ) -> Result<(), StateError> {
357 if self.index >= self.layers.len() {
358 return self.terminal.write(scope, key, value, options).await;
359 }
360 let next = StoreWriteChain {
361 layers: self.layers,
362 index: self.index + 1,
363 terminal: self.terminal,
364 options: self.options,
365 };
366 self.layers[self.index]
367 .write(scope, key, value, options, &next)
368 .await
369 }
370}
371
372struct StoreReadChain<'a> {
373 layers: &'a [Arc<dyn StoreMiddleware>],
374 index: usize,
375 terminal: &'a dyn StoreReadNext,
376}
377
378#[async_trait]
379impl StoreReadNext for StoreReadChain<'_> {
380 async fn read(
381 &self,
382 scope: &Scope,
383 key: &str,
384 ) -> Result<Option<serde_json::Value>, StateError> {
385 if self.index >= self.layers.len() {
386 return self.terminal.read(scope, key).await;
387 }
388 let next = StoreReadChain {
389 layers: self.layers,
390 index: self.index + 1,
391 terminal: self.terminal,
392 };
393 self.layers[self.index].read(scope, key, &next).await
394 }
395}
396
397pub struct ExecStack {
406 layers: Vec<Arc<dyn ExecMiddleware>>,
407}
408
409pub struct ExecStackBuilder {
411 observers: Vec<Arc<dyn ExecMiddleware>>,
412 transformers: Vec<Arc<dyn ExecMiddleware>>,
413 guards: Vec<Arc<dyn ExecMiddleware>>,
414}
415
416impl ExecStack {
417 pub fn builder() -> ExecStackBuilder {
419 ExecStackBuilder {
420 observers: Vec::new(),
421 transformers: Vec::new(),
422 guards: Vec::new(),
423 }
424 }
425
426 pub async fn run_with(
428 &self,
429 input: OperatorInput,
430 spec: &EnvironmentSpec,
431 terminal: &dyn ExecNext,
432 ) -> Result<OperatorOutput, EnvError> {
433 if self.layers.is_empty() {
434 return terminal.run(input, spec).await;
435 }
436 let chain = ExecChain {
437 layers: &self.layers,
438 index: 0,
439 terminal,
440 };
441 chain.run(input, spec).await
442 }
443}
444
445impl ExecStackBuilder {
446 pub fn observe(mut self, mw: Arc<dyn ExecMiddleware>) -> Self {
448 self.observers.push(mw);
449 self
450 }
451
452 pub fn transform(mut self, mw: Arc<dyn ExecMiddleware>) -> Self {
454 self.transformers.push(mw);
455 self
456 }
457
458 pub fn guard(mut self, mw: Arc<dyn ExecMiddleware>) -> Self {
460 self.guards.push(mw);
461 self
462 }
463
464 pub fn build(self) -> ExecStack {
466 let mut layers = Vec::new();
467 layers.extend(self.observers);
468 layers.extend(self.transformers);
469 layers.extend(self.guards);
470 ExecStack { layers }
471 }
472}
473
474struct ExecChain<'a> {
475 layers: &'a [Arc<dyn ExecMiddleware>],
476 index: usize,
477 terminal: &'a dyn ExecNext,
478}
479
480#[async_trait]
481impl ExecNext for ExecChain<'_> {
482 async fn run(
483 &self,
484 input: OperatorInput,
485 spec: &EnvironmentSpec,
486 ) -> Result<OperatorOutput, EnvError> {
487 if self.index >= self.layers.len() {
488 return self.terminal.run(input, spec).await;
489 }
490 let next = ExecChain {
491 layers: self.layers,
492 index: self.index + 1,
493 terminal: self.terminal,
494 };
495 self.layers[self.index].run(input, spec, &next).await
496 }
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502
503 #[tokio::test]
504 async fn dispatch_middleware_is_object_safe() {
505 struct TagMiddleware;
506
507 #[async_trait]
508 impl DispatchMiddleware for TagMiddleware {
509 async fn dispatch(
510 &self,
511 operator: &OperatorId,
512 mut input: OperatorInput,
513 next: &dyn DispatchNext,
514 ) -> Result<OperatorOutput, OrchError> {
515 input.metadata = serde_json::json!({"tagged": true});
516 next.dispatch(operator, input).await
517 }
518 }
519
520 let _mw: Box<dyn DispatchMiddleware> = Box::new(TagMiddleware);
521 }
522
523 #[tokio::test]
524 async fn store_middleware_is_object_safe() {
525 struct AuditStore;
526
527 #[async_trait]
528 impl StoreMiddleware for AuditStore {
529 async fn write(
530 &self,
531 scope: &Scope,
532 key: &str,
533 value: serde_json::Value,
534 options: Option<&StoreOptions>,
535 next: &dyn StoreWriteNext,
536 ) -> Result<(), StateError> {
537 next.write(scope, key, value, options).await
538 }
539 }
540
541 let _mw: Box<dyn StoreMiddleware> = Box::new(AuditStore);
542 }
543
544 #[tokio::test]
545 async fn exec_middleware_is_object_safe() {
546 struct CredentialInjector;
547
548 #[async_trait]
549 impl ExecMiddleware for CredentialInjector {
550 async fn run(
551 &self,
552 input: OperatorInput,
553 spec: &EnvironmentSpec,
554 next: &dyn ExecNext,
555 ) -> Result<OperatorOutput, EnvError> {
556 next.run(input, spec).await
557 }
558 }
559
560 let _mw: Box<dyn ExecMiddleware> = Box::new(CredentialInjector);
561 }
562
563 #[tokio::test]
564 async fn dispatch_stack_observer_always_runs() {
565 use std::sync::atomic::{AtomicU32, Ordering};
566
567 let counter = Arc::new(AtomicU32::new(0));
568
569 struct CountObserver(Arc<AtomicU32>);
570
571 #[async_trait]
572 impl DispatchMiddleware for CountObserver {
573 async fn dispatch(
574 &self,
575 operator: &OperatorId,
576 input: OperatorInput,
577 next: &dyn DispatchNext,
578 ) -> Result<OperatorOutput, OrchError> {
579 self.0.fetch_add(1, Ordering::SeqCst);
580 next.dispatch(operator, input).await
581 }
582 }
583
584 struct HaltGuard;
585
586 #[async_trait]
587 impl DispatchMiddleware for HaltGuard {
588 async fn dispatch(
589 &self,
590 _operator: &OperatorId,
591 _input: OperatorInput,
592 _next: &dyn DispatchNext,
593 ) -> Result<OperatorOutput, OrchError> {
594 Err(OrchError::DispatchFailed("budget exceeded".into()))
595 }
596 }
597
598 let stack = DispatchStack::builder()
599 .observe(Arc::new(CountObserver(counter.clone())))
600 .guard(Arc::new(HaltGuard))
601 .build();
602
603 struct EchoTerminal;
604
605 #[async_trait]
606 impl DispatchNext for EchoTerminal {
607 async fn dispatch(
608 &self,
609 _operator: &OperatorId,
610 input: OperatorInput,
611 ) -> Result<OperatorOutput, OrchError> {
612 Ok(OperatorOutput::new(
613 input.message,
614 crate::ExitReason::Complete,
615 ))
616 }
617 }
618
619 let input = OperatorInput::new(
620 crate::content::Content::text("test"),
621 crate::operator::TriggerType::User,
622 );
623 let result = stack
624 .dispatch_with(&OperatorId::from("a"), input, &EchoTerminal)
625 .await;
626 assert!(result.is_err());
627 assert_eq!(counter.load(Ordering::SeqCst), 1);
628 }
629
630 #[tokio::test]
631 async fn dispatch_stack_transform_then_terminal() {
632 struct Uppercaser;
633
634 #[async_trait]
635 impl DispatchMiddleware for Uppercaser {
636 async fn dispatch(
637 &self,
638 operator: &OperatorId,
639 mut input: OperatorInput,
640 next: &dyn DispatchNext,
641 ) -> Result<OperatorOutput, OrchError> {
642 input.metadata = serde_json::json!({"transformed": true});
643 next.dispatch(operator, input).await
644 }
645 }
646
647 struct EchoTerminal;
648
649 #[async_trait]
650 impl DispatchNext for EchoTerminal {
651 async fn dispatch(
652 &self,
653 _operator: &OperatorId,
654 input: OperatorInput,
655 ) -> Result<OperatorOutput, OrchError> {
656 Ok(OperatorOutput::new(
657 input.message,
658 crate::ExitReason::Complete,
659 ))
660 }
661 }
662
663 let stack = DispatchStack::builder()
664 .transform(Arc::new(Uppercaser))
665 .build();
666
667 let input = OperatorInput::new(
668 crate::content::Content::text("hello"),
669 crate::operator::TriggerType::User,
670 );
671 let result = stack
672 .dispatch_with(&OperatorId::from("a"), input, &EchoTerminal)
673 .await;
674 assert!(result.is_ok());
675 }
676
677 #[tokio::test]
678 async fn store_stack_write_through() {
679 use std::sync::atomic::{AtomicU32, Ordering};
680
681 let write_count = Arc::new(AtomicU32::new(0));
682
683 struct CountWrites(Arc<AtomicU32>);
684
685 #[async_trait]
686 impl StoreMiddleware for CountWrites {
687 async fn write(
688 &self,
689 scope: &Scope,
690 key: &str,
691 value: serde_json::Value,
692 options: Option<&StoreOptions>,
693 next: &dyn StoreWriteNext,
694 ) -> Result<(), StateError> {
695 self.0.fetch_add(1, Ordering::SeqCst);
696 next.write(scope, key, value, options).await
697 }
698 }
699
700 struct NoOpStore;
701
702 #[async_trait]
703 impl StoreWriteNext for NoOpStore {
704 async fn write(
705 &self,
706 _scope: &Scope,
707 _key: &str,
708 _value: serde_json::Value,
709 _options: Option<&StoreOptions>,
710 ) -> Result<(), StateError> {
711 Ok(())
712 }
713 }
714
715 let stack = StoreStack::builder()
716 .observe(Arc::new(CountWrites(write_count.clone())))
717 .build();
718
719 let scope = Scope::Operator {
720 workflow: crate::id::WorkflowId::from("w"),
721 operator: OperatorId::from("a"),
722 };
723 stack
724 .write_with(&scope, "k", serde_json::json!(1), None, &NoOpStore)
725 .await
726 .unwrap();
727 assert_eq!(write_count.load(Ordering::SeqCst), 1);
728 }
729
730 #[tokio::test]
731 async fn exec_stack_passthrough() {
732 struct LogExec;
733
734 #[async_trait]
735 impl ExecMiddleware for LogExec {
736 async fn run(
737 &self,
738 input: OperatorInput,
739 spec: &EnvironmentSpec,
740 next: &dyn ExecNext,
741 ) -> Result<OperatorOutput, EnvError> {
742 next.run(input, spec).await
743 }
744 }
745
746 struct EchoExec;
747
748 #[async_trait]
749 impl ExecNext for EchoExec {
750 async fn run(
751 &self,
752 input: OperatorInput,
753 _spec: &EnvironmentSpec,
754 ) -> Result<OperatorOutput, EnvError> {
755 Ok(OperatorOutput::new(
756 input.message,
757 crate::ExitReason::Complete,
758 ))
759 }
760 }
761
762 let stack = ExecStack::builder().observe(Arc::new(LogExec)).build();
763
764 let input = OperatorInput::new(
765 crate::content::Content::text("run"),
766 crate::operator::TriggerType::User,
767 );
768 let spec = EnvironmentSpec::default();
769 let result = stack.run_with(input, &spec, &EchoExec).await;
770 assert!(result.is_ok());
771 }
772}