use async_trait::async_trait;
use futures::future::try_join_all;
use futures::{future::BoxFuture, FutureExt};
use std::fmt;
#[cfg(feature = "tracing")]
use tracing::{instrument, Instrument};
use crate::error::AnchorChainError;
use crate::node::Node;
type CombinationFunction<I, O> =
Box<dyn Fn(Vec<I>) -> BoxFuture<'static, Result<O, AnchorChainError>> + Send + Sync>;
pub struct ParallelNode<I, O, C>
where
I: Clone + Send + Sync,
O: Send + Sync,
C: Send + Sync,
{
pub nodes: Vec<Box<dyn Node<Input = I, Output = O> + Send + Sync>>,
pub function: CombinationFunction<O, C>,
}
impl<I, O, C> ParallelNode<I, O, C>
where
I: Clone + Send + Sync,
O: Send + Sync,
C: Send + Sync,
{
pub fn new(
nodes: Vec<Box<dyn Node<Input = I, Output = O> + Send + Sync>>,
function: CombinationFunction<O, C>,
) -> Self {
ParallelNode { nodes, function }
}
}
#[async_trait]
impl<I, O, C> Node for ParallelNode<I, O, C>
where
I: Clone + Send + Sync + fmt::Debug,
O: Send + Sync + fmt::Debug,
C: Send + Sync + fmt::Debug,
{
type Input = I;
type Output = C;
#[cfg_attr(feature = "tracing", instrument)]
async fn process(&self, input: Self::Input) -> Result<Self::Output, AnchorChainError> {
let futures = self.nodes.iter().map(|node| {
let input_clone = input.clone();
async move { node.process(input_clone).await }
});
let results = try_join_all(futures);
#[cfg(feature = "tracing")]
let results = results.instrument(tracing::info_span!("Joining parallel node futures"));
let results = results.await?;
let combined_results = (self.function)(results);
#[cfg(feature = "tracing")]
let combined_results =
combined_results.instrument(tracing::info_span!("Combining parallel node outputs"));
combined_results.await
}
}
impl<I, O, C> fmt::Debug for ParallelNode<I, O, C>
where
I: fmt::Debug + Clone + Send + Sync,
O: fmt::Debug + Send + Sync,
C: fmt::Debug + Send + Sync,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ParallelNode")
.field("nodes", &self.nodes)
.field("function", &format_args!("<function/closure>"))
.finish()
}
}
pub fn to_boxed_future<F, I, O>(
f: F,
) -> Box<dyn Fn(I) -> BoxFuture<'static, Result<O, AnchorChainError>> + Send + Sync>
where
F: Fn(I) -> Result<O, AnchorChainError> + Send + Sync + Clone + 'static,
I: Send + 'static,
{
Box::new(move |input| {
let f_clone = f.clone();
async move { f_clone(input) }.boxed()
})
}