use petgraph::algo::{is_cyclic_directed, toposort};
use petgraph::{Directed, Graph};
use std::collections::{HashMap, HashSet};
use crate::error::ValidationError;
use crate::task::TaskNamespace;
#[derive(Debug, Clone)]
pub struct DependencyGraph {
nodes: HashSet<TaskNamespace>,
edges: HashMap<TaskNamespace, Vec<TaskNamespace>>,
}
impl DependencyGraph {
pub fn new() -> Self {
Self {
nodes: HashSet::new(),
edges: HashMap::new(),
}
}
pub fn add_node(&mut self, node_id: TaskNamespace) {
self.nodes.insert(node_id.clone());
self.edges.entry(node_id).or_default();
}
pub fn add_edge(&mut self, from: TaskNamespace, to: TaskNamespace) {
self.nodes.insert(from.clone());
self.nodes.insert(to.clone());
self.edges.entry(from).or_default().push(to);
}
pub fn remove_node(&mut self, node_id: &TaskNamespace) {
self.nodes.remove(node_id);
self.edges.remove(node_id);
for deps in self.edges.values_mut() {
deps.retain(|dep| dep != node_id);
}
}
pub fn remove_edge(&mut self, from: &TaskNamespace, to: &TaskNamespace) {
if let Some(deps) = self.edges.get_mut(from) {
deps.retain(|dep| dep != to);
}
}
pub fn get_dependencies(&self, node_id: &TaskNamespace) -> Option<&Vec<TaskNamespace>> {
self.edges.get(node_id)
}
pub fn get_dependents(&self, node_id: &TaskNamespace) -> Vec<TaskNamespace> {
self.edges
.iter()
.filter_map(|(k, v)| {
if v.contains(node_id) {
Some(k.clone())
} else {
None
}
})
.collect()
}
pub fn has_cycles(&self) -> bool {
let mut graph = Graph::<TaskNamespace, (), Directed>::new();
let mut node_indices = HashMap::new();
for node in &self.nodes {
let index = graph.add_node(node.clone());
node_indices.insert(node.clone(), index);
}
for (from, deps) in &self.edges {
if let Some(&from_index) = node_indices.get(from) {
for dep in deps {
if let Some(&dep_index) = node_indices.get(dep) {
graph.add_edge(dep_index, from_index, ());
}
}
}
}
is_cyclic_directed(&graph)
}
pub fn topological_sort(&self) -> Result<Vec<TaskNamespace>, ValidationError> {
if self.has_cycles() {
return Err(ValidationError::CyclicDependency {
cycle: self
.find_cycle()
.unwrap_or_default()
.into_iter()
.map(|ns| ns.to_string())
.collect(),
});
}
let mut graph = Graph::<TaskNamespace, (), Directed>::new();
let mut node_indices = HashMap::new();
for node in &self.nodes {
let index = graph.add_node(node.clone());
node_indices.insert(node.clone(), index);
}
for (from, deps) in &self.edges {
if let Some(&from_index) = node_indices.get(from) {
for dep in deps {
if let Some(&dep_index) = node_indices.get(dep) {
graph.add_edge(dep_index, from_index, ());
}
}
}
}
match toposort(&graph, None) {
Ok(sorted) => {
let result = sorted.into_iter().map(|idx| graph[idx].clone()).collect();
Ok(result)
}
Err(_) => Err(ValidationError::CyclicDependency {
cycle: self
.find_cycle()
.unwrap_or_default()
.into_iter()
.map(|ns| ns.to_string())
.collect(),
}),
}
}
pub(crate) fn find_cycle(&self) -> Option<Vec<TaskNamespace>> {
let mut visited = HashSet::new();
let mut rec_stack = HashSet::new();
let mut path = Vec::new();
for node in &self.nodes {
if !visited.contains(node) {
if let Some(cycle) = self.dfs_cycle(node, &mut visited, &mut rec_stack, &mut path) {
return Some(cycle);
}
}
}
None
}
fn dfs_cycle(
&self,
node: &TaskNamespace,
visited: &mut HashSet<TaskNamespace>,
rec_stack: &mut HashSet<TaskNamespace>,
path: &mut Vec<TaskNamespace>,
) -> Option<Vec<TaskNamespace>> {
visited.insert(node.clone());
rec_stack.insert(node.clone());
path.push(node.clone());
if let Some(deps) = self.edges.get(node) {
for dep in deps {
if !visited.contains(dep) {
if let Some(cycle) = self.dfs_cycle(dep, visited, rec_stack, path) {
return Some(cycle);
}
} else if rec_stack.contains(dep) {
let cycle_start = path.iter().position(|x| x == dep).unwrap_or(0);
let mut cycle = path[cycle_start..].to_vec();
cycle.push(dep.clone());
return Some(cycle);
}
}
}
rec_stack.remove(node);
path.pop();
None
}
}
impl Default for DependencyGraph {
fn default() -> Self {
Self::new()
}
}