dag-flow 0.1.6

DAG Flow is a simple DAG workflow engine.
Documentation
use std::collections::HashMap;
use std::collections::VecDeque;
use std::hash::Hash;
use std::sync::Arc;
use std::sync::RwLock;

use futures::FutureExt;
use futures::StreamExt;
use futures::stream::FuturesUnordered;

use crate::context::Context;
use crate::task::DynTask;
use crate::task::Task;

mod dag;
use dag::BuildDagError;
use dag::Dag;
use dag::Edge;
use dag::NodeData;

#[derive(Clone)]
pub struct Engine<'a, I, D> {
    dag: Dag<I>,
    tasks: Arc<HashMap<I, Arc<DynTask<'a, I, D>>>>,
}

impl<'a, I, D> Engine<'a, I, D> {
    pub fn new() -> Self {
        Self {
            dag: Dag::new(),
            tasks: Arc::new(HashMap::new()),
        }
    }

    pub fn builder() -> EngineBuilder<'a, I, D> {
        EngineBuilder::new()
    }
}

impl<I, D> Default for Engine<'_, I, D> {
    fn default() -> Self {
        Self::new()
    }
}

impl<'a, 'cx, I, D> Engine<'a, I, D>
where
    'a: 'cx,
    I: Clone + Eq + Hash + Send + 'cx,
    D: Clone + Send + Sync + 'cx,
{
    pub async fn run(&self, context: Context<'cx, I, Option<D>>) {
        let graph = self.dag.graph();
        let mut in_degrees: HashMap<_, _> = graph
            .iter()
            .map(|(node, NodeData { in_neighbors, .. })| (node, in_neighbors.len()))
            .collect();

        let mut queue: VecDeque<_> = in_degrees
            .iter()
            .flat_map(|(&node, &in_degree)| (in_degree == 0).then_some(node))
            .collect();

        while let Some(node) = queue.pop_front() {
            if let Some(task) = self.tasks.get(node).cloned() {
                let inputs = graph[node]
                    .in_neighbors
                    .iter()
                    .flat_map(|in_neighbor| {
                        context
                            .get(in_neighbor)
                            .map(|data| (in_neighbor.clone(), data))
                    })
                    .collect();

                context.set(
                    node.clone(),
                    async move { task.run(inputs).await }.boxed().shared(),
                );
            }

            for out_neighbor in &graph[node].out_neighbors {
                let Some(in_degree) = in_degrees.get_mut(out_neighbor) else {
                    continue;
                };

                *in_degree -= 1;
                if *in_degree == 0 {
                    queue.push_back(out_neighbor);
                }
            }
        }

        graph
            .iter()
            .flat_map(|(node, _)| {
                self.tasks
                    .get(node)
                    .filter(|task| task.is_auto())
                    .and_then(|_| context.get(node))
            })
            .collect::<FuturesUnordered<_>>()
            .collect::<Vec<_>>()
            .await;
    }
}

#[derive(Clone)]
pub struct EngineBuilder<'a, I, D> {
    #[allow(clippy::type_complexity)]
    tasks: Arc<RwLock<HashMap<I, Box<DynTask<'a, I, D>>>>>,
}

impl<I, D> EngineBuilder<'_, I, D> {
    pub fn new() -> Self {
        Self {
            tasks: Arc::new(RwLock::new(HashMap::new())),
        }
    }
}

impl<I, D> Default for EngineBuilder<'_, I, D> {
    fn default() -> Self {
        Self::new()
    }
}

impl<I, D> EngineBuilder<'_, I, D>
where
    I: Eq + Hash,
{
    pub fn exists_task<T>(&self, task: &T) -> bool
    where
        T: Task<I, D>,
    {
        self.exists_task_by_id(&task.id())
    }

    pub fn exists_task_by_id(&self, id: &I) -> bool {
        self.tasks.read().unwrap().contains_key(id)
    }

    pub fn remove_task<T>(&self, task: &T) -> &Self
    where
        T: Task<I, D>,
    {
        self.remove_task_by_id(&task.id())
    }

    pub fn remove_task_by_id(&self, id: &I) -> &Self {
        self.tasks.write().unwrap().remove(id);
        self
    }
}

impl<'a, I, D> EngineBuilder<'a, I, D>
where
    I: Eq + Hash,
{
    pub fn add_task<T>(&self, task: T) -> &Self
    where
        T: Task<I, D> + 'a,
    {
        self.tasks
            .write()
            .unwrap()
            .insert(task.id(), DynTask::new_box(task));

        self
    }
}

impl<'a, I, D> EngineBuilder<'a, I, D>
where
    I: Clone + Eq + Hash,
{
    pub fn build(self) -> Result<Engine<'a, I, D>, BuildEngineError> {
        let tasks = Arc::into_inner(self.tasks).unwrap().into_inner().unwrap();
        let mut builder = Dag::builder();

        for id in tasks.keys().cloned() {
            builder.add_node(id);
        }

        for (id, task) in &tasks {
            for dependency in task.dependencies() {
                builder.add_edge(Edge::new(dependency, id.clone()));
            }
        }

        Ok(Engine {
            dag: builder.build().map_err(EngineErrorKind::DagBuildFailed)?,
            tasks: Arc::new(
                tasks
                    .into_iter()
                    .map(|(id, task)| (id, task.into()))
                    .collect(),
            ),
        })
    }
}

#[derive(Clone, Copy, Debug, Eq, PartialEq, thiserror::Error)]
#[error(transparent)]
pub struct BuildEngineError(#[from] EngineErrorKind);

#[derive(Clone, Copy, Debug, Eq, PartialEq, thiserror::Error)]
enum EngineErrorKind {
    #[error("failed to build DAG: {0}")]
    DagBuildFailed(#[from] BuildDagError),
}