use crate::effect::Scope;
use crate::environment::EnvironmentSpec;
use crate::error::{EnvError, OrchError, StateError};
use crate::id::OperatorId;
use crate::operator::{OperatorInput, OperatorOutput};
use crate::state::StoreOptions;
use async_trait::async_trait;
use std::sync::Arc;
#[async_trait]
pub trait DispatchNext: Send + Sync {
async fn dispatch(
&self,
operator: &OperatorId,
input: OperatorInput,
) -> Result<OperatorOutput, OrchError>;
}
#[async_trait]
pub trait DispatchMiddleware: Send + Sync {
async fn dispatch(
&self,
operator: &OperatorId,
input: OperatorInput,
next: &dyn DispatchNext,
) -> Result<OperatorOutput, OrchError>;
}
#[async_trait]
pub trait StoreWriteNext: Send + Sync {
async fn write(
&self,
scope: &Scope,
key: &str,
value: serde_json::Value,
options: Option<&StoreOptions>,
) -> Result<(), StateError>;
}
#[async_trait]
pub trait StoreReadNext: Send + Sync {
async fn read(&self, scope: &Scope, key: &str)
-> Result<Option<serde_json::Value>, StateError>;
}
#[async_trait]
pub trait StoreMiddleware: Send + Sync {
async fn write(
&self,
scope: &Scope,
key: &str,
value: serde_json::Value,
options: Option<&StoreOptions>,
next: &dyn StoreWriteNext,
) -> Result<(), StateError>;
async fn read(
&self,
scope: &Scope,
key: &str,
next: &dyn StoreReadNext,
) -> Result<Option<serde_json::Value>, StateError> {
next.read(scope, key).await
}
}
#[async_trait]
pub trait ExecNext: Send + Sync {
async fn run(
&self,
input: OperatorInput,
spec: &EnvironmentSpec,
) -> Result<OperatorOutput, EnvError>;
}
#[async_trait]
pub trait ExecMiddleware: Send + Sync {
async fn run(
&self,
input: OperatorInput,
spec: &EnvironmentSpec,
next: &dyn ExecNext,
) -> Result<OperatorOutput, EnvError>;
}
pub struct DispatchStack {
layers: Vec<Arc<dyn DispatchMiddleware>>,
}
pub struct DispatchStackBuilder {
observers: Vec<Arc<dyn DispatchMiddleware>>,
transformers: Vec<Arc<dyn DispatchMiddleware>>,
guards: Vec<Arc<dyn DispatchMiddleware>>,
}
impl DispatchStack {
pub fn builder() -> DispatchStackBuilder {
DispatchStackBuilder {
observers: Vec::new(),
transformers: Vec::new(),
guards: Vec::new(),
}
}
pub async fn dispatch_with(
&self,
operator: &OperatorId,
input: OperatorInput,
terminal: &dyn DispatchNext,
) -> Result<OperatorOutput, OrchError> {
if self.layers.is_empty() {
return terminal.dispatch(operator, input).await;
}
let chain = DispatchChain {
layers: &self.layers,
index: 0,
terminal,
};
chain.dispatch(operator, input).await
}
}
impl DispatchStackBuilder {
pub fn observe(mut self, mw: Arc<dyn DispatchMiddleware>) -> Self {
self.observers.push(mw);
self
}
pub fn transform(mut self, mw: Arc<dyn DispatchMiddleware>) -> Self {
self.transformers.push(mw);
self
}
pub fn guard(mut self, mw: Arc<dyn DispatchMiddleware>) -> Self {
self.guards.push(mw);
self
}
pub fn build(self) -> DispatchStack {
let mut layers = Vec::new();
layers.extend(self.observers);
layers.extend(self.transformers);
layers.extend(self.guards);
DispatchStack { layers }
}
}
struct DispatchChain<'a> {
layers: &'a [Arc<dyn DispatchMiddleware>],
index: usize,
terminal: &'a dyn DispatchNext,
}
#[async_trait]
impl DispatchNext for DispatchChain<'_> {
async fn dispatch(
&self,
operator: &OperatorId,
input: OperatorInput,
) -> Result<OperatorOutput, OrchError> {
if self.index >= self.layers.len() {
return self.terminal.dispatch(operator, input).await;
}
let next = DispatchChain {
layers: self.layers,
index: self.index + 1,
terminal: self.terminal,
};
self.layers[self.index]
.dispatch(operator, input, &next)
.await
}
}
pub struct StoreStack {
layers: Vec<Arc<dyn StoreMiddleware>>,
}
pub struct StoreStackBuilder {
observers: Vec<Arc<dyn StoreMiddleware>>,
transformers: Vec<Arc<dyn StoreMiddleware>>,
guards: Vec<Arc<dyn StoreMiddleware>>,
}
impl StoreStack {
pub fn builder() -> StoreStackBuilder {
StoreStackBuilder {
observers: Vec::new(),
transformers: Vec::new(),
guards: Vec::new(),
}
}
pub async fn write_with(
&self,
scope: &Scope,
key: &str,
value: serde_json::Value,
options: Option<&StoreOptions>,
terminal: &dyn StoreWriteNext,
) -> Result<(), StateError> {
if self.layers.is_empty() {
return terminal.write(scope, key, value, options).await;
}
let chain = StoreWriteChain {
layers: &self.layers,
index: 0,
terminal,
options,
};
chain.write(scope, key, value, options).await
}
pub async fn read_with(
&self,
scope: &Scope,
key: &str,
terminal: &dyn StoreReadNext,
) -> Result<Option<serde_json::Value>, StateError> {
if self.layers.is_empty() {
return terminal.read(scope, key).await;
}
let chain = StoreReadChain {
layers: &self.layers,
index: 0,
terminal,
};
chain.read(scope, key).await
}
}
impl StoreStackBuilder {
pub fn observe(mut self, mw: Arc<dyn StoreMiddleware>) -> Self {
self.observers.push(mw);
self
}
pub fn transform(mut self, mw: Arc<dyn StoreMiddleware>) -> Self {
self.transformers.push(mw);
self
}
pub fn guard(mut self, mw: Arc<dyn StoreMiddleware>) -> Self {
self.guards.push(mw);
self
}
pub fn build(self) -> StoreStack {
let mut layers = Vec::new();
layers.extend(self.observers);
layers.extend(self.transformers);
layers.extend(self.guards);
StoreStack { layers }
}
}
struct StoreWriteChain<'a> {
layers: &'a [Arc<dyn StoreMiddleware>],
index: usize,
terminal: &'a dyn StoreWriteNext,
options: Option<&'a StoreOptions>,
}
#[async_trait]
impl StoreWriteNext for StoreWriteChain<'_> {
async fn write(
&self,
scope: &Scope,
key: &str,
value: serde_json::Value,
options: Option<&StoreOptions>,
) -> Result<(), StateError> {
if self.index >= self.layers.len() {
return self.terminal.write(scope, key, value, options).await;
}
let next = StoreWriteChain {
layers: self.layers,
index: self.index + 1,
terminal: self.terminal,
options: self.options,
};
self.layers[self.index]
.write(scope, key, value, options, &next)
.await
}
}
struct StoreReadChain<'a> {
layers: &'a [Arc<dyn StoreMiddleware>],
index: usize,
terminal: &'a dyn StoreReadNext,
}
#[async_trait]
impl StoreReadNext for StoreReadChain<'_> {
async fn read(
&self,
scope: &Scope,
key: &str,
) -> Result<Option<serde_json::Value>, StateError> {
if self.index >= self.layers.len() {
return self.terminal.read(scope, key).await;
}
let next = StoreReadChain {
layers: self.layers,
index: self.index + 1,
terminal: self.terminal,
};
self.layers[self.index].read(scope, key, &next).await
}
}
pub struct ExecStack {
layers: Vec<Arc<dyn ExecMiddleware>>,
}
pub struct ExecStackBuilder {
observers: Vec<Arc<dyn ExecMiddleware>>,
transformers: Vec<Arc<dyn ExecMiddleware>>,
guards: Vec<Arc<dyn ExecMiddleware>>,
}
impl ExecStack {
pub fn builder() -> ExecStackBuilder {
ExecStackBuilder {
observers: Vec::new(),
transformers: Vec::new(),
guards: Vec::new(),
}
}
pub async fn run_with(
&self,
input: OperatorInput,
spec: &EnvironmentSpec,
terminal: &dyn ExecNext,
) -> Result<OperatorOutput, EnvError> {
if self.layers.is_empty() {
return terminal.run(input, spec).await;
}
let chain = ExecChain {
layers: &self.layers,
index: 0,
terminal,
};
chain.run(input, spec).await
}
}
impl ExecStackBuilder {
pub fn observe(mut self, mw: Arc<dyn ExecMiddleware>) -> Self {
self.observers.push(mw);
self
}
pub fn transform(mut self, mw: Arc<dyn ExecMiddleware>) -> Self {
self.transformers.push(mw);
self
}
pub fn guard(mut self, mw: Arc<dyn ExecMiddleware>) -> Self {
self.guards.push(mw);
self
}
pub fn build(self) -> ExecStack {
let mut layers = Vec::new();
layers.extend(self.observers);
layers.extend(self.transformers);
layers.extend(self.guards);
ExecStack { layers }
}
}
struct ExecChain<'a> {
layers: &'a [Arc<dyn ExecMiddleware>],
index: usize,
terminal: &'a dyn ExecNext,
}
#[async_trait]
impl ExecNext for ExecChain<'_> {
async fn run(
&self,
input: OperatorInput,
spec: &EnvironmentSpec,
) -> Result<OperatorOutput, EnvError> {
if self.index >= self.layers.len() {
return self.terminal.run(input, spec).await;
}
let next = ExecChain {
layers: self.layers,
index: self.index + 1,
terminal: self.terminal,
};
self.layers[self.index].run(input, spec, &next).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn dispatch_middleware_is_object_safe() {
struct TagMiddleware;
#[async_trait]
impl DispatchMiddleware for TagMiddleware {
async fn dispatch(
&self,
operator: &OperatorId,
mut input: OperatorInput,
next: &dyn DispatchNext,
) -> Result<OperatorOutput, OrchError> {
input.metadata = serde_json::json!({"tagged": true});
next.dispatch(operator, input).await
}
}
let _mw: Box<dyn DispatchMiddleware> = Box::new(TagMiddleware);
}
#[tokio::test]
async fn store_middleware_is_object_safe() {
struct AuditStore;
#[async_trait]
impl StoreMiddleware for AuditStore {
async fn write(
&self,
scope: &Scope,
key: &str,
value: serde_json::Value,
options: Option<&StoreOptions>,
next: &dyn StoreWriteNext,
) -> Result<(), StateError> {
next.write(scope, key, value, options).await
}
}
let _mw: Box<dyn StoreMiddleware> = Box::new(AuditStore);
}
#[tokio::test]
async fn exec_middleware_is_object_safe() {
struct CredentialInjector;
#[async_trait]
impl ExecMiddleware for CredentialInjector {
async fn run(
&self,
input: OperatorInput,
spec: &EnvironmentSpec,
next: &dyn ExecNext,
) -> Result<OperatorOutput, EnvError> {
next.run(input, spec).await
}
}
let _mw: Box<dyn ExecMiddleware> = Box::new(CredentialInjector);
}
#[tokio::test]
async fn dispatch_stack_observer_always_runs() {
use std::sync::atomic::{AtomicU32, Ordering};
let counter = Arc::new(AtomicU32::new(0));
struct CountObserver(Arc<AtomicU32>);
#[async_trait]
impl DispatchMiddleware for CountObserver {
async fn dispatch(
&self,
operator: &OperatorId,
input: OperatorInput,
next: &dyn DispatchNext,
) -> Result<OperatorOutput, OrchError> {
self.0.fetch_add(1, Ordering::SeqCst);
next.dispatch(operator, input).await
}
}
struct HaltGuard;
#[async_trait]
impl DispatchMiddleware for HaltGuard {
async fn dispatch(
&self,
_operator: &OperatorId,
_input: OperatorInput,
_next: &dyn DispatchNext,
) -> Result<OperatorOutput, OrchError> {
Err(OrchError::DispatchFailed("budget exceeded".into()))
}
}
let stack = DispatchStack::builder()
.observe(Arc::new(CountObserver(counter.clone())))
.guard(Arc::new(HaltGuard))
.build();
struct EchoTerminal;
#[async_trait]
impl DispatchNext for EchoTerminal {
async fn dispatch(
&self,
_operator: &OperatorId,
input: OperatorInput,
) -> Result<OperatorOutput, OrchError> {
Ok(OperatorOutput::new(
input.message,
crate::ExitReason::Complete,
))
}
}
let input = OperatorInput::new(
crate::content::Content::text("test"),
crate::operator::TriggerType::User,
);
let result = stack
.dispatch_with(&OperatorId::from("a"), input, &EchoTerminal)
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn dispatch_stack_transform_then_terminal() {
struct Uppercaser;
#[async_trait]
impl DispatchMiddleware for Uppercaser {
async fn dispatch(
&self,
operator: &OperatorId,
mut input: OperatorInput,
next: &dyn DispatchNext,
) -> Result<OperatorOutput, OrchError> {
input.metadata = serde_json::json!({"transformed": true});
next.dispatch(operator, input).await
}
}
struct EchoTerminal;
#[async_trait]
impl DispatchNext for EchoTerminal {
async fn dispatch(
&self,
_operator: &OperatorId,
input: OperatorInput,
) -> Result<OperatorOutput, OrchError> {
Ok(OperatorOutput::new(
input.message,
crate::ExitReason::Complete,
))
}
}
let stack = DispatchStack::builder()
.transform(Arc::new(Uppercaser))
.build();
let input = OperatorInput::new(
crate::content::Content::text("hello"),
crate::operator::TriggerType::User,
);
let result = stack
.dispatch_with(&OperatorId::from("a"), input, &EchoTerminal)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn store_stack_write_through() {
use std::sync::atomic::{AtomicU32, Ordering};
let write_count = Arc::new(AtomicU32::new(0));
struct CountWrites(Arc<AtomicU32>);
#[async_trait]
impl StoreMiddleware for CountWrites {
async fn write(
&self,
scope: &Scope,
key: &str,
value: serde_json::Value,
options: Option<&StoreOptions>,
next: &dyn StoreWriteNext,
) -> Result<(), StateError> {
self.0.fetch_add(1, Ordering::SeqCst);
next.write(scope, key, value, options).await
}
}
struct NoOpStore;
#[async_trait]
impl StoreWriteNext for NoOpStore {
async fn write(
&self,
_scope: &Scope,
_key: &str,
_value: serde_json::Value,
_options: Option<&StoreOptions>,
) -> Result<(), StateError> {
Ok(())
}
}
let stack = StoreStack::builder()
.observe(Arc::new(CountWrites(write_count.clone())))
.build();
let scope = Scope::Operator {
workflow: crate::id::WorkflowId::from("w"),
operator: OperatorId::from("a"),
};
stack
.write_with(&scope, "k", serde_json::json!(1), None, &NoOpStore)
.await
.unwrap();
assert_eq!(write_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn exec_stack_passthrough() {
struct LogExec;
#[async_trait]
impl ExecMiddleware for LogExec {
async fn run(
&self,
input: OperatorInput,
spec: &EnvironmentSpec,
next: &dyn ExecNext,
) -> Result<OperatorOutput, EnvError> {
next.run(input, spec).await
}
}
struct EchoExec;
#[async_trait]
impl ExecNext for EchoExec {
async fn run(
&self,
input: OperatorInput,
_spec: &EnvironmentSpec,
) -> Result<OperatorOutput, EnvError> {
Ok(OperatorOutput::new(
input.message,
crate::ExitReason::Complete,
))
}
}
let stack = ExecStack::builder().observe(Arc::new(LogExec)).build();
let input = OperatorInput::new(
crate::content::Content::text("run"),
crate::operator::TriggerType::User,
);
let spec = EnvironmentSpec::default();
let result = stack.run_with(input, &spec, &EchoExec).await;
assert!(result.is_ok());
}
}