Skip to main content

entelix_graph/
merge_node.rs

1//! `MergeNodeAdapter` — wraps a delta-producing `Runnable<S, U>` and
2//! a merger closure into a `Runnable<S, S>` that fits unchanged into
3//! the existing StateGraph node contract.
4//!
5//! The current StateGraph contract has each node implement
6//! `Runnable<S, S>` and own its full-state replace logic. That works
7//! but forces every node closure to thread the unchanged fields
8//! through itself manually. LangGraph users coming from Python
9//! expect a "delta-style" alternative where a node returns only its
10//! contribution and the runtime merges it into the surrounding
11//! state.
12//!
13//! `MergeNodeAdapter<S, U, F>` provides that ergonomic without
14//! changing the node contract. It snapshots the inbound state,
15//! runs the inner runnable to get an update of arbitrary type `U`,
16//! and applies the user-supplied merger
17//! `Fn(state: S, update: U) -> Result<S>` to produce the next
18//! full state. Existing `add_node` (full-state replace) and the new
19//! `add_node_with` (delta + merger) coexist.
20//!
21//! The merger has full access to both the inbound state and the
22//! delta, so it composes naturally with the [`Reducer<T>`](crate::Reducer)
23//! impls already shipped — the closure body is the place to call
24//! `Append::<U>::new().reduce(...)` per field.
25
26use std::sync::Arc;
27
28use async_trait::async_trait;
29use entelix_core::context::ExecutionContext;
30use entelix_core::error::Result;
31use entelix_runnable::Runnable;
32
33/// `Runnable<S, S>` that runs an inner `Runnable<S, U>` and merges
34/// the resulting `U` back into a fresh copy of the inbound `S` via
35/// the supplied closure.
36pub struct MergeNodeAdapter<S, U, F>
37where
38    S: Clone + Send + Sync + 'static,
39    U: Send + Sync + 'static,
40    F: Fn(S, U) -> Result<S> + Send + Sync + 'static,
41{
42    inner: Arc<dyn Runnable<S, U>>,
43    merger: F,
44}
45
46impl<S, U, F> MergeNodeAdapter<S, U, F>
47where
48    S: Clone + Send + Sync + 'static,
49    U: Send + Sync + 'static,
50    F: Fn(S, U) -> Result<S> + Send + Sync + 'static,
51{
52    /// Wrap `inner` with the supplied merger.
53    pub fn new<R>(inner: R, merger: F) -> Self
54    where
55        R: Runnable<S, U> + 'static,
56    {
57        Self {
58            inner: Arc::new(inner),
59            merger,
60        }
61    }
62}
63
64impl<S, U, F> std::fmt::Debug for MergeNodeAdapter<S, U, F>
65where
66    S: Clone + Send + Sync + 'static,
67    U: Send + Sync + 'static,
68    F: Fn(S, U) -> Result<S> + Send + Sync + 'static,
69{
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        f.debug_struct("MergeNodeAdapter")
72            .field("inner", &"<runnable>")
73            .field("merger", &"<closure>")
74            .finish()
75    }
76}
77
78#[async_trait]
79impl<S, U, F> Runnable<S, S> for MergeNodeAdapter<S, U, F>
80where
81    S: Clone + Send + Sync + 'static,
82    U: Send + Sync + 'static,
83    F: Fn(S, U) -> Result<S> + Send + Sync + 'static,
84{
85    async fn invoke(&self, input: S, ctx: &ExecutionContext) -> Result<S> {
86        // Clone *before* invoking the inner runnable so the merger
87        // sees the pre-call state regardless of what the inner
88        // runnable does with its argument. Cheap when S is `Arc`-
89        // backed; explicit and predictable for everything else.
90        let snapshot = input.clone();
91        let update = self.inner.invoke(input, ctx).await?;
92        (self.merger)(snapshot, update)
93    }
94}
95
96#[cfg(test)]
97#[allow(clippy::unwrap_used)]
98mod tests {
99    use entelix_core::error::Error;
100    use entelix_runnable::RunnableLambda;
101
102    use super::*;
103
104    #[derive(Clone, Debug, PartialEq)]
105    struct State {
106        log: Vec<String>,
107        counter: u32,
108    }
109
110    /// Delta-style update: a node only produces its new log entries
111    /// and a counter increment, not the full state.
112    #[derive(Clone, Debug)]
113    struct PlanDelta {
114        new_entries: Vec<String>,
115        increment: u32,
116    }
117
118    #[tokio::test]
119    async fn merger_combines_state_with_delta() {
120        let planner = RunnableLambda::new(|s: State, _ctx| async move {
121            Ok::<_, _>(PlanDelta {
122                new_entries: vec![format!("planned at counter={}", s.counter)],
123                increment: 1,
124            })
125        });
126        let adapter = MergeNodeAdapter::new(planner, |mut state: State, update: PlanDelta| {
127            state.log.extend(update.new_entries);
128            state.counter += update.increment;
129            Ok(state)
130        });
131
132        let initial = State {
133            log: vec!["seed".into()],
134            counter: 10,
135        };
136        let result = adapter
137            .invoke(initial, &ExecutionContext::new())
138            .await
139            .unwrap();
140        assert_eq!(
141            result.log,
142            vec!["seed".to_owned(), "planned at counter=10".to_owned()]
143        );
144        assert_eq!(result.counter, 11);
145    }
146
147    #[tokio::test]
148    async fn merger_can_fail_and_propagate_error() {
149        let planner = RunnableLambda::new(|_s: State, _ctx| async move {
150            Ok::<_, _>(PlanDelta {
151                new_entries: Vec::new(),
152                increment: 0,
153            })
154        });
155        let adapter = MergeNodeAdapter::new(planner, |_state: State, _update: PlanDelta| {
156            Err(Error::invalid_request("merger refused"))
157        });
158
159        let err = adapter
160            .invoke(
161                State {
162                    log: Vec::new(),
163                    counter: 0,
164                },
165                &ExecutionContext::new(),
166            )
167            .await
168            .unwrap_err();
169        assert!(format!("{err}").contains("merger refused"));
170    }
171
172    #[tokio::test]
173    async fn inner_failure_short_circuits_before_merger() {
174        let merger_calls = Arc::new(std::sync::atomic::AtomicU32::new(0));
175        let merger_calls_clone = Arc::clone(&merger_calls);
176
177        let planner = RunnableLambda::new(|_s: State, _ctx| async move {
178            Err::<PlanDelta, _>(Error::invalid_request("planner failed"))
179        });
180        let adapter = MergeNodeAdapter::new(planner, move |state: State, _update: PlanDelta| {
181            merger_calls_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
182            Ok(state)
183        });
184
185        let err = adapter
186            .invoke(
187                State {
188                    log: Vec::new(),
189                    counter: 0,
190                },
191                &ExecutionContext::new(),
192            )
193            .await
194            .unwrap_err();
195        assert!(format!("{err}").contains("planner failed"));
196        assert_eq!(
197            merger_calls.load(std::sync::atomic::Ordering::SeqCst),
198            0,
199            "merger must not run when inner runnable fails"
200        );
201    }
202}