use crate::{TaskError, TaskId, TaskResult};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TaskNodeStatus {
Pending,
Ready,
Running,
Completed,
Failed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskNode {
id: TaskId,
dependencies: Vec<TaskId>,
status: TaskNodeStatus,
}
impl TaskNode {
pub fn new(id: TaskId) -> Self {
Self {
id,
dependencies: Vec::new(),
status: TaskNodeStatus::Pending,
}
}
pub fn id(&self) -> TaskId {
self.id
}
pub fn dependencies(&self) -> &[TaskId] {
&self.dependencies
}
pub fn status(&self) -> TaskNodeStatus {
self.status
}
pub fn add_dependency(&mut self, task_id: TaskId) {
if !self.dependencies.contains(&task_id) {
self.dependencies.push(task_id);
}
}
pub fn set_status(&mut self, status: TaskNodeStatus) {
self.status = status;
}
pub(crate) fn remove_dependency(&mut self, task_id: TaskId) {
self.dependencies.retain(|&id| id != task_id);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskDAG {
nodes: HashMap<TaskId, TaskNode>,
dependents: HashMap<TaskId, Vec<TaskId>>,
}
impl TaskDAG {
pub fn new() -> Self {
Self {
nodes: HashMap::new(),
dependents: HashMap::new(),
}
}
pub fn add_task(&mut self, task_id: TaskId) -> TaskResult<()> {
if self.nodes.contains_key(&task_id) {
return Err(TaskError::ExecutionFailed(format!(
"Task {} already exists in DAG",
task_id
)));
}
self.nodes.insert(task_id, TaskNode::new(task_id));
self.dependents.insert(task_id, Vec::new());
Ok(())
}
pub fn add_dependency(&mut self, task_id: TaskId, depends_on: TaskId) -> TaskResult<()> {
if !self.nodes.contains_key(&task_id) {
return Err(TaskError::TaskNotFound(task_id.to_string()));
}
if !self.nodes.contains_key(&depends_on) {
return Err(TaskError::TaskNotFound(depends_on.to_string()));
}
if let Some(node) = self.nodes.get_mut(&task_id) {
node.add_dependency(depends_on);
}
if let Some(deps) = self.dependents.get_mut(&depends_on)
&& !deps.contains(&task_id)
{
deps.push(task_id);
}
if let Err(e) = self.detect_cycle() {
if let Some(node) = self.nodes.get_mut(&task_id) {
node.remove_dependency(depends_on);
}
if let Some(deps) = self.dependents.get_mut(&depends_on) {
deps.retain(|&id| id != task_id);
}
return Err(e);
}
Ok(())
}
pub fn task_count(&self) -> usize {
self.nodes.len()
}
pub fn get_task(&self, task_id: TaskId) -> Option<&TaskNode> {
self.nodes.get(&task_id)
}
pub fn get_ready_tasks(&self) -> Vec<TaskId> {
self.nodes
.values()
.filter(|node| {
node.status() == TaskNodeStatus::Pending
&& node
.dependencies()
.iter()
.all(|dep_id| match self.nodes.get(dep_id) {
Some(dep_node) => dep_node.status() == TaskNodeStatus::Completed,
None => false,
})
})
.map(|node| node.id())
.collect()
}
pub fn mark_completed(&mut self, task_id: TaskId) -> TaskResult<()> {
let node = self
.nodes
.get_mut(&task_id)
.ok_or_else(|| TaskError::TaskNotFound(task_id.to_string()))?;
node.set_status(TaskNodeStatus::Completed);
Ok(())
}
pub fn mark_failed(&mut self, task_id: TaskId) -> TaskResult<()> {
let node = self
.nodes
.get_mut(&task_id)
.ok_or_else(|| TaskError::TaskNotFound(task_id.to_string()))?;
node.set_status(TaskNodeStatus::Failed);
Ok(())
}
pub fn mark_running(&mut self, task_id: TaskId) -> TaskResult<()> {
let node = self
.nodes
.get_mut(&task_id)
.ok_or_else(|| TaskError::TaskNotFound(task_id.to_string()))?;
node.set_status(TaskNodeStatus::Running);
Ok(())
}
pub fn topological_sort(&self) -> TaskResult<Vec<TaskId>> {
let mut in_degree: HashMap<TaskId, usize> = HashMap::new();
for (task_id, node) in &self.nodes {
in_degree.insert(*task_id, node.dependencies().len());
}
let mut queue: VecDeque<TaskId> = in_degree
.iter()
.filter(|(_, degree)| **degree == 0)
.map(|(task_id, _)| *task_id)
.collect();
let mut sorted = Vec::new();
while let Some(task_id) = queue.pop_front() {
sorted.push(task_id);
if let Some(deps) = self.dependents.get(&task_id) {
for &dependent in deps {
if let Some(degree) = in_degree.get_mut(&dependent) {
*degree -= 1;
if *degree == 0 {
queue.push_back(dependent);
}
}
}
}
}
if sorted.len() != self.nodes.len() {
return Err(TaskError::ExecutionFailed(
"Cycle detected in task dependencies".to_string(),
));
}
Ok(sorted)
}
fn detect_cycle(&self) -> TaskResult<()> {
let mut visited = HashSet::new();
let mut rec_stack = HashSet::new();
for &start_id in self.nodes.keys() {
if visited.contains(&start_id) {
continue;
}
let mut stack: Vec<(TaskId, usize, bool)> = vec![(start_id, 0, true)];
while let Some((task_id, dep_idx, is_entering)) = stack.last_mut() {
if *is_entering {
visited.insert(*task_id);
rec_stack.insert(*task_id);
*is_entering = false;
}
let deps = self
.nodes
.get(task_id)
.map(|n| n.dependencies())
.unwrap_or(&[]);
if *dep_idx < deps.len() {
let dep_id = deps[*dep_idx];
*dep_idx += 1;
if rec_stack.contains(&dep_id) {
return Err(TaskError::ExecutionFailed(format!(
"Cycle detected: {} -> {}",
task_id, dep_id
)));
}
if !visited.contains(&dep_id) {
stack.push((dep_id, 0, true));
}
} else {
rec_stack.remove(task_id);
stack.pop();
}
}
}
Ok(())
}
}
impl Default for TaskDAG {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[rstest]
fn test_dag_creation() {
let dag = TaskDAG::new();
assert_eq!(dag.task_count(), 0);
}
#[rstest]
fn test_add_task() {
let mut dag = TaskDAG::new();
let task_id = TaskId::new();
dag.add_task(task_id).unwrap();
assert_eq!(dag.task_count(), 1);
assert!(dag.get_task(task_id).is_some());
}
#[rstest]
fn test_add_duplicate_task() {
let mut dag = TaskDAG::new();
let task_id = TaskId::new();
dag.add_task(task_id).unwrap();
let result = dag.add_task(task_id);
assert!(result.is_err());
}
#[rstest]
fn test_add_dependency() {
let mut dag = TaskDAG::new();
let task_a = TaskId::new();
let task_b = TaskId::new();
dag.add_task(task_a).unwrap();
dag.add_task(task_b).unwrap();
dag.add_dependency(task_b, task_a).unwrap();
let node_b = dag.get_task(task_b).unwrap();
assert_eq!(node_b.dependencies().len(), 1);
assert_eq!(node_b.dependencies()[0], task_a);
}
#[rstest]
fn test_add_dependency_nonexistent_task() {
let mut dag = TaskDAG::new();
let task_a = TaskId::new();
let task_b = TaskId::new();
dag.add_task(task_a).unwrap();
let result = dag.add_dependency(task_a, task_b);
assert!(result.is_err());
}
#[rstest]
fn test_cycle_detection() {
let mut dag = TaskDAG::new();
let task_a = TaskId::new();
let task_b = TaskId::new();
let task_c = TaskId::new();
dag.add_task(task_a).unwrap();
dag.add_task(task_b).unwrap();
dag.add_task(task_c).unwrap();
dag.add_dependency(task_b, task_a).unwrap();
dag.add_dependency(task_c, task_b).unwrap();
let result = dag.add_dependency(task_a, task_c);
assert!(result.is_err());
}
#[rstest]
fn test_topological_sort_simple() {
let mut dag = TaskDAG::new();
let task_a = TaskId::new();
let task_b = TaskId::new();
let task_c = TaskId::new();
dag.add_task(task_a).unwrap();
dag.add_task(task_b).unwrap();
dag.add_task(task_c).unwrap();
dag.add_dependency(task_b, task_a).unwrap();
dag.add_dependency(task_c, task_b).unwrap();
let order = dag.topological_sort().unwrap();
assert_eq!(order.len(), 3);
let a_pos = order.iter().position(|&id| id == task_a).unwrap();
let b_pos = order.iter().position(|&id| id == task_b).unwrap();
let c_pos = order.iter().position(|&id| id == task_c).unwrap();
assert!(a_pos < b_pos);
assert!(b_pos < c_pos);
}
#[rstest]
fn test_topological_sort_diamond() {
let mut dag = TaskDAG::new();
let task_a = TaskId::new();
let task_b = TaskId::new();
let task_c = TaskId::new();
let task_d = TaskId::new();
dag.add_task(task_a).unwrap();
dag.add_task(task_b).unwrap();
dag.add_task(task_c).unwrap();
dag.add_task(task_d).unwrap();
dag.add_dependency(task_b, task_a).unwrap();
dag.add_dependency(task_c, task_a).unwrap();
dag.add_dependency(task_d, task_b).unwrap();
dag.add_dependency(task_d, task_c).unwrap();
let order = dag.topological_sort().unwrap();
assert_eq!(order.len(), 4);
let a_pos = order.iter().position(|&id| id == task_a).unwrap();
let b_pos = order.iter().position(|&id| id == task_b).unwrap();
let c_pos = order.iter().position(|&id| id == task_c).unwrap();
let d_pos = order.iter().position(|&id| id == task_d).unwrap();
assert!(a_pos < b_pos);
assert!(a_pos < c_pos);
assert!(b_pos < d_pos);
assert!(c_pos < d_pos);
}
#[rstest]
fn test_get_ready_tasks() {
let mut dag = TaskDAG::new();
let task_a = TaskId::new();
let task_b = TaskId::new();
let task_c = TaskId::new();
dag.add_task(task_a).unwrap();
dag.add_task(task_b).unwrap();
dag.add_task(task_c).unwrap();
dag.add_dependency(task_b, task_a).unwrap();
dag.add_dependency(task_c, task_b).unwrap();
let ready = dag.get_ready_tasks();
assert_eq!(ready.len(), 1);
assert!(ready.contains(&task_a));
dag.mark_completed(task_a).unwrap();
let ready = dag.get_ready_tasks();
assert_eq!(ready.len(), 1);
assert!(ready.contains(&task_b));
dag.mark_completed(task_b).unwrap();
let ready = dag.get_ready_tasks();
assert_eq!(ready.len(), 1);
assert!(ready.contains(&task_c));
}
#[rstest]
fn test_mark_status() {
let mut dag = TaskDAG::new();
let task_id = TaskId::new();
dag.add_task(task_id).unwrap();
assert_eq!(
dag.get_task(task_id).unwrap().status(),
TaskNodeStatus::Pending
);
dag.mark_running(task_id).unwrap();
assert_eq!(
dag.get_task(task_id).unwrap().status(),
TaskNodeStatus::Running
);
dag.mark_completed(task_id).unwrap();
assert_eq!(
dag.get_task(task_id).unwrap().status(),
TaskNodeStatus::Completed
);
}
#[rstest]
fn test_mark_failed() {
let mut dag = TaskDAG::new();
let task_id = TaskId::new();
dag.add_task(task_id).unwrap();
dag.mark_failed(task_id).unwrap();
assert_eq!(
dag.get_task(task_id).unwrap().status(),
TaskNodeStatus::Failed
);
}
#[rstest]
fn test_parallel_execution_detection() {
let mut dag = TaskDAG::new();
let task_a = TaskId::new();
let task_b = TaskId::new();
let task_c = TaskId::new();
let task_d = TaskId::new();
dag.add_task(task_a).unwrap();
dag.add_task(task_b).unwrap();
dag.add_task(task_c).unwrap();
dag.add_task(task_d).unwrap();
dag.add_dependency(task_b, task_a).unwrap();
dag.add_dependency(task_c, task_a).unwrap();
dag.add_dependency(task_d, task_b).unwrap();
dag.add_dependency(task_d, task_c).unwrap();
dag.mark_completed(task_a).unwrap();
let ready = dag.get_ready_tasks();
assert_eq!(ready.len(), 2);
assert!(ready.contains(&task_b));
assert!(ready.contains(&task_c));
}
#[rstest]
fn test_deep_dependency_chain_does_not_stack_overflow() {
let mut dag = TaskDAG::new();
let depth = 1000;
let mut task_ids = Vec::with_capacity(depth);
for _ in 0..depth {
let id = TaskId::new();
dag.add_task(id).unwrap();
task_ids.push(id);
}
for i in 1..depth {
dag.add_dependency(task_ids[i], task_ids[i - 1]).unwrap();
}
let order = dag.topological_sort().unwrap();
assert_eq!(order.len(), depth);
for i in 1..depth {
let dep_pos = order.iter().position(|&id| id == task_ids[i - 1]).unwrap();
let task_pos = order.iter().position(|&id| id == task_ids[i]).unwrap();
assert!(dep_pos < task_pos);
}
}
#[rstest]
fn test_cycle_detection_on_deep_chain_with_back_edge() {
let mut dag = TaskDAG::new();
let depth = 500;
let mut task_ids = Vec::with_capacity(depth);
for _ in 0..depth {
let id = TaskId::new();
dag.add_task(id).unwrap();
task_ids.push(id);
}
for i in 1..depth {
dag.add_dependency(task_ids[i], task_ids[i - 1]).unwrap();
}
let result = dag.add_dependency(task_ids[0], task_ids[depth - 1]);
assert!(result.is_err());
}
#[rstest]
fn test_deep_chain_10k_nodes_does_not_stack_overflow() {
let mut dag = TaskDAG::new();
let depth = 10_000;
let mut task_ids = Vec::with_capacity(depth);
for _ in 0..depth {
let id = TaskId::new();
dag.add_task(id).unwrap();
task_ids.push(id);
}
for i in 1..depth {
dag.add_dependency(task_ids[i], task_ids[i - 1]).unwrap();
}
let order = dag.topological_sort().unwrap();
assert_eq!(order.len(), depth);
for i in 1..depth {
let prev_pos = order.iter().position(|&id| id == task_ids[i - 1]).unwrap();
let curr_pos = order.iter().position(|&id| id == task_ids[i]).unwrap();
assert!(
prev_pos < curr_pos,
"task_ids[{}] must precede task_ids[{}]",
i - 1,
i
);
}
}
#[rstest]
fn test_deep_chain_10k_nodes_back_edge_detected() {
let mut dag = TaskDAG::new();
let depth = 10_000;
let mut task_ids = Vec::with_capacity(depth);
for _ in 0..depth {
let id = TaskId::new();
dag.add_task(id).unwrap();
task_ids.push(id);
}
for i in 1..depth {
dag.add_dependency(task_ids[i], task_ids[i - 1]).unwrap();
}
let result = dag.add_dependency(task_ids[0], task_ids[depth - 1]);
assert!(
result.is_err(),
"back-edge cycle must be detected in 10k chain"
);
}
}