1use std::sync::Arc;
27
28use async_trait::async_trait;
29use entelix_core::context::ExecutionContext;
30use entelix_core::error::Result;
31use entelix_runnable::Runnable;
32
33pub struct MergeNodeAdapter<S, U, F>
37where
38 S: Clone + Send + Sync + 'static,
39 U: Send + Sync + 'static,
40 F: Fn(S, U) -> Result<S> + Send + Sync + 'static,
41{
42 inner: Arc<dyn Runnable<S, U>>,
43 merger: F,
44}
45
46impl<S, U, F> MergeNodeAdapter<S, U, F>
47where
48 S: Clone + Send + Sync + 'static,
49 U: Send + Sync + 'static,
50 F: Fn(S, U) -> Result<S> + Send + Sync + 'static,
51{
52 pub fn new<R>(inner: R, merger: F) -> Self
54 where
55 R: Runnable<S, U> + 'static,
56 {
57 Self {
58 inner: Arc::new(inner),
59 merger,
60 }
61 }
62}
63
64impl<S, U, F> std::fmt::Debug for MergeNodeAdapter<S, U, F>
65where
66 S: Clone + Send + Sync + 'static,
67 U: Send + Sync + 'static,
68 F: Fn(S, U) -> Result<S> + Send + Sync + 'static,
69{
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("MergeNodeAdapter")
72 .field("inner", &"<runnable>")
73 .field("merger", &"<closure>")
74 .finish()
75 }
76}
77
78#[async_trait]
79impl<S, U, F> Runnable<S, S> for MergeNodeAdapter<S, U, F>
80where
81 S: Clone + Send + Sync + 'static,
82 U: Send + Sync + 'static,
83 F: Fn(S, U) -> Result<S> + Send + Sync + 'static,
84{
85 async fn invoke(&self, input: S, ctx: &ExecutionContext) -> Result<S> {
86 let snapshot = input.clone();
91 let update = self.inner.invoke(input, ctx).await?;
92 (self.merger)(snapshot, update)
93 }
94}
95
96#[cfg(test)]
97#[allow(clippy::unwrap_used)]
98mod tests {
99 use entelix_core::error::Error;
100 use entelix_runnable::RunnableLambda;
101
102 use super::*;
103
104 #[derive(Clone, Debug, PartialEq)]
105 struct State {
106 log: Vec<String>,
107 counter: u32,
108 }
109
110 #[derive(Clone, Debug)]
113 struct PlanDelta {
114 new_entries: Vec<String>,
115 increment: u32,
116 }
117
118 #[tokio::test]
119 async fn merger_combines_state_with_delta() {
120 let planner = RunnableLambda::new(|s: State, _ctx| async move {
121 Ok::<_, _>(PlanDelta {
122 new_entries: vec![format!("planned at counter={}", s.counter)],
123 increment: 1,
124 })
125 });
126 let adapter = MergeNodeAdapter::new(planner, |mut state: State, update: PlanDelta| {
127 state.log.extend(update.new_entries);
128 state.counter += update.increment;
129 Ok(state)
130 });
131
132 let initial = State {
133 log: vec!["seed".into()],
134 counter: 10,
135 };
136 let result = adapter
137 .invoke(initial, &ExecutionContext::new())
138 .await
139 .unwrap();
140 assert_eq!(
141 result.log,
142 vec!["seed".to_owned(), "planned at counter=10".to_owned()]
143 );
144 assert_eq!(result.counter, 11);
145 }
146
147 #[tokio::test]
148 async fn merger_can_fail_and_propagate_error() {
149 let planner = RunnableLambda::new(|_s: State, _ctx| async move {
150 Ok::<_, _>(PlanDelta {
151 new_entries: Vec::new(),
152 increment: 0,
153 })
154 });
155 let adapter = MergeNodeAdapter::new(planner, |_state: State, _update: PlanDelta| {
156 Err(Error::invalid_request("merger refused"))
157 });
158
159 let err = adapter
160 .invoke(
161 State {
162 log: Vec::new(),
163 counter: 0,
164 },
165 &ExecutionContext::new(),
166 )
167 .await
168 .unwrap_err();
169 assert!(format!("{err}").contains("merger refused"));
170 }
171
172 #[tokio::test]
173 async fn inner_failure_short_circuits_before_merger() {
174 let merger_calls = Arc::new(std::sync::atomic::AtomicU32::new(0));
175 let merger_calls_clone = Arc::clone(&merger_calls);
176
177 let planner = RunnableLambda::new(|_s: State, _ctx| async move {
178 Err::<PlanDelta, _>(Error::invalid_request("planner failed"))
179 });
180 let adapter = MergeNodeAdapter::new(planner, move |state: State, _update: PlanDelta| {
181 merger_calls_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
182 Ok(state)
183 });
184
185 let err = adapter
186 .invoke(
187 State {
188 log: Vec::new(),
189 counter: 0,
190 },
191 &ExecutionContext::new(),
192 )
193 .await
194 .unwrap_err();
195 assert!(format!("{err}").contains("planner failed"));
196 assert_eq!(
197 merger_calls.load(std::sync::atomic::Ordering::SeqCst),
198 0,
199 "merger must not run when inner runnable fails"
200 );
201 }
202}