entelix_graph/
dispatch.rs1use std::sync::Arc;
18
19use entelix_core::context::ExecutionContext;
20use entelix_core::error::{Error, Result};
21use entelix_runnable::Runnable;
22use futures::StreamExt;
23use futures::future::BoxFuture;
24use futures::stream::FuturesOrdered;
25
26#[derive(Clone, Debug)]
30pub struct Dispatch<T> {
31 pub payload: T,
33}
34
35impl<T> Dispatch<T> {
36 pub const fn new(payload: T) -> Self {
38 Self { payload }
39 }
40}
41
42impl<T> From<T> for Dispatch<T> {
43 fn from(payload: T) -> Self {
44 Self { payload }
45 }
46}
47
48pub async fn scatter<R, I, O>(
63 runnable: Arc<R>,
64 sends: Vec<Dispatch<I>>,
65 ctx: &ExecutionContext,
66 concurrency: usize,
67) -> Result<Vec<O>>
68where
69 R: Runnable<I, O> + 'static,
70 I: Send + Sync + 'static,
71 O: Send + Sync + 'static,
72{
73 if concurrency == 0 {
74 return Err(Error::config("scatter concurrency must be > 0"));
75 }
76 let scope_ctx = ctx.child();
80 let _guard = ScopeCancelGuard {
81 token: scope_ctx.cancellation().clone(),
82 };
83 let mut in_flight: FuturesOrdered<BoxFuture<'static, Result<O>>> = FuturesOrdered::new();
88 let mut iter = sends.into_iter();
89 let make_future = |send: Dispatch<I>| -> BoxFuture<'static, Result<O>> {
90 let runnable = Arc::clone(&runnable);
91 let ctx_clone = scope_ctx.clone();
92 Box::pin(async move { runnable.invoke(send.payload, &ctx_clone).await })
93 };
94 for _ in 0..concurrency {
95 let Some(send) = iter.next() else { break };
96 in_flight.push_back(make_future(send));
97 }
98 let mut out = Vec::new();
99 while let Some(result) = in_flight.next().await {
100 match result {
101 Ok(v) => out.push(v),
102 Err(e) => return Err(e),
103 }
104 if let Some(send) = iter.next() {
105 in_flight.push_back(make_future(send));
106 }
107 }
108 Ok(out)
109}
110
111struct ScopeCancelGuard {
115 token: entelix_core::cancellation::CancellationToken,
116}
117
118impl Drop for ScopeCancelGuard {
119 fn drop(&mut self) {
120 self.token.cancel();
121 }
122}
123
124#[cfg(test)]
125#[allow(clippy::unwrap_used)]
126mod tests {
127 use std::sync::Mutex;
128 use std::sync::atomic::{AtomicUsize, Ordering};
129
130 use entelix_runnable::RunnableLambda;
131
132 use super::*;
133
134 #[tokio::test]
135 async fn scatter_returns_results_in_submission_order() {
136 let runnable = Arc::new(RunnableLambda::new(|n: u32, _ctx| async move {
137 Ok::<_, _>(n * 2)
138 }));
139 let sends = vec![
140 Dispatch::new(1_u32),
141 Dispatch::new(2),
142 Dispatch::new(3),
143 Dispatch::new(4),
144 ];
145 let out = scatter(runnable, sends, &ExecutionContext::new(), 2)
146 .await
147 .unwrap();
148 assert_eq!(out, vec![2, 4, 6, 8]);
149 }
150
151 #[tokio::test]
152 async fn scatter_zero_concurrency_is_rejected() {
153 let runnable = Arc::new(RunnableLambda::new(
154 |n: u32, _ctx| async move { Ok::<_, _>(n) },
155 ));
156 let err = scatter(
157 runnable,
158 vec![Dispatch::new(1_u32)],
159 &ExecutionContext::new(),
160 0,
161 )
162 .await
163 .unwrap_err();
164 assert!(format!("{err}").contains("concurrency"));
165 }
166
167 #[tokio::test]
168 async fn scatter_caps_in_flight_invocations() {
169 let peak = Arc::new(AtomicUsize::new(0));
170 let in_flight = Arc::new(AtomicUsize::new(0));
171 let history = Arc::new(Mutex::new(Vec::<usize>::new()));
172 let peak_for_lambda = Arc::clone(&peak);
173 let in_flight_for_lambda = Arc::clone(&in_flight);
174 let history_for_lambda = Arc::clone(&history);
175
176 let runnable = Arc::new(RunnableLambda::new(move |n: u32, _ctx| {
177 let peak = Arc::clone(&peak_for_lambda);
178 let in_flight = Arc::clone(&in_flight_for_lambda);
179 let history = Arc::clone(&history_for_lambda);
180 async move {
181 let now = in_flight.fetch_add(1, Ordering::SeqCst) + 1;
182 history.lock().unwrap().push(now);
183 peak.fetch_max(now, Ordering::SeqCst);
184 tokio::task::yield_now().await;
185 in_flight.fetch_sub(1, Ordering::SeqCst);
186 Ok::<_, _>(n)
187 }
188 }));
189 let sends: Vec<_> = (0..6_u32).map(Dispatch::new).collect();
190 let _ = scatter(runnable, sends, &ExecutionContext::new(), 2)
191 .await
192 .unwrap();
193 assert!(
194 peak.load(Ordering::SeqCst) <= 2,
195 "peak in-flight exceeded 2"
196 );
197 }
198
199 #[tokio::test]
200 async fn scatter_fail_fast_on_first_error() {
201 let runnable = Arc::new(RunnableLambda::new(|n: u32, _ctx| async move {
202 if n == 3 {
203 Err(entelix_core::Error::invalid_request("boom"))
204 } else {
205 Ok::<_, _>(n)
206 }
207 }));
208 let sends: Vec<_> = (1..=5_u32).map(Dispatch::new).collect();
209 let err = scatter(runnable, sends, &ExecutionContext::new(), 2)
210 .await
211 .unwrap_err();
212 assert!(format!("{err}").contains("boom"));
213 }
214
215 #[tokio::test]
216 async fn fail_fast_cancels_scope_token_for_siblings() {
217 let parent_ctx = ExecutionContext::new();
220 let parent_token = parent_ctx.cancellation().clone();
221
222 let observed_cancel = Arc::new(Mutex::new(Vec::<bool>::new()));
223 let observed_for_lambda = Arc::clone(&observed_cancel);
224
225 let runnable = Arc::new(RunnableLambda::new(move |n: u32, ctx: ExecutionContext| {
226 let observed = Arc::clone(&observed_for_lambda);
227 async move {
228 if n == 1 {
229 return Err(entelix_core::Error::invalid_request("boom"));
231 }
232 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
235 observed.lock().unwrap().push(ctx.is_cancelled());
236 Ok::<_, _>(n)
237 }
238 }));
239 let sends: Vec<_> = (1..=4_u32).map(Dispatch::new).collect();
240 let _ = scatter(runnable, sends, &parent_ctx, 4).await;
241 assert!(
243 !parent_token.is_cancelled(),
244 "scatter must not bubble cancellation to parent"
245 );
246 }
247
248 #[tokio::test]
249 async fn parent_cancel_cascades_to_branches() {
250 let parent_ctx = ExecutionContext::new();
251 let parent_token = parent_ctx.cancellation().clone();
252
253 let runnable = Arc::new(RunnableLambda::new(
254 |_n: u32, ctx: ExecutionContext| async move {
255 tokio::select! {
256 () = ctx.cancellation().cancelled() => {
257 Err(entelix_core::Error::Cancelled)
258 }
259 () = tokio::time::sleep(std::time::Duration::from_secs(5)) => {
260 Ok::<_, _>(0)
261 }
262 }
263 },
264 ));
265 let parent_token_for_canceller = parent_token.clone();
266 let canceller = tokio::spawn(async move {
267 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
268 parent_token_for_canceller.cancel();
269 });
270
271 let sends: Vec<_> = (0..4_u32).map(Dispatch::new).collect();
272 let err = scatter(runnable, sends, &parent_ctx, 4).await.unwrap_err();
273 canceller.await.unwrap();
274 assert!(matches!(err, entelix_core::Error::Cancelled));
275 }
276
277 #[tokio::test]
278 async fn empty_sends_returns_empty_output() {
279 let runnable = Arc::new(RunnableLambda::new(
280 |n: u32, _ctx| async move { Ok::<_, _>(n) },
281 ));
282 let out = scatter::<_, u32, u32>(runnable, Vec::new(), &ExecutionContext::new(), 4)
283 .await
284 .unwrap();
285 assert!(out.is_empty());
286 }
287}