use std::sync::Arc;
use async_trait::async_trait;
use entelix_core::context::ExecutionContext;
use entelix_core::error::Result;
use entelix_runnable::Runnable;
pub struct MergeNodeAdapter<S, U, F>
where
S: Clone + Send + Sync + 'static,
U: Send + Sync + 'static,
F: Fn(S, U) -> Result<S> + Send + Sync + 'static,
{
inner: Arc<dyn Runnable<S, U>>,
merger: F,
}
impl<S, U, F> MergeNodeAdapter<S, U, F>
where
S: Clone + Send + Sync + 'static,
U: Send + Sync + 'static,
F: Fn(S, U) -> Result<S> + Send + Sync + 'static,
{
pub fn new<R>(inner: R, merger: F) -> Self
where
R: Runnable<S, U> + 'static,
{
Self {
inner: Arc::new(inner),
merger,
}
}
}
impl<S, U, F> std::fmt::Debug for MergeNodeAdapter<S, U, F>
where
S: Clone + Send + Sync + 'static,
U: Send + Sync + 'static,
F: Fn(S, U) -> Result<S> + Send + Sync + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MergeNodeAdapter")
.field("inner", &"<runnable>")
.field("merger", &"<closure>")
.finish()
}
}
#[async_trait]
impl<S, U, F> Runnable<S, S> for MergeNodeAdapter<S, U, F>
where
S: Clone + Send + Sync + 'static,
U: Send + Sync + 'static,
F: Fn(S, U) -> Result<S> + Send + Sync + 'static,
{
async fn invoke(&self, input: S, ctx: &ExecutionContext) -> Result<S> {
let snapshot = input.clone();
let update = self.inner.invoke(input, ctx).await?;
(self.merger)(snapshot, update)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use entelix_core::error::Error;
use entelix_runnable::RunnableLambda;
use super::*;
#[derive(Clone, Debug, PartialEq)]
struct State {
log: Vec<String>,
counter: u32,
}
#[derive(Clone, Debug)]
struct PlanDelta {
new_entries: Vec<String>,
increment: u32,
}
#[tokio::test]
async fn merger_combines_state_with_delta() {
let planner = RunnableLambda::new(|s: State, _ctx| async move {
Ok::<_, _>(PlanDelta {
new_entries: vec![format!("planned at counter={}", s.counter)],
increment: 1,
})
});
let adapter = MergeNodeAdapter::new(planner, |mut state: State, update: PlanDelta| {
state.log.extend(update.new_entries);
state.counter += update.increment;
Ok(state)
});
let initial = State {
log: vec!["seed".into()],
counter: 10,
};
let result = adapter
.invoke(initial, &ExecutionContext::new())
.await
.unwrap();
assert_eq!(
result.log,
vec!["seed".to_owned(), "planned at counter=10".to_owned()]
);
assert_eq!(result.counter, 11);
}
#[tokio::test]
async fn merger_can_fail_and_propagate_error() {
let planner = RunnableLambda::new(|_s: State, _ctx| async move {
Ok::<_, _>(PlanDelta {
new_entries: Vec::new(),
increment: 0,
})
});
let adapter = MergeNodeAdapter::new(planner, |_state: State, _update: PlanDelta| {
Err(Error::invalid_request("merger refused"))
});
let err = adapter
.invoke(
State {
log: Vec::new(),
counter: 0,
},
&ExecutionContext::new(),
)
.await
.unwrap_err();
assert!(format!("{err}").contains("merger refused"));
}
#[tokio::test]
async fn inner_failure_short_circuits_before_merger() {
let merger_calls = Arc::new(std::sync::atomic::AtomicU32::new(0));
let merger_calls_clone = Arc::clone(&merger_calls);
let planner = RunnableLambda::new(|_s: State, _ctx| async move {
Err::<PlanDelta, _>(Error::invalid_request("planner failed"))
});
let adapter = MergeNodeAdapter::new(planner, move |state: State, _update: PlanDelta| {
merger_calls_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Ok(state)
});
let err = adapter
.invoke(
State {
log: Vec::new(),
counter: 0,
},
&ExecutionContext::new(),
)
.await
.unwrap_err();
assert!(format!("{err}").contains("planner failed"));
assert_eq!(
merger_calls.load(std::sync::atomic::Ordering::SeqCst),
0,
"merger must not run when inner runnable fails"
);
}
}