use petgraph::algo::{is_cyclic_directed, toposort};
use petgraph::graph::{DiGraph, NodeIndex};
use petgraph::visit::EdgeRef;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct TaskNode {
pub id: String,
pub name: String,
pub description: Option<String>,
pub source_location: Option<String>,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct DependencyEdge {
pub dependency_type: String,
pub weight: Option<f64>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl Default for DependencyEdge {
fn default() -> Self {
Self {
dependency_type: "data".to_string(),
weight: None,
metadata: HashMap::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct WorkflowGraph {
graph: DiGraph<TaskNode, DependencyEdge>,
task_index: HashMap<String, NodeIndex>,
}
impl WorkflowGraph {
pub fn new() -> Self {
Self {
graph: DiGraph::new(),
task_index: HashMap::new(),
}
}
pub fn add_task(&mut self, node: TaskNode) -> NodeIndex {
let task_id = node.id.clone();
let index = self.graph.add_node(node);
self.task_index.insert(task_id, index);
index
}
pub fn add_dependency(
&mut self,
from_task_id: &str,
to_task_id: &str,
edge: DependencyEdge,
) -> Result<(), String> {
let from_index = self
.task_index
.get(from_task_id)
.ok_or_else(|| format!("Task '{}' not found in graph", from_task_id))?;
let to_index = self
.task_index
.get(to_task_id)
.ok_or_else(|| format!("Task '{}' not found in graph", to_task_id))?;
self.graph.add_edge(*from_index, *to_index, edge);
Ok(())
}
pub fn get_task(&self, task_id: &str) -> Option<&TaskNode> {
self.task_index
.get(task_id)
.and_then(|&index| self.graph.node_weight(index))
}
pub fn task_ids(&self) -> impl Iterator<Item = &str> {
self.task_index.keys().map(|s| s.as_str())
}
pub fn task_count(&self) -> usize {
self.task_index.len()
}
pub fn has_cycles(&self) -> bool {
is_cyclic_directed(&self.graph)
}
pub fn topological_sort(&self) -> Result<Vec<String>, String> {
match toposort(&self.graph, None) {
Ok(indices) => Ok(indices
.into_iter()
.filter_map(|idx| self.graph.node_weight(idx).map(|n| n.id.clone()))
.collect()),
Err(_) => Err("Graph contains cycles".to_string()),
}
}
pub fn get_dependencies(&self, task_id: &str) -> impl Iterator<Item = &str> {
self.task_index
.get(task_id)
.into_iter()
.flat_map(|&node_idx| {
self.graph
.edges_directed(node_idx, petgraph::Direction::Outgoing)
.filter_map(|edge| self.graph.node_weight(edge.target()).map(|n| n.id.as_str()))
})
}
pub fn get_dependents(&self, task_id: &str) -> impl Iterator<Item = &str> {
self.task_index
.get(task_id)
.into_iter()
.flat_map(|&node_idx| {
self.graph
.edges_directed(node_idx, petgraph::Direction::Incoming)
.filter_map(|edge| self.graph.node_weight(edge.source()).map(|n| n.id.as_str()))
})
}
pub fn find_roots(&self) -> impl Iterator<Item = &str> {
self.graph.node_indices().filter_map(|idx| {
let has_no_deps = self
.graph
.edges_directed(idx, petgraph::Direction::Outgoing)
.next()
.is_none();
if has_no_deps {
self.graph.node_weight(idx).map(|n| n.id.as_str())
} else {
None
}
})
}
pub fn find_leaves(&self) -> impl Iterator<Item = &str> {
self.graph.node_indices().filter_map(|idx| {
let has_no_dependents = self
.graph
.edges_directed(idx, petgraph::Direction::Incoming)
.next()
.is_none();
if has_no_dependents {
self.graph.node_weight(idx).map(|n| n.id.as_str())
} else {
None
}
})
}
pub fn calculate_depths(&self) -> HashMap<String, usize> {
let mut depths = HashMap::new();
for root_id in self.find_roots() {
depths.insert(root_id.to_string(), 0);
}
let topo_order = self.topological_sort().unwrap_or_default();
for task_id in topo_order {
if let Some(&node_idx) = self.task_index.get(&task_id) {
let mut max_dep_depth = 0;
let mut has_dependencies = false;
for edge in self
.graph
.edges_directed(node_idx, petgraph::Direction::Incoming)
{
if let Some(dependency) = self.graph.node_weight(edge.source()) {
has_dependencies = true;
if let Some(&dep_depth) = depths.get(&dependency.id) {
max_dep_depth = max_dep_depth.max(dep_depth);
}
}
}
let task_depth = if has_dependencies {
max_dep_depth + 1
} else {
0
};
depths.insert(task_id, task_depth);
}
}
depths
}
pub fn find_parallel_groups(&self) -> Vec<Vec<String>> {
let depths = self.calculate_depths();
let mut groups: HashMap<usize, Vec<String>> = HashMap::new();
for (task_id, depth) in depths {
groups.entry(depth).or_default().push(task_id);
}
let mut result: Vec<Vec<String>> = groups.into_values().collect();
result.sort_by_key(|group| group.len());
result
}
pub fn to_serializable(&self) -> WorkflowGraphData {
let nodes: Vec<GraphNode> = self
.graph
.node_indices()
.filter_map(|idx| {
self.graph.node_weight(idx).map(|node| GraphNode {
id: node.id.clone(),
data: node.clone(),
})
})
.collect();
let edges: Vec<GraphEdge> = self
.graph
.edge_indices()
.filter_map(|idx| {
let (source, target) = self.graph.edge_endpoints(idx)?;
let source_node = self.graph.node_weight(source)?;
let target_node = self.graph.node_weight(target)?;
let edge_data = self.graph.edge_weight(idx)?;
Some(GraphEdge {
from: source_node.id.clone(),
to: target_node.id.clone(),
data: edge_data.clone(),
})
})
.collect();
let metadata = GraphMetadata {
task_count: nodes.len(),
edge_count: edges.len(),
has_cycles: self.has_cycles(),
depth_levels: self.calculate_depths().values().max().copied().unwrap_or(0) + 1,
root_tasks: self.find_roots().map(|s| s.to_string()).collect(),
leaf_tasks: self.find_leaves().map(|s| s.to_string()).collect(),
};
WorkflowGraphData {
nodes,
edges,
metadata,
}
}
pub fn from_serializable(data: &WorkflowGraphData) -> Result<Self, String> {
let mut graph = WorkflowGraph::new();
for node in &data.nodes {
graph.add_task(node.data.clone());
}
for edge in &data.edges {
graph.add_dependency(&edge.from, &edge.to, edge.data.clone())?;
}
Ok(graph)
}
}
impl Default for WorkflowGraph {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowGraphData {
pub nodes: Vec<GraphNode>,
pub edges: Vec<GraphEdge>,
pub metadata: GraphMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphNode {
pub id: String,
pub data: TaskNode,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphEdge {
pub from: String,
pub to: String,
pub data: DependencyEdge,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphMetadata {
pub task_count: usize,
pub edge_count: usize,
pub has_cycles: bool,
pub depth_levels: usize,
pub root_tasks: Vec<String>,
pub leaf_tasks: Vec<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_workflow_graph_creation() {
let mut graph = WorkflowGraph::new();
let task1 = TaskNode {
id: "task1".to_string(),
name: "Task 1".to_string(),
description: None,
source_location: None,
metadata: HashMap::new(),
};
let task2 = TaskNode {
id: "task2".to_string(),
name: "Task 2".to_string(),
description: None,
source_location: None,
metadata: HashMap::new(),
};
graph.add_task(task1);
graph.add_task(task2);
graph
.add_dependency("task2", "task1", DependencyEdge::default())
.unwrap();
assert_eq!(graph.task_count(), 2);
assert!(!graph.has_cycles());
assert_eq!(
graph.get_dependencies("task2").collect::<Vec<_>>(),
vec!["task1"]
);
assert_eq!(
graph.get_dependents("task1").collect::<Vec<_>>(),
vec!["task2"]
);
}
#[test]
fn test_parallel_groups() {
let mut graph = WorkflowGraph::new();
for id in ["root", "a", "b", "end"] {
graph.add_task(TaskNode {
id: id.to_string(),
name: id.to_string(),
description: None,
source_location: None,
metadata: HashMap::new(),
});
}
graph
.add_dependency("root", "a", DependencyEdge::default())
.unwrap();
graph
.add_dependency("root", "b", DependencyEdge::default())
.unwrap();
graph
.add_dependency("a", "end", DependencyEdge::default())
.unwrap();
graph
.add_dependency("b", "end", DependencyEdge::default())
.unwrap();
let groups = graph.find_parallel_groups();
assert_eq!(groups.len(), 3); }
#[test]
fn test_serialization() {
let mut graph = WorkflowGraph::new();
graph.add_task(TaskNode {
id: "test".to_string(),
name: "Test Task".to_string(),
description: Some("A test task".to_string()),
source_location: Some("test.rs:42".to_string()),
metadata: HashMap::new(),
});
let serializable = graph.to_serializable();
let json = serde_json::to_string(&serializable).unwrap();
let deserialized: WorkflowGraphData = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.nodes.len(), 1);
assert_eq!(deserialized.metadata.task_count, 1);
}
#[test]
fn test_task_count() {
let mut graph = WorkflowGraph::new();
assert_eq!(graph.task_count(), 0);
graph.add_task(TaskNode {
id: "task1".to_string(),
name: "Task 1".to_string(),
description: None,
source_location: None,
metadata: HashMap::new(),
});
assert_eq!(graph.task_count(), 1);
graph.add_task(TaskNode {
id: "task2".to_string(),
name: "Task 2".to_string(),
description: None,
source_location: None,
metadata: HashMap::new(),
});
assert_eq!(graph.task_count(), 2);
}
#[test]
fn test_task_ids_iterator() {
let mut graph = WorkflowGraph::new();
for id in ["a", "b", "c"] {
graph.add_task(TaskNode {
id: id.to_string(),
name: id.to_string(),
description: None,
source_location: None,
metadata: HashMap::new(),
});
}
let ids: Vec<&str> = graph.task_ids().collect();
assert_eq!(ids.len(), 3);
assert!(ids.contains(&"a"));
assert!(ids.contains(&"b"));
assert!(ids.contains(&"c"));
}
}