use std::collections::HashMap;
use std::sync::Arc;
use entelix_core::{ExecutionContext, Result};
use futures::future::try_join_all;
use crate::runnable::Runnable;
type Branch<I, O> = (String, Arc<dyn Runnable<I, O>>);
pub struct RunnableParallel<I, O>
where
I: Clone + Send + 'static,
O: Send + 'static,
{
branches: Vec<Branch<I, O>>,
}
impl<I, O> RunnableParallel<I, O>
where
I: Clone + Send + 'static,
O: Send + 'static,
{
pub fn new() -> Self {
Self {
branches: Vec::new(),
}
}
#[must_use]
pub fn branch<R>(mut self, name: impl Into<String>, runnable: R) -> Self
where
R: Runnable<I, O> + 'static,
{
self.branches.push((name.into(), Arc::new(runnable)));
self
}
pub fn len(&self) -> usize {
self.branches.len()
}
pub fn is_empty(&self) -> bool {
self.branches.is_empty()
}
}
impl<I, O> Default for RunnableParallel<I, O>
where
I: Clone + Send + 'static,
O: Send + 'static,
{
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl<I, O> Runnable<I, HashMap<String, O>> for RunnableParallel<I, O>
where
I: Clone + Send + Sync + 'static,
O: Send + 'static,
{
async fn invoke(&self, input: I, ctx: &ExecutionContext) -> Result<HashMap<String, O>> {
let futures = self.branches.iter().map(|(name, runnable)| {
let input = input.clone();
let runnable = Arc::clone(runnable);
let name = name.clone();
let ctx = ctx.clone();
async move {
let out = runnable.invoke(input, &ctx).await?;
Ok::<_, entelix_core::Error>((name, out))
}
});
let pairs = try_join_all(futures).await?;
Ok(pairs.into_iter().collect())
}
}