mod error;
pub mod iters;
#[cfg(test)]
mod tests;
use std::{collections::{HashMap, HashSet}, hash::Hash};
pub use error::DagError;
use iters::{ChildrenIter, ChildrenIterMut, EdgesIter, EdgesIterMut, ParentsIter};
#[derive(Debug, Clone)]
pub struct Dag<NodeId, NodeData, EdgeData> {
nodes: HashMap<NodeId, NodeData>,
edges: HashMap<NodeId, HashMap<NodeId, EdgeData>>,
back_edges: HashMap<NodeId, HashSet<NodeId>>,
}
impl<NodeId, NodeData, EdgeData> Dag<NodeId, NodeData, EdgeData>
where
NodeId: Copy + Hash + Eq
{
pub fn new() -> Self {
Dag {
nodes: HashMap::new(),
edges: HashMap::new(),
back_edges: HashMap::new(),
}
}
fn in_cycle(&self, node_id: NodeId) -> bool {
let mut visited = HashSet::new();
let mut stack = vec![node_id];
while let Some(top) = stack.pop() {
if visited.contains(&top) {
return true;
}
visited.insert(top);
for child_id in self.children(top).map(|(id, _)| id) {
stack.push(child_id)
}
}
false
}
pub fn contains_node(&self, node_id: NodeId) -> bool {
self.nodes.contains_key(&node_id)
}
pub fn is_root(&self, node_id: NodeId) -> bool {
self.parents(node_id).count() == 0
}
pub fn insert_node(&mut self, node_id: NodeId, node_data: NodeData) -> Option<NodeData> {
if !self.edges.contains_key(&node_id) {
self.edges.insert(node_id, HashMap::new());
}
if !self.back_edges.contains_key(&node_id) {
self.back_edges.insert(node_id, HashSet::new());
}
self.nodes.insert(node_id, node_data)
}
pub fn contains_edge(&self, from: NodeId, to: NodeId) -> bool {
if let Some(children) = self.edges.get(&from) {
return children.contains_key(&to);
}
false
}
pub fn insert_edge(
&mut self,
from: NodeId,
to: NodeId,
edge_data: EdgeData,
) -> Result<Option<EdgeData>, DagError<NodeId, EdgeData>> {
if !self.nodes.contains_key(&from) {
return Err(DagError::NodeNotFound(from));
}
if !self.nodes.contains_key(&to) {
return Err(DagError::NodeNotFound(to));
}
let children = self
.edges
.get_mut(&from)
.unwrap_or_else(|| unreachable!("proved by contains_key"));
let result = children.insert(to, edge_data);
if self.in_cycle(from) {
let children = self
.edges
.get_mut(&from)
.unwrap_or_else(|| unreachable!("proved by contains_key"));
let data = children
.remove(&to)
.unwrap_or_else(|| unreachable!("proved by contains_key"));
return Err(DagError::HasCycle(from, to, data));
}
let parents = self
.back_edges
.get_mut(&to)
.unwrap_or_else(|| unreachable!("proved by contains_key"));
parents.insert(from);
Ok(result)
}
pub fn remove_edge(
&mut self,
from: NodeId,
to: NodeId,
) -> Result<Option<EdgeData>, DagError<NodeId, EdgeData>> {
if !self.nodes.contains_key(&from) {
return Err(DagError::NodeNotFound(from));
}
if !self.nodes.contains_key(&to) {
return Err(DagError::NodeNotFound(to));
}
let children = self
.edges
.get_mut(&from)
.unwrap_or_else(|| unreachable!("proved by contains_key"));
let result = children.remove(&to);
let parents = self
.back_edges
.get_mut(&to)
.unwrap_or_else(|| unreachable!("proved by contains_key"));
parents.remove(&from);
Ok(result)
}
pub fn remove_node(&mut self, node_id: NodeId) -> (Option<NodeData>, Vec<EdgeData>) {
if !self.contains_node(node_id) {
return (None, Vec::new());
}
let mut edge_data = Vec::new();
let ids = self.children(node_id).map(|(id, _)| id).collect::<Vec<_>>();
for child_id in ids {
let data = self
.remove_edge(node_id, child_id)
.unwrap_or_else(|_| {
unreachable!(
"Xdag ensures this node exists both in nodes and edges at the same time"
)
})
.unwrap_or_else(|| {
unreachable!("data is from self.children, so there must be such an edge")
});
edge_data.push(data);
}
let ids = self.parents(node_id).collect::<Vec<_>>();
for parent_id in ids {
let data = self
.remove_edge(parent_id, node_id)
.unwrap_or_else(|_| {
unreachable!(
"Xdag ensures this node exists both in nodes and edges at the same time"
)
})
.unwrap_or_else(|| {
unreachable!("data is from self.parents, so there must be such an edge")
});
edge_data.push(data)
}
let node_data = self.nodes.remove(&node_id);
(node_data, edge_data)
}
pub fn children(&self, node_id: NodeId) -> ChildrenIter<'_, NodeId, EdgeData> {
ChildrenIter {
iter: self.edges.get(&node_id).map(|map| map.iter()),
}
}
pub fn children_mut(&mut self, node_id: NodeId) -> ChildrenIterMut<'_, NodeId, EdgeData> {
ChildrenIterMut {
iter: self.edges.get_mut(&node_id).map(|map| map.iter_mut()),
}
}
pub fn parents(&self, node_id: NodeId) -> ParentsIter<'_, NodeId> {
ParentsIter {
iter: self.back_edges.get(&node_id).map(|set| set.iter()),
}
}
pub fn nodes_len(&self) -> usize {
self.nodes.len()
}
pub fn nodes(&self) -> impl Iterator<Item = (NodeId, &'_ NodeData)> {
self.nodes.iter().map(|(id, data)| (*id, data))
}
pub fn nodes_mut(&mut self) -> impl Iterator<Item = (NodeId, &'_ mut NodeData)> {
self.nodes.iter_mut().map(|(id, data)| (*id, data))
}
pub fn edges(&self) -> EdgesIter<'_, NodeId, EdgeData> {
EdgesIter {
from_iter: self.edges.iter(),
to_iter: None,
}
}
pub fn edges_mut(&mut self) -> EdgesIterMut<'_, NodeId, EdgeData> {
EdgesIterMut {
from_iter: self.edges.iter_mut(),
to_iter: None,
}
}
pub fn leaves(&self) -> impl Iterator<Item = (NodeId, &'_ NodeData)> {
self.nodes().filter(|(id, _)| self.children(*id).len() == 0)
}
pub fn roots(&self) -> impl Iterator<Item = (NodeId, &'_ NodeData)> {
self.nodes().filter(|(id, _)| self.parents(*id).len() == 0)
}
pub fn get_node(&self, node_id: NodeId) -> Option<&NodeData> {
self.nodes.get(&node_id)
}
pub fn get_node_mut(&mut self, node_id: NodeId) -> Option<&mut NodeData> {
self.nodes.get_mut(&node_id)
}
pub fn get_edge(
&self,
from: NodeId,
to: NodeId,
) -> Result<Option<&EdgeData>, DagError<NodeId, EdgeData>> {
if !self.nodes.contains_key(&from) {
return Err(DagError::NodeNotFound(from));
}
if !self.nodes.contains_key(&to) {
return Err(DagError::NodeNotFound(to));
}
let children = self
.edges
.get(&from)
.unwrap_or_else(|| unreachable!("proved by contains_key"));
Ok(children.get(&to))
}
pub fn get_edge_mut(
&mut self,
from: NodeId,
to: NodeId,
) -> Result<Option<&mut EdgeData>, DagError<NodeId, EdgeData>> {
if !self.nodes.contains_key(&from) {
return Err(DagError::NodeNotFound(from));
}
if !self.nodes.contains_key(&to) {
return Err(DagError::NodeNotFound(to));
}
let children = self
.edges
.get_mut(&from)
.unwrap_or_else(|| unreachable!("proved by contains_key"));
Ok(children.get_mut(&to))
}
}