shotover 0.7.2

Shotover API for building custom transforms
Documentation
use super::{DownChainProtocol, TransformContextBuilder, TransformContextConfig, UpChainProtocol};
use crate::config::chain::TransformChainConfig;
use crate::message::Messages;
use crate::transforms::chain::{TransformChain, TransformChainBuilder};
use crate::transforms::{ChainState, Transform, TransformBuilder, TransformConfig};
use anyhow::Result;
use async_trait::async_trait;
use futures::Stream;
use futures::StreamExt;
use futures::stream::{FuturesOrdered, FuturesUnordered};
use futures::task::{Context, Poll};
use serde::{Deserialize, Serialize};
use std::future::Future;
use std::pin::Pin;

struct ParallelMapBuilder {
    chains: Vec<TransformChainBuilder>,
    ordered: bool,
}

struct ParallelMap {
    chains: Vec<TransformChain>,
    ordered: bool,
}

enum UOFutures<T: Future> {
    Ordered(FuturesOrdered<T>),
    Unordered(FuturesUnordered<T>),
}

impl<T> UOFutures<T>
where
    T: Future,
{
    fn new(ordered: bool) -> Self {
        if ordered {
            Self::Ordered(FuturesOrdered::new())
        } else {
            Self::Unordered(FuturesUnordered::new())
        }
    }

    fn push(&mut self, future: T) {
        match self {
            UOFutures::Ordered(o) => o.push_back(future),
            UOFutures::Unordered(u) => u.push(future),
        }
    }
}

impl<T> Stream for UOFutures<T>
where
    T: Future,
{
    type Item = T::Output;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        match self.get_mut() {
            UOFutures::Ordered(o) => Pin::new(o).poll_next(cx),
            UOFutures::Unordered(u) => Pin::new(u).poll_next(cx),
        }
    }
}

#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct ParallelMapConfig {
    pub parallelism: u32,
    pub chain: TransformChainConfig,
    pub ordered_results: bool,
}

const NAME: &str = "ParallelMap";
#[typetag::serde(name = "ParallelMap")]
#[async_trait(?Send)]
impl TransformConfig for ParallelMapConfig {
    async fn get_builder(
        &self,
        transform_context: TransformContextConfig,
    ) -> Result<Box<dyn TransformBuilder>> {
        let mut chains = vec![];
        for _ in 0..self.parallelism {
            let transform_context_config = TransformContextConfig {
                chain_name: "parallel_map_chain".into(),
                up_chain_protocol: transform_context.up_chain_protocol,
            };
            chains.push(self.chain.get_builder(transform_context_config).await?);
        }

        Ok(Box::new(ParallelMapBuilder {
            chains,
            ordered: self.ordered_results,
        }))
    }

    fn up_chain_protocol(&self) -> UpChainProtocol {
        UpChainProtocol::Any
    }

    fn down_chain_protocol(&self) -> DownChainProtocol {
        DownChainProtocol::Terminating
    }
}

#[async_trait]
impl Transform for ParallelMap {
    fn get_name(&self) -> &'static str {
        NAME
    }

    async fn transform<'shorter, 'longer: 'shorter>(
        &mut self,
        chain_state: &'shorter mut ChainState<'longer>,
    ) -> Result<Messages> {
        let mut results = Vec::with_capacity(chain_state.requests.len());
        let mut message_iter = chain_state.requests.drain(..);
        while message_iter.len() != 0 {
            let mut future = UOFutures::new(self.ordered);
            for chain in self.chains.iter_mut() {
                if let Some(message) = message_iter.next() {
                    future.push(async {
                        chain
                            .process_request(&mut ChainState::new_with_addr(
                                vec![message],
                                chain_state.local_addr,
                            ))
                            .await
                    });
                }
            }
            // We do this gnarly functional chain to unwrap each individual result and pop an error on the first one
            // then flatten it into one giant response.
            results.extend(
                future
                    .collect::<Vec<_>>()
                    .await
                    .into_iter()
                    .collect::<anyhow::Result<Vec<Messages>>>()
                    .into_iter()
                    .flat_map(|ms| ms.into_iter().flatten()),
            );
        }
        Ok(results)
    }
}

impl TransformBuilder for ParallelMapBuilder {
    fn build(&self, transform_context: TransformContextBuilder) -> Box<dyn Transform> {
        Box::new(ParallelMap {
            chains: self
                .chains
                .iter()
                .map(|x| x.build(transform_context.clone()))
                .collect(),
            ordered: self.ordered,
        })
    }

    fn get_name(&self) -> &'static str {
        NAME
    }

    fn validate(&self) -> Vec<String> {
        let mut errors = self
            .chains
            .iter()
            .flat_map(|chain| {
                chain
                    .validate()
                    .iter()
                    .map(|x| format!("  {x}"))
                    .collect::<Vec<String>>()
            })
            .collect::<Vec<String>>();

        if !errors.is_empty() {
            errors.insert(0, format!("{}:", self.get_name()));
        }

        errors
    }

    fn is_terminating(&self) -> bool {
        true
    }
}

#[cfg(test)]
mod parallel_map_tests {
    use crate::transforms::TransformBuilder;
    use crate::transforms::chain::TransformChainBuilder;
    use crate::transforms::debug::printer::DebugPrinter;
    use crate::transforms::null::NullSink;
    use crate::transforms::parallel_map::ParallelMapBuilder;
    use pretty_assertions::assert_eq;

    #[tokio::test]
    async fn test_validate_invalid_chain() {
        let chain_1 = TransformChainBuilder::new(
            vec![
                Box::<DebugPrinter>::default(),
                Box::<DebugPrinter>::default(),
                Box::<NullSink>::default(),
            ],
            "test-chain-1",
        );
        let chain_2 = TransformChainBuilder::new(vec![], "test-chain-2");

        let transform = ParallelMapBuilder {
            chains: vec![chain_1, chain_2],
            ordered: true,
        };

        assert_eq!(
            transform.validate(),
            vec![
                "ParallelMap:",
                "  test-chain-2 chain:",
                "    Chain cannot be empty"
            ]
        );
    }

    #[tokio::test]
    async fn test_validate_valid_chain() {
        let chain_1 = TransformChainBuilder::new(
            vec![
                Box::<DebugPrinter>::default(),
                Box::<DebugPrinter>::default(),
                Box::<NullSink>::default(),
            ],
            "test-chain-1",
        );
        let chain_2 = TransformChainBuilder::new(
            vec![
                Box::<DebugPrinter>::default(),
                Box::<DebugPrinter>::default(),
                Box::<NullSink>::default(),
            ],
            "test-chain-2",
        );

        let transform = ParallelMapBuilder {
            chains: vec![chain_1, chain_2],
            ordered: true,
        };

        assert_eq!(transform.validate(), Vec::<String>::new());
    }
}