use std::collections::{HashMap, HashSet, VecDeque};
pub struct DAG<T: Clone + Eq + std::hash::Hash> {
nodes: HashSet<T>,
edges: HashMap<T, Vec<T>>,
in_degree: HashMap<T, usize>,
}
impl<T: Clone + Eq + std::hash::Hash + Ord> DAG<T> {
pub fn new() -> Self {
Self {
nodes: HashSet::new(),
edges: HashMap::new(),
in_degree: HashMap::new(),
}
}
pub fn add_node(&mut self, node: T) {
self.nodes.insert(node.clone());
self.edges.entry(node.clone()).or_default();
self.in_degree.entry(node).or_insert(0);
}
pub fn add_edge(&mut self, from: T, to: T) {
self.edges.entry(from.clone()).or_default().push(to.clone());
*self.in_degree.entry(to).or_insert(0) += 1;
self.in_degree.entry(from).or_insert(0);
}
pub fn toposort(&self) -> Result<Vec<T>, DAGError> {
let mut in_deg = self.in_degree.clone();
let mut zero: Vec<T> = in_deg
.iter()
.filter(|(_, &d)| d == 0)
.map(|(n, _)| n.clone())
.collect();
zero.sort();
let mut queue: VecDeque<T> = zero.into();
let mut result = Vec::with_capacity(self.nodes.len());
while let Some(node) = queue.pop_front() {
result.push(node.clone());
if let Some(neighbors) = self.edges.get(&node) {
let mut nexts: Vec<T> = neighbors
.iter()
.filter_map(|next| {
let deg = in_deg.get_mut(next).unwrap();
*deg -= 1;
if *deg == 0 {
Some(next.clone())
} else {
None
}
})
.collect();
nexts.sort();
for n in nexts {
queue.push_back(n);
}
}
}
if result.len() != self.nodes.len() {
Err(DAGError::CycleDetected)
} else {
Ok(result)
}
}
pub fn parallel_groups(&self) -> Result<Vec<Vec<T>>, DAGError> {
let mut in_deg = self.in_degree.clone();
let mut current: Vec<T> = in_deg
.iter()
.filter(|(_, &d)| d == 0)
.map(|(n, _)| n.clone())
.collect();
current.sort();
let mut groups = Vec::new();
let mut processed = 0;
while !current.is_empty() {
let mut next = Vec::new();
for node in ¤t {
if let Some(neighbors) = self.edges.get(node) {
for n in neighbors {
let deg = in_deg.get_mut(n).unwrap();
*deg -= 1;
if *deg == 0 {
next.push(n.clone());
}
}
}
}
processed += current.len();
groups.push(current);
next.sort();
current = next;
}
if processed != self.nodes.len() {
Err(DAGError::CycleDetected)
} else {
Ok(groups)
}
}
}
impl<T: Clone + Eq + std::hash::Hash + Ord> Default for DAG<T> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub enum DAGError {
CycleDetected,
}
impl std::fmt::Display for DAGError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::CycleDetected => write!(f, "Cycle detected in DAG"),
}
}
}
impl std::error::Error for DAGError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_toposort_linear() {
let mut dag = DAG::new();
dag.add_node("a");
dag.add_node("b");
dag.add_node("c");
dag.add_edge("a", "b");
dag.add_edge("b", "c");
let sorted = dag.toposort().unwrap();
assert_eq!(sorted, vec!["a", "b", "c"]);
}
#[test]
fn test_toposort_diamond() {
let mut dag = DAG::new();
dag.add_node("a");
dag.add_node("b");
dag.add_node("c");
dag.add_node("d");
dag.add_edge("a", "b");
dag.add_edge("a", "c");
dag.add_edge("b", "d");
dag.add_edge("c", "d");
let sorted = dag.toposort().unwrap();
assert_eq!(sorted[0], "a");
assert_eq!(*sorted.last().unwrap(), "d");
}
#[test]
fn test_cycle_detection() {
let mut dag = DAG::new();
dag.add_node("a");
dag.add_node("b");
dag.add_edge("a", "b");
dag.add_edge("b", "a");
assert!(dag.toposort().is_err());
}
#[test]
fn test_parallel_groups() {
let mut dag = DAG::new();
dag.add_node("a");
dag.add_node("b");
dag.add_node("c");
dag.add_edge("a", "b");
dag.add_edge("a", "c");
let groups = dag.parallel_groups().unwrap();
assert_eq!(groups.len(), 2); assert_eq!(groups[0], vec!["a"]);
}
}