Skip to main content

entelix_graph/
dispatch.rs

1//! `Dispatch<T>` — fan-out primitive for parallel sub-task dispatch.
2//!
3//! LangGraph's `Send` API lets a conditional edge emit *multiple*
4//! follow-up invocations into the same node, one per `Send`. We name
5//! the Rust counterpart [`Dispatch<T>`] rather than `Send<T>` because
6//! Rust already reserves `Send` as the unsafe auto trait — a struct
7//! with the same identifier would force every downstream user to
8//! disambiguate `T: std::marker::Send` from `T: entelix_graph::Send`.
9//! [`scatter`] runs a `Vec<Dispatch<I>>` through a [`Runnable<I, O>`]
10//! in parallel and collects the outputs in submission order.
11//!
12//! Like [`Reducer<T>`](crate::Reducer), this is a standalone helper:
13//! `Dispatch` does not plug into `StateGraph::add_conditional_edges`
14//! as a "this node receives N sends" emit. Users fan out manually
15//! from inside their node closures via [`scatter`].
16
17use std::sync::Arc;
18
19use entelix_core::context::ExecutionContext;
20use entelix_core::error::{Error, Result};
21use entelix_runnable::Runnable;
22use futures::StreamExt;
23use futures::future::BoxFuture;
24use futures::stream::FuturesOrdered;
25
26/// One unit of fan-out work — semantically equivalent to LangGraph's
27/// `Send(node, payload)`. Carries the payload that a fanned-out
28/// runnable will receive when [`scatter`] runs.
29#[derive(Clone, Debug)]
30pub struct Dispatch<T> {
31    /// Payload the runnable will see.
32    pub payload: T,
33}
34
35impl<T> Dispatch<T> {
36    /// Build a single dispatch.
37    pub const fn new(payload: T) -> Self {
38        Self { payload }
39    }
40}
41
42impl<T> From<T> for Dispatch<T> {
43    fn from(payload: T) -> Self {
44        Self { payload }
45    }
46}
47
48/// Fan a `Vec<Dispatch<I>>` through a `Runnable<I, O>` in parallel
49/// and collect the outputs in submission order.
50///
51/// `concurrency` caps the number of in-flight invocations; the
52/// remainder are queued. Setting it to `0` is rejected at runtime
53/// because a zero cap deadlocks the consumer.
54///
55/// Failure is fail-fast: as soon as one branch returns `Err`, the
56/// remaining in-flight futures are dropped and the error surfaces.
57/// All branches run under a [`ExecutionContext::child`] scope —
58/// cancelling the parent cascades to siblings, and on scatter exit
59/// (success, error, or panic) the scope token is fired so any branch
60/// observing `ctx.cancellation()` cooperatively unwinds. The parent
61/// context is left untouched.
62pub async fn scatter<R, I, O>(
63    runnable: Arc<R>,
64    sends: Vec<Dispatch<I>>,
65    ctx: &ExecutionContext,
66    concurrency: usize,
67) -> Result<Vec<O>>
68where
69    R: Runnable<I, O> + 'static,
70    I: Send + Sync + 'static,
71    O: Send + Sync + 'static,
72{
73    if concurrency == 0 {
74        return Err(Error::config("scatter concurrency must be > 0"));
75    }
76    // Scope-bound child context. `_guard` cancels the scope token on
77    // every exit path including panic-unwind, so still-racing
78    // branches see `ctx.cancellation()` fire cooperatively.
79    let scope_ctx = ctx.child();
80    let _guard = ScopeCancelGuard {
81        token: scope_ctx.cancellation().clone(),
82    };
83    // FuturesOrdered requires every queued future to share one type.
84    // Two `async move` blocks produce two anonymous types, so we
85    // erase via `BoxFuture` (heap allocation per dispatch — cheap
86    // relative to the model call this typically wraps).
87    let mut in_flight: FuturesOrdered<BoxFuture<'static, Result<O>>> = FuturesOrdered::new();
88    let mut iter = sends.into_iter();
89    let make_future = |send: Dispatch<I>| -> BoxFuture<'static, Result<O>> {
90        let runnable = Arc::clone(&runnable);
91        let ctx_clone = scope_ctx.clone();
92        Box::pin(async move { runnable.invoke(send.payload, &ctx_clone).await })
93    };
94    for _ in 0..concurrency {
95        let Some(send) = iter.next() else { break };
96        in_flight.push_back(make_future(send));
97    }
98    let mut out = Vec::new();
99    while let Some(result) = in_flight.next().await {
100        match result {
101            Ok(v) => out.push(v),
102            Err(e) => return Err(e),
103        }
104        if let Some(send) = iter.next() {
105            in_flight.push_back(make_future(send));
106        }
107    }
108    Ok(out)
109}
110
111/// RAII fire-on-drop for the scope cancellation token. Ensures that
112/// every exit path from [`scatter`] — early return, fail-fast `Err`,
113/// or panic-unwind — signals siblings to wind down.
114struct ScopeCancelGuard {
115    token: entelix_core::cancellation::CancellationToken,
116}
117
118impl Drop for ScopeCancelGuard {
119    fn drop(&mut self) {
120        self.token.cancel();
121    }
122}
123
124#[cfg(test)]
125#[allow(clippy::unwrap_used)]
126mod tests {
127    use std::sync::Mutex;
128    use std::sync::atomic::{AtomicUsize, Ordering};
129
130    use entelix_runnable::RunnableLambda;
131
132    use super::*;
133
134    #[tokio::test]
135    async fn scatter_returns_results_in_submission_order() {
136        let runnable = Arc::new(RunnableLambda::new(|n: u32, _ctx| async move {
137            Ok::<_, _>(n * 2)
138        }));
139        let sends = vec![
140            Dispatch::new(1_u32),
141            Dispatch::new(2),
142            Dispatch::new(3),
143            Dispatch::new(4),
144        ];
145        let out = scatter(runnable, sends, &ExecutionContext::new(), 2)
146            .await
147            .unwrap();
148        assert_eq!(out, vec![2, 4, 6, 8]);
149    }
150
151    #[tokio::test]
152    async fn scatter_zero_concurrency_is_rejected() {
153        let runnable = Arc::new(RunnableLambda::new(
154            |n: u32, _ctx| async move { Ok::<_, _>(n) },
155        ));
156        let err = scatter(
157            runnable,
158            vec![Dispatch::new(1_u32)],
159            &ExecutionContext::new(),
160            0,
161        )
162        .await
163        .unwrap_err();
164        assert!(format!("{err}").contains("concurrency"));
165    }
166
167    #[tokio::test]
168    async fn scatter_caps_in_flight_invocations() {
169        let peak = Arc::new(AtomicUsize::new(0));
170        let in_flight = Arc::new(AtomicUsize::new(0));
171        let history = Arc::new(Mutex::new(Vec::<usize>::new()));
172        let peak_for_lambda = Arc::clone(&peak);
173        let in_flight_for_lambda = Arc::clone(&in_flight);
174        let history_for_lambda = Arc::clone(&history);
175
176        let runnable = Arc::new(RunnableLambda::new(move |n: u32, _ctx| {
177            let peak = Arc::clone(&peak_for_lambda);
178            let in_flight = Arc::clone(&in_flight_for_lambda);
179            let history = Arc::clone(&history_for_lambda);
180            async move {
181                let now = in_flight.fetch_add(1, Ordering::SeqCst) + 1;
182                history.lock().unwrap().push(now);
183                peak.fetch_max(now, Ordering::SeqCst);
184                tokio::task::yield_now().await;
185                in_flight.fetch_sub(1, Ordering::SeqCst);
186                Ok::<_, _>(n)
187            }
188        }));
189        let sends: Vec<_> = (0..6_u32).map(Dispatch::new).collect();
190        let _ = scatter(runnable, sends, &ExecutionContext::new(), 2)
191            .await
192            .unwrap();
193        assert!(
194            peak.load(Ordering::SeqCst) <= 2,
195            "peak in-flight exceeded 2"
196        );
197    }
198
199    #[tokio::test]
200    async fn scatter_fail_fast_on_first_error() {
201        let runnable = Arc::new(RunnableLambda::new(|n: u32, _ctx| async move {
202            if n == 3 {
203                Err(entelix_core::Error::invalid_request("boom"))
204            } else {
205                Ok::<_, _>(n)
206            }
207        }));
208        let sends: Vec<_> = (1..=5_u32).map(Dispatch::new).collect();
209        let err = scatter(runnable, sends, &ExecutionContext::new(), 2)
210            .await
211            .unwrap_err();
212        assert!(format!("{err}").contains("boom"));
213    }
214
215    #[tokio::test]
216    async fn fail_fast_cancels_scope_token_for_siblings() {
217        // A failing branch should signal still-running siblings via the
218        // scope cancellation token without escaping to the parent.
219        let parent_ctx = ExecutionContext::new();
220        let parent_token = parent_ctx.cancellation().clone();
221
222        let observed_cancel = Arc::new(Mutex::new(Vec::<bool>::new()));
223        let observed_for_lambda = Arc::clone(&observed_cancel);
224
225        let runnable = Arc::new(RunnableLambda::new(move |n: u32, ctx: ExecutionContext| {
226            let observed = Arc::clone(&observed_for_lambda);
227            async move {
228                if n == 1 {
229                    // First branch fails immediately, triggering scope cancel.
230                    return Err(entelix_core::Error::invalid_request("boom"));
231                }
232                // Sibling waits long enough for the fail-fast to land,
233                // then records whether the *child* token has fired.
234                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
235                observed.lock().unwrap().push(ctx.is_cancelled());
236                Ok::<_, _>(n)
237            }
238        }));
239        let sends: Vec<_> = (1..=4_u32).map(Dispatch::new).collect();
240        let _ = scatter(runnable, sends, &parent_ctx, 4).await;
241        // Parent context untouched.
242        assert!(
243            !parent_token.is_cancelled(),
244            "scatter must not bubble cancellation to parent"
245        );
246    }
247
248    #[tokio::test]
249    async fn parent_cancel_cascades_to_branches() {
250        let parent_ctx = ExecutionContext::new();
251        let parent_token = parent_ctx.cancellation().clone();
252
253        let runnable = Arc::new(RunnableLambda::new(
254            |_n: u32, ctx: ExecutionContext| async move {
255                tokio::select! {
256                    () = ctx.cancellation().cancelled() => {
257                        Err(entelix_core::Error::Cancelled)
258                    }
259                    () = tokio::time::sleep(std::time::Duration::from_secs(5)) => {
260                        Ok::<_, _>(0)
261                    }
262                }
263            },
264        ));
265        let parent_token_for_canceller = parent_token.clone();
266        let canceller = tokio::spawn(async move {
267            tokio::time::sleep(std::time::Duration::from_millis(20)).await;
268            parent_token_for_canceller.cancel();
269        });
270
271        let sends: Vec<_> = (0..4_u32).map(Dispatch::new).collect();
272        let err = scatter(runnable, sends, &parent_ctx, 4).await.unwrap_err();
273        canceller.await.unwrap();
274        assert!(matches!(err, entelix_core::Error::Cancelled));
275    }
276
277    #[tokio::test]
278    async fn empty_sends_returns_empty_output() {
279        let runnable = Arc::new(RunnableLambda::new(
280            |n: u32, _ctx| async move { Ok::<_, _>(n) },
281        ));
282        let out = scatter::<_, u32, u32>(runnable, Vec::new(), &ExecutionContext::new(), 4)
283            .await
284            .unwrap();
285        assert!(out.is_empty());
286    }
287}