use crate::error::{Result, ValidationError};
use crate::tasks::Task;
use petgraph::algo::toposort;
use petgraph::graph::{DiGraph, NodeIndex};
use petgraph::Direction;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Default)]
pub struct Dag {
graph: DiGraph<String, ()>,
indices: HashMap<String, NodeIndex>,
tasks: HashMap<String, Arc<dyn Task>>,
}
impl Dag {
pub fn new() -> Self {
Self::default()
}
fn node_for(&mut self, id: &str) -> NodeIndex {
if let Some(idx) = self.indices.get(id) {
return *idx;
}
let idx = self.graph.add_node(id.to_string());
self.indices.insert(id.to_string(), idx);
idx
}
pub fn add_task(&mut self, task: Arc<dyn Task>) -> Result<()> {
let id = task.id().to_string();
if self.tasks.contains_key(&id) {
return Err(ValidationError::DuplicateTask(id).into());
}
let deps = task.dependencies();
let task_idx = self.node_for(&id);
for dep in &deps {
if dep == &id {
return Err(ValidationError::SelfDependency(id).into());
}
let dep_idx = self.node_for(dep);
self.graph.update_edge(dep_idx, task_idx, ());
}
self.tasks.insert(id, task);
Ok(())
}
pub fn len(&self) -> usize {
self.tasks.len()
}
pub fn is_empty(&self) -> bool {
self.tasks.is_empty()
}
pub fn task(&self, id: &str) -> Option<Arc<dyn Task>> {
self.tasks.get(id).cloned()
}
pub fn task_ids(&self) -> Vec<String> {
self.tasks.keys().cloned().collect()
}
pub fn dependencies_of(&self, id: &str) -> Vec<String> {
self.neighbors(id, Direction::Incoming)
}
pub fn dependents_of(&self, id: &str) -> Vec<String> {
self.neighbors(id, Direction::Outgoing)
}
fn neighbors(&self, id: &str, dir: Direction) -> Vec<String> {
match self.indices.get(id) {
Some(&idx) => self
.graph
.neighbors_directed(idx, dir)
.map(|n| self.graph[n].clone())
.collect(),
None => Vec::new(),
}
}
pub fn validate(&self) -> Result<()> {
for (id, task) in &self.tasks {
for dep in task.dependencies() {
if !self.tasks.contains_key(&dep) {
return Err(ValidationError::MissingDependency {
task: id.clone(),
dep,
}
.into());
}
}
}
if let Err(cycle) = toposort(&self.graph, None) {
let id = self.graph[cycle.node_id()].clone();
return Err(ValidationError::Cycle(id).into());
}
Ok(())
}
pub fn topological_order(&self) -> Result<Vec<String>> {
match toposort(&self.graph, None) {
Ok(order) => Ok(order
.into_iter()
.map(|n| self.graph[n].clone())
.filter(|id| self.tasks.contains_key(id))
.collect()),
Err(cycle) => Err(ValidationError::Cycle(self.graph[cycle.node_id()].clone()).into()),
}
}
}