gesha-core 0.0.12

Core functionality for Gesha project
Documentation
use crate::testing::Error::TaskNotFound;
use crate::testing::TestCase;
use crate::{Error, Result, testing};
use futures::future::join_all;
use std::collections::HashMap;
use std::future::Future;
use tokio::task::{Id, JoinError, JoinHandle};
use tracing::Instrument;

pub fn run_parallel<A, B, F, Fut>(cases: Vec<TestCase<A>>, f: F) -> Joiner<A, B>
where
    F: Fn(TestCase<A>) -> Fut,
    Fut: Future<Output = Result<B>> + Send + 'static,
    B: Send + 'static,
{
    let (handles, map) = cases
        .into_iter()
        .map(|case| {
            let cloned = case.clone();
            let handle = tokio::spawn(f(case).in_current_span());
            (handle.id(), cloned, handle)
        })
        .fold((vec![], TestCaseMap::new()), TestCaseMap::accumulate);

    Joiner { handles, map }
}

pub struct Joiner<A, B> {
    map: TestCaseMap<A>,
    handles: Vec<JoinHandle<Result<B>>>,
}

impl<A, B> Joiner<A, B> {
    pub async fn join_all<F, Y>(mut self, f: F) -> (Vec<Y>, Vec<Error>)
    where
        F: Fn(&mut Vec<Y>, &mut Vec<Error>, Result<B>),
    {
        let (outputs, errors) = join_all(self.handles)
            .await
            .into_iter()
            .map(|result| self.map.flatten(result))
            .fold((vec![], vec![]), |(mut outputs, mut errors), result| {
                f(&mut outputs, &mut errors, result);
                (outputs, errors)
            });

        (outputs, errors)
    }

    pub async fn collect_errors(self) -> Vec<Error> {
        let (_, errors) = self
            .join_all(|_: &mut Vec<()>, errors, result| match result {
                Ok(_) => {}
                Err(e) => errors.push(e),
            })
            .await;

        errors
    }
}

#[derive(Default)]
struct TestCaseMap<A>(HashMap<Id, TestCase<A>>);

impl<A> TestCaseMap<A> {
    pub fn new() -> Self {
        Self(HashMap::new())
    }

    pub fn push(mut self, id: Id, case: TestCase<A>) -> Self {
        self.0.insert(id, case);
        self
    }

    pub fn extract(&mut self, id: Id) -> Result<TestCase<A>> {
        self.0
            .remove(&id)
            .ok_or_else(|| TaskNotFound { id: id.to_string() }.into())
    }

    pub fn accumulate<B>(
        (mut handles, map): (Vec<B>, Self),
        (id, case, handle): (Id, TestCase<A>, B),
    ) -> (Vec<B>, Self) {
        handles.push(handle);
        (handles, map.push(id, case))
    }

    pub fn flatten<B>(&mut self, result: std::result::Result<Result<B>, JoinError>) -> Result<B> {
        match result {
            Ok(x) => x,
            Err(cause) => Err(testing::Error::JoinError {
                schema_path: self.extract(cause.id())?.schema,
                cause,
            })?,
        }
    }
}