dag-flow 0.1.6

DAG Flow is a simple DAG workflow engine.
Documentation
use std::collections::HashMap;
use std::collections::VecDeque;
use std::hash::Hash;
use std::sync::Arc;

#[derive(Clone, Debug)]
pub struct Dag<N> {
    graph: Arc<HashMap<N, NodeData<N>>>,
}

impl<N> Dag<N> {
    pub fn new() -> Self {
        Self {
            graph: Arc::new(HashMap::new()),
        }
    }

    pub fn builder() -> DagBuilder<N> {
        DagBuilder::new()
    }
}

impl<N> Default for Dag<N> {
    fn default() -> Self {
        Self::new()
    }
}

impl<N> Dag<N>
where
    N: Clone + Eq + Hash,
{
    pub fn graph(&self) -> Arc<HashMap<N, NodeData<N>>> {
        self.graph.clone()
    }
}

#[derive(Clone, Debug)]
pub struct NodeData<N> {
    pub in_neighbors: Vec<N>,
    pub out_neighbors: Vec<N>,
}

impl<N> NodeData<N> {
    pub fn new() -> Self {
        Self::from(Vec::new(), Vec::new())
    }

    pub fn from(in_neighbors: Vec<N>, out_neighbors: Vec<N>) -> Self {
        Self {
            in_neighbors,
            out_neighbors,
        }
    }
}

impl<N> Default for NodeData<N> {
    fn default() -> Self {
        Self::new()
    }
}

#[derive(Clone, Debug)]
pub struct DagBuilder<N> {
    graph: HashMap<N, NodeData<N>>,
}

impl<N> DagBuilder<N> {
    pub fn new() -> Self {
        Self {
            graph: HashMap::new(),
        }
    }
}

impl<N> Default for DagBuilder<N> {
    fn default() -> Self {
        Self::new()
    }
}

impl<N> DagBuilder<N>
where
    N: Eq + Hash,
{
    pub fn add_node(&mut self, node: N) -> &mut Self {
        self.graph.entry(node).or_default();
        self
    }
}

impl<N> DagBuilder<N>
where
    N: Clone + Eq + Hash,
{
    pub fn add_edge(&mut self, Edge { from, to }: Edge<N>) -> &mut Self {
        if from == to {
            return self;
        }

        let graph = &mut self.graph;
        if graph
            .get(&from)
            .is_some_and(|NodeData { out_neighbors, .. }| out_neighbors.contains(&to))
        {
            return self;
        }

        if let Some(NodeData { out_neighbors, .. }) = graph.get_mut(&from) {
            out_neighbors.push(to.clone());
        } else {
            graph
                .entry(from.clone())
                .or_default()
                .out_neighbors
                .push(to.clone());
        }

        graph.entry(to).or_default().in_neighbors.push(from);

        self
    }
}

impl<N> DagBuilder<N>
where
    N: Eq + Hash,
{
    pub fn build(self) -> Result<Dag<N>, BuildDagError> {
        let graph = self.graph;
        let mut in_degrees: HashMap<_, _> = graph
            .iter()
            .map(|(node, NodeData { in_neighbors, .. })| (node, in_neighbors.len()))
            .collect();

        let mut queue: VecDeque<_> = in_degrees
            .iter()
            .flat_map(|(&node, &in_degree)| (in_degree == 0).then_some(node))
            .collect();

        while let Some(node) = queue.pop_front() {
            for out_neighbor in &graph[node].out_neighbors {
                let Some(in_degree) = in_degrees.get_mut(out_neighbor) else {
                    continue;
                };

                *in_degree -= 1;
                if *in_degree == 0 {
                    queue.push_back(out_neighbor);
                }
            }
        }

        if in_degrees.values().any(|&in_degree| in_degree > 0) {
            Err(DagErrorKind::Cycle)?
        }

        Ok(Dag {
            graph: Arc::new(graph),
        })
    }
}

#[derive(Clone, Copy, Debug, Default)]
pub struct Edge<N> {
    pub from: N,
    pub to: N,
}

impl<N> Edge<N> {
    pub fn new(from: N, to: N) -> Self {
        Self { from, to }
    }
}

#[derive(Clone, Copy, Debug, Eq, PartialEq, thiserror::Error)]
#[error(transparent)]
pub struct BuildDagError(#[from] DagErrorKind);

#[derive(Clone, Copy, Debug, Eq, PartialEq, thiserror::Error)]
enum DagErrorKind {
    #[error("cycle detected in directed graph")]
    Cycle,
}