use crate::models::plan::Task;
use anyhow::{Context, Result};
use petgraph::graph::{DiGraph, NodeIndex};
use petgraph::visit::EdgeRef;
use petgraph::Direction;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DependencyEdge {
Blocks,
BlockedBy,
}
pub struct TaskGraph {
graph: DiGraph<Task, DependencyEdge>,
task_index: HashMap<String, NodeIndex>,
}
impl TaskGraph {
pub fn new() -> Self {
Self {
graph: DiGraph::new(),
task_index: HashMap::new(),
}
}
pub fn add_task(&mut self, task: Task) -> NodeIndex {
let task_id = task.id.clone();
let node = self.graph.add_node(task);
self.task_index.insert(task_id, node);
node
}
pub fn add_dependency(&mut self, task_a_id: &str, task_b_id: &str) -> Result<()> {
let node_a = self
.task_index
.get(task_a_id)
.copied()
.context(format!("Task not found: {}", task_a_id))?;
let node_b = self
.task_index
.get(task_b_id)
.copied()
.context(format!("Task not found: {}", task_b_id))?;
self.graph.add_edge(node_a, node_b, DependencyEdge::Blocks);
Ok(())
}
pub fn topological_sort(&self) -> Result<Vec<String>> {
use petgraph::algo::toposort;
let sorted_nodes = toposort(&self.graph, None)
.map_err(|cycle| anyhow::anyhow!("Cycle detected at node {:?}", cycle.node_id()))?;
let task_ids = sorted_nodes
.into_iter()
.map(|node| self.graph[node].id.clone())
.collect();
Ok(task_ids)
}
pub fn detect_cycles(&self) -> Vec<Vec<String>> {
use petgraph::algo::kosaraju_scc;
let sccs = kosaraju_scc(&self.graph);
sccs.into_iter()
.filter(|scc| scc.len() > 1)
.map(|scc| {
scc.into_iter()
.map(|node| self.graph[node].id.clone())
.collect()
})
.collect()
}
pub fn critical_path(&self) -> Result<Vec<String>> {
if !self.detect_cycles().is_empty() {
return Err(anyhow::anyhow!(
"Cannot compute critical path: graph contains cycles"
));
}
let topo_order = self.topological_sort()?;
let id_to_node: HashMap<_, _> = self
.task_index
.iter()
.map(|(id, &node)| (id.clone(), node))
.collect();
let mut distances: HashMap<NodeIndex, f64> = HashMap::new();
let mut predecessors: HashMap<NodeIndex, NodeIndex> = HashMap::new();
for node in self.graph.node_indices() {
distances.insert(node, 0.0);
}
for task_id in topo_order {
let node = id_to_node[&task_id];
let task = &self.graph[node];
let duration = Self::parse_duration(&task.duration);
let current_distance = distances[&node];
for edge in self.graph.edges_directed(node, Direction::Outgoing) {
let successor = edge.target();
let new_distance = current_distance + duration;
if new_distance > distances[&successor] {
distances.insert(successor, new_distance);
predecessors.insert(successor, node);
}
}
}
let (&end_node, _) = distances
.iter()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.context("No tasks in graph")?;
let mut path = vec![end_node];
let mut current = end_node;
while let Some(&pred) = predecessors.get(¤t) {
path.push(pred);
current = pred;
}
path.reverse();
let critical_path_ids = path
.into_iter()
.map(|node| self.graph[node].id.clone())
.collect();
Ok(critical_path_ids)
}
fn parse_duration(duration: &Option<String>) -> f64 {
duration
.as_ref()
.and_then(|s| {
let num_str: String = s.chars().take_while(|c| c.is_numeric()).collect();
num_str.parse::<f64>().ok()
})
.unwrap_or(1.0) }
pub fn get_task(&self, task_id: &str) -> Option<&Task> {
self.task_index.get(task_id).map(|&node| &self.graph[node])
}
pub fn tasks(&self) -> Vec<&Task> {
self.graph.node_weights().collect()
}
pub fn dependencies(&self, task_id: &str) -> Vec<String> {
self.task_index
.get(task_id)
.map(|&node| {
self.graph
.edges_directed(node, Direction::Incoming)
.map(|edge| self.graph[edge.source()].id.clone())
.collect()
})
.unwrap_or_default()
}
pub fn dependents(&self, task_id: &str) -> Vec<String> {
self.task_index
.get(task_id)
.map(|&node| {
self.graph
.edges_directed(node, Direction::Outgoing)
.map(|edge| self.graph[edge.target()].id.clone())
.collect()
})
.unwrap_or_default()
}
pub fn len(&self) -> usize {
self.graph.node_count()
}
pub fn is_empty(&self) -> bool {
self.graph.node_count() == 0
}
}
impl Default for TaskGraph {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_task(id: &str, title: &str, duration: Option<&str>) -> Task {
Task {
id: id.to_string(),
title: title.to_string(),
description: None,
priority: None,
duration: duration.map(|s| s.to_string()),
difficulty: None,
crate_name: None,
issue: None,
}
}
#[test]
fn test_add_task() {
let mut graph = TaskGraph::new();
let task = create_test_task("T1", "Task 1", None);
graph.add_task(task);
assert_eq!(graph.len(), 1);
assert!(graph.get_task("T1").is_some());
}
#[test]
fn test_topological_sort_simple() {
let mut graph = TaskGraph::new();
graph.add_task(create_test_task("T1", "Task 1", None));
graph.add_task(create_test_task("T2", "Task 2", None));
graph.add_task(create_test_task("T3", "Task 3", None));
graph.add_dependency("T1", "T2").unwrap();
graph.add_dependency("T2", "T3").unwrap();
let sorted = graph.topological_sort().unwrap();
let t1_pos = sorted.iter().position(|id| id == "T1").unwrap();
let t2_pos = sorted.iter().position(|id| id == "T2").unwrap();
let t3_pos = sorted.iter().position(|id| id == "T3").unwrap();
assert!(t1_pos < t2_pos);
assert!(t2_pos < t3_pos);
}
#[test]
fn test_detect_cycles_none() {
let mut graph = TaskGraph::new();
graph.add_task(create_test_task("T1", "Task 1", None));
graph.add_task(create_test_task("T2", "Task 2", None));
graph.add_dependency("T1", "T2").unwrap();
let cycles = graph.detect_cycles();
assert!(cycles.is_empty());
}
#[test]
fn test_detect_cycles_simple() {
let mut graph = TaskGraph::new();
graph.add_task(create_test_task("T1", "Task 1", None));
graph.add_task(create_test_task("T2", "Task 2", None));
graph.add_dependency("T1", "T2").unwrap();
graph.add_dependency("T2", "T1").unwrap();
let cycles = graph.detect_cycles();
assert_eq!(cycles.len(), 1);
assert_eq!(cycles[0].len(), 2);
}
#[test]
fn test_parse_duration() {
assert_eq!(TaskGraph::parse_duration(&Some("3-4h".to_string())), 3.0);
assert_eq!(TaskGraph::parse_duration(&Some("2h".to_string())), 2.0);
assert_eq!(TaskGraph::parse_duration(&Some("10-12h".to_string())), 10.0);
assert_eq!(TaskGraph::parse_duration(&None), 1.0);
}
}