use std::sync::Arc;
use entelix_core::context::ExecutionContext;
use entelix_core::error::{Error, Result};
use entelix_runnable::Runnable;
use futures::StreamExt;
use futures::future::BoxFuture;
use futures::stream::FuturesOrdered;
#[derive(Clone, Debug)]
pub struct Dispatch<T> {
pub payload: T,
}
impl<T> Dispatch<T> {
pub const fn new(payload: T) -> Self {
Self { payload }
}
}
impl<T> From<T> for Dispatch<T> {
fn from(payload: T) -> Self {
Self { payload }
}
}
pub async fn scatter<R, I, O>(
runnable: Arc<R>,
sends: Vec<Dispatch<I>>,
ctx: &ExecutionContext,
concurrency: usize,
) -> Result<Vec<O>>
where
R: Runnable<I, O> + 'static,
I: Send + Sync + 'static,
O: Send + Sync + 'static,
{
if concurrency == 0 {
return Err(Error::config("scatter concurrency must be > 0"));
}
let scope_ctx = ctx.child();
let _guard = ScopeCancelGuard {
token: scope_ctx.cancellation().clone(),
};
let mut in_flight: FuturesOrdered<BoxFuture<'static, Result<O>>> = FuturesOrdered::new();
let mut iter = sends.into_iter();
let make_future = |send: Dispatch<I>| -> BoxFuture<'static, Result<O>> {
let runnable = Arc::clone(&runnable);
let ctx_clone = scope_ctx.clone();
Box::pin(async move { runnable.invoke(send.payload, &ctx_clone).await })
};
for _ in 0..concurrency {
let Some(send) = iter.next() else { break };
in_flight.push_back(make_future(send));
}
let mut out = Vec::new();
while let Some(result) = in_flight.next().await {
match result {
Ok(v) => out.push(v),
Err(e) => return Err(e),
}
if let Some(send) = iter.next() {
in_flight.push_back(make_future(send));
}
}
Ok(out)
}
struct ScopeCancelGuard {
token: entelix_core::cancellation::CancellationToken,
}
impl Drop for ScopeCancelGuard {
fn drop(&mut self) {
self.token.cancel();
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
use entelix_runnable::RunnableLambda;
use super::*;
#[tokio::test]
async fn scatter_returns_results_in_submission_order() {
let runnable = Arc::new(RunnableLambda::new(|n: u32, _ctx| async move {
Ok::<_, _>(n * 2)
}));
let sends = vec![
Dispatch::new(1_u32),
Dispatch::new(2),
Dispatch::new(3),
Dispatch::new(4),
];
let out = scatter(runnable, sends, &ExecutionContext::new(), 2)
.await
.unwrap();
assert_eq!(out, vec![2, 4, 6, 8]);
}
#[tokio::test]
async fn scatter_zero_concurrency_is_rejected() {
let runnable = Arc::new(RunnableLambda::new(
|n: u32, _ctx| async move { Ok::<_, _>(n) },
));
let err = scatter(
runnable,
vec![Dispatch::new(1_u32)],
&ExecutionContext::new(),
0,
)
.await
.unwrap_err();
assert!(format!("{err}").contains("concurrency"));
}
#[tokio::test]
async fn scatter_caps_in_flight_invocations() {
let peak = Arc::new(AtomicUsize::new(0));
let in_flight = Arc::new(AtomicUsize::new(0));
let history = Arc::new(Mutex::new(Vec::<usize>::new()));
let peak_for_lambda = Arc::clone(&peak);
let in_flight_for_lambda = Arc::clone(&in_flight);
let history_for_lambda = Arc::clone(&history);
let runnable = Arc::new(RunnableLambda::new(move |n: u32, _ctx| {
let peak = Arc::clone(&peak_for_lambda);
let in_flight = Arc::clone(&in_flight_for_lambda);
let history = Arc::clone(&history_for_lambda);
async move {
let now = in_flight.fetch_add(1, Ordering::SeqCst) + 1;
history.lock().unwrap().push(now);
peak.fetch_max(now, Ordering::SeqCst);
tokio::task::yield_now().await;
in_flight.fetch_sub(1, Ordering::SeqCst);
Ok::<_, _>(n)
}
}));
let sends: Vec<_> = (0..6_u32).map(Dispatch::new).collect();
let _ = scatter(runnable, sends, &ExecutionContext::new(), 2)
.await
.unwrap();
assert!(
peak.load(Ordering::SeqCst) <= 2,
"peak in-flight exceeded 2"
);
}
#[tokio::test]
async fn scatter_fail_fast_on_first_error() {
let runnable = Arc::new(RunnableLambda::new(|n: u32, _ctx| async move {
if n == 3 {
Err(entelix_core::Error::invalid_request("boom"))
} else {
Ok::<_, _>(n)
}
}));
let sends: Vec<_> = (1..=5_u32).map(Dispatch::new).collect();
let err = scatter(runnable, sends, &ExecutionContext::new(), 2)
.await
.unwrap_err();
assert!(format!("{err}").contains("boom"));
}
#[tokio::test]
async fn fail_fast_cancels_scope_token_for_siblings() {
let parent_ctx = ExecutionContext::new();
let parent_token = parent_ctx.cancellation().clone();
let observed_cancel = Arc::new(Mutex::new(Vec::<bool>::new()));
let observed_for_lambda = Arc::clone(&observed_cancel);
let runnable = Arc::new(RunnableLambda::new(move |n: u32, ctx: ExecutionContext| {
let observed = Arc::clone(&observed_for_lambda);
async move {
if n == 1 {
return Err(entelix_core::Error::invalid_request("boom"));
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
observed.lock().unwrap().push(ctx.is_cancelled());
Ok::<_, _>(n)
}
}));
let sends: Vec<_> = (1..=4_u32).map(Dispatch::new).collect();
let _ = scatter(runnable, sends, &parent_ctx, 4).await;
assert!(
!parent_token.is_cancelled(),
"scatter must not bubble cancellation to parent"
);
}
#[tokio::test]
async fn parent_cancel_cascades_to_branches() {
let parent_ctx = ExecutionContext::new();
let parent_token = parent_ctx.cancellation().clone();
let runnable = Arc::new(RunnableLambda::new(
|_n: u32, ctx: ExecutionContext| async move {
tokio::select! {
() = ctx.cancellation().cancelled() => {
Err(entelix_core::Error::Cancelled)
}
() = tokio::time::sleep(std::time::Duration::from_secs(5)) => {
Ok::<_, _>(0)
}
}
},
));
let parent_token_for_canceller = parent_token.clone();
let canceller = tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
parent_token_for_canceller.cancel();
});
let sends: Vec<_> = (0..4_u32).map(Dispatch::new).collect();
let err = scatter(runnable, sends, &parent_ctx, 4).await.unwrap_err();
canceller.await.unwrap();
assert!(matches!(err, entelix_core::Error::Cancelled));
}
#[tokio::test]
async fn empty_sends_returns_empty_output() {
let runnable = Arc::new(RunnableLambda::new(
|n: u32, _ctx| async move { Ok::<_, _>(n) },
));
let out = scatter::<_, u32, u32>(runnable, Vec::new(), &ExecutionContext::new(), 4)
.await
.unwrap();
assert!(out.is_empty());
}
}