#![allow(dead_code)]
use crate::agent::AgentResult;
use crate::config::AgentConfig;
use anyhow::{bail, Result};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "state", rename_all = "snake_case")]
pub enum TaskStatus {
Pending,
Running,
Completed,
Failed {
error: String,
retries: u32,
},
Cancelled,
}
impl TaskStatus {
pub fn is_terminal(&self) -> bool {
matches!(
self,
TaskStatus::Completed | TaskStatus::Cancelled | TaskStatus::Failed { .. }
)
}
pub fn is_completed(&self) -> bool {
matches!(self, TaskStatus::Completed)
}
pub fn is_failed(&self) -> bool {
matches!(self, TaskStatus::Failed { .. })
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskNode {
pub id: String,
pub description: String,
pub status: TaskStatus,
pub agent_config: AgentConfig,
pub depends_on: Vec<String>,
pub retry_count: u32,
#[serde(skip)]
pub result: Option<AgentResult>,
}
impl TaskNode {
pub fn new(id: impl Into<String>, description: impl Into<String>) -> Self {
TaskNode {
id: id.into(),
description: description.into(),
status: TaskStatus::Pending,
agent_config: AgentConfig::default(),
depends_on: Vec::new(),
retry_count: 0,
result: None,
}
}
pub fn with_dependency(mut self, dep_id: impl Into<String>) -> Self {
self.depends_on.push(dep_id.into());
self
}
pub fn with_config(mut self, config: AgentConfig) -> Self {
self.agent_config = config;
self
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = description.into();
self
}
}
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct TaskGraph {
nodes: Vec<TaskNode>,
#[serde(skip)]
id_to_index: HashMap<String, usize>,
}
impl TaskGraph {
pub fn new() -> Self {
TaskGraph::default()
}
pub fn add_task(&mut self, node: TaskNode) -> Result<()> {
if self.id_to_index.contains_key(&node.id) {
bail!("Duplicate task ID: '{}'", node.id);
}
for dep_id in &node.depends_on {
if !self.id_to_index.contains_key(dep_id) {
bail!(
"Task '{}' depends on '{}' which has not been added yet. \
Add dependencies before dependents.",
node.id,
dep_id
);
}
}
let idx = self.nodes.len();
self.id_to_index.insert(node.id.clone(), idx);
self.nodes.push(node);
self.check_no_cycles()?;
Ok(())
}
pub fn mark_running(&mut self, id: &str) -> Result<()> {
let node = self.get_mut(id)?;
if node.status != TaskStatus::Pending {
bail!(
"Cannot mark '{}' as Running — current state is {:?}",
id,
node.status
);
}
node.status = TaskStatus::Running;
Ok(())
}
pub fn mark_completed(&mut self, id: &str, result: AgentResult) -> Result<()> {
let node = self.get_mut(id)?;
node.status = TaskStatus::Completed;
node.result = Some(result);
Ok(())
}
pub fn mark_failed(&mut self, id: &str, error: String) -> Result<()> {
let node = self.get_mut(id)?;
node.retry_count += 1;
let retries = node.retry_count - 1; node.status = TaskStatus::Failed { error, retries };
Ok(())
}
pub fn mark_cancelled(&mut self, id: &str) -> Result<()> {
let node = self.get_mut(id)?;
node.status = TaskStatus::Cancelled;
Ok(())
}
pub fn reset_for_retry(&mut self, id: &str) -> Result<()> {
let node = self.get_mut(id)?;
if !node.status.is_failed() {
bail!("Cannot reset '{}' for retry — not in Failed state", id);
}
node.status = TaskStatus::Pending;
Ok(())
}
pub fn nodes(&self) -> impl Iterator<Item = &TaskNode> {
self.nodes.iter()
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn get(&self, id: &str) -> Option<&TaskNode> {
self.id_to_index.get(id).map(|&i| &self.nodes[i])
}
pub fn topological_sort(&self) -> Result<Vec<String>> {
let mut in_degree = vec![0usize; self.nodes.len()];
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); self.nodes.len()];
for (idx, node) in self.nodes.iter().enumerate() {
for dep_id in &node.depends_on {
let dep_idx = self.id_to_index[dep_id];
adj[dep_idx].push(idx); in_degree[idx] += 1;
}
}
let mut queue: VecDeque<usize> = in_degree
.iter()
.enumerate()
.filter(|(_, &d)| d == 0)
.map(|(i, _)| i)
.collect();
let mut result = Vec::with_capacity(self.nodes.len());
while let Some(idx) = queue.pop_front() {
result.push(self.nodes[idx].id.clone());
for &dep_idx in &adj[idx] {
in_degree[dep_idx] -= 1;
if in_degree[dep_idx] == 0 {
queue.push_back(dep_idx);
}
}
}
if result.len() != self.nodes.len() {
bail!("Cycle detected in task graph — this is a bug in add_task validation");
}
Ok(result)
}
pub fn compute_waves(&self) -> Result<Vec<Vec<String>>> {
if self.nodes.is_empty() {
return Ok(Vec::new());
}
let mut wave_num = vec![0usize; self.nodes.len()];
let topo_ids = self.topological_sort()?;
for id in &topo_ids {
let idx = self.id_to_index[id];
let max_dep_wave = self.nodes[idx]
.depends_on
.iter()
.map(|dep_id| wave_num[self.id_to_index[dep_id]])
.max()
.unwrap_or(0);
wave_num[idx] = if self.nodes[idx].depends_on.is_empty() {
0
} else {
max_dep_wave + 1
};
}
let max_wave = *wave_num.iter().max().unwrap_or(&0);
let mut waves: Vec<Vec<String>> = vec![Vec::new(); max_wave + 1];
for (idx, node) in self.nodes.iter().enumerate() {
waves[wave_num[idx]].push(node.id.clone());
}
let topo_pos: HashMap<&str, usize> = topo_ids
.iter()
.enumerate()
.map(|(pos, id)| (id.as_str(), pos))
.collect();
for wave in &mut waves {
wave.sort_by_key(|id| topo_pos[id.as_str()]);
}
Ok(waves)
}
pub fn next_ready(&self) -> Vec<&TaskNode> {
self.nodes
.iter()
.filter(|node| {
node.status == TaskStatus::Pending
&& node.depends_on.iter().all(|dep_id| {
self.get(dep_id)
.map(|dep| dep.status.is_completed())
.unwrap_or(false)
})
})
.collect()
}
pub fn is_finished(&self) -> bool {
self.nodes.iter().all(|n| n.status.is_terminal())
}
pub fn is_all_completed(&self) -> bool {
self.nodes.iter().all(|n| n.status.is_completed())
}
pub fn status_counts(&self) -> (usize, usize, usize, usize, usize) {
let mut counts = (0, 0, 0, 0, 0);
for node in &self.nodes {
match &node.status {
TaskStatus::Pending => counts.0 += 1,
TaskStatus::Running => counts.1 += 1,
TaskStatus::Completed => counts.2 += 1,
TaskStatus::Failed { .. } => counts.3 += 1,
TaskStatus::Cancelled => counts.4 += 1,
}
}
counts
}
fn get_mut(&mut self, id: &str) -> Result<&mut TaskNode> {
match self.id_to_index.get(id) {
Some(&i) => Ok(&mut self.nodes[i]),
None => bail!("Task ID '{}' not found in graph", id),
}
}
fn check_no_cycles(&self) -> Result<()> {
let sorted = self.topological_sort_internal();
if sorted.len() != self.nodes.len() {
bail!(
"Cycle detected after adding task '{}' — remove the circular dependency",
self.nodes.last().map(|n| n.id.as_str()).unwrap_or("?")
);
}
Ok(())
}
fn topological_sort_internal(&self) -> Vec<usize> {
let mut in_degree = vec![0usize; self.nodes.len()];
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); self.nodes.len()];
for (idx, node) in self.nodes.iter().enumerate() {
for dep_id in &node.depends_on {
let dep_idx = self.id_to_index[dep_id];
adj[dep_idx].push(idx);
in_degree[idx] += 1;
}
}
let mut queue: VecDeque<usize> = in_degree
.iter()
.enumerate()
.filter(|(_, &d)| d == 0)
.map(|(i, _)| i)
.collect();
let mut result = Vec::with_capacity(self.nodes.len());
while let Some(idx) = queue.pop_front() {
result.push(idx);
for &dep_idx in &adj[idx] {
in_degree[dep_idx] -= 1;
if in_degree[dep_idx] == 0 {
queue.push_back(dep_idx);
}
}
}
result
}
}
impl TaskGraph {
pub fn rebuild_index(&mut self) {
self.id_to_index.clear();
for (i, node) in self.nodes.iter().enumerate() {
self.id_to_index.insert(node.id.clone(), i);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
fn dummy_result() -> AgentResult {
use crate::tracking::SessionTracker;
AgentResult {
final_message: "done".to_string(),
iterations: 1,
tool_calls_total: 0,
auto_continues: 0,
tracker: SessionTracker::new("test-model"),
}
}
#[test]
fn test_empty_graph() {
let g = TaskGraph::new();
assert!(g.is_empty());
assert_eq!(g.len(), 0);
}
#[test]
fn test_add_single_task() {
let mut g = TaskGraph::new();
g.add_task(TaskNode::new("t1", "Do something")).unwrap();
assert_eq!(g.len(), 1);
let node = g.get("t1").unwrap();
assert_eq!(node.id, "t1");
assert_eq!(node.description, "Do something");
assert_eq!(node.status, TaskStatus::Pending);
}
#[test]
fn test_duplicate_id_rejected() {
let mut g = TaskGraph::new();
g.add_task(TaskNode::new("t1", "First")).unwrap();
let err = g.add_task(TaskNode::new("t1", "Duplicate")).unwrap_err();
assert!(err.to_string().contains("Duplicate task ID"));
}
#[test]
fn test_unknown_dependency_rejected() {
let mut g = TaskGraph::new();
let err = g
.add_task(TaskNode::new("t2", "Second").with_dependency("t1"))
.unwrap_err();
assert!(err.to_string().contains("has not been added yet"));
}
#[test]
fn test_topological_sort_linear_chain() {
let mut g = TaskGraph::new();
g.add_task(TaskNode::new("t1", "Step 1")).unwrap();
g.add_task(TaskNode::new("t2", "Step 2").with_dependency("t1"))
.unwrap();
g.add_task(TaskNode::new("t3", "Step 3").with_dependency("t2"))
.unwrap();
let order = g.topological_sort().unwrap();
let pos: HashMap<&str, usize> = order
.iter()
.enumerate()
.map(|(i, s)| (s.as_str(), i))
.collect();
assert!(pos["t1"] < pos["t2"]);
assert!(pos["t2"] < pos["t3"]);
}
#[test]
fn test_topological_sort_diamond() {
let mut g = TaskGraph::new();
g.add_task(TaskNode::new("t1", "Root")).unwrap();
g.add_task(TaskNode::new("t2", "Left").with_dependency("t1"))
.unwrap();
g.add_task(TaskNode::new("t3", "Right").with_dependency("t1"))
.unwrap();
g.add_task(
TaskNode::new("t4", "Merge")
.with_dependency("t2")
.with_dependency("t3"),
)
.unwrap();
let order = g.topological_sort().unwrap();
assert_eq!(order.len(), 4);
let pos: HashMap<&str, usize> = order
.iter()
.enumerate()
.map(|(i, s)| (s.as_str(), i))
.collect();
assert!(pos["t1"] < pos["t2"]);
assert!(pos["t1"] < pos["t3"]);
assert!(pos["t2"] < pos["t4"]);
assert!(pos["t3"] < pos["t4"]);
}
#[test]
fn test_topological_sort_empty() {
let g = TaskGraph::new();
let order = g.topological_sort().unwrap();
assert!(order.is_empty());
}
#[test]
fn test_compute_waves_no_deps() {
let mut g = TaskGraph::new();
g.add_task(TaskNode::new("a", "A")).unwrap();
g.add_task(TaskNode::new("b", "B")).unwrap();
g.add_task(TaskNode::new("c", "C")).unwrap();
let waves = g.compute_waves().unwrap();
assert_eq!(waves.len(), 1);
assert_eq!(waves[0].len(), 3);
let in_wave0: HashSet<&str> = waves[0].iter().map(|s| s.as_str()).collect();
assert!(in_wave0.contains("a"));
assert!(in_wave0.contains("b"));
assert!(in_wave0.contains("c"));
}
#[test]
fn test_compute_waves_linear_chain() {
let mut g = TaskGraph::new();
g.add_task(TaskNode::new("t1", "Step 1")).unwrap();
g.add_task(TaskNode::new("t2", "Step 2").with_dependency("t1"))
.unwrap();
g.add_task(TaskNode::new("t3", "Step 3").with_dependency("t2"))
.unwrap();
let waves = g.compute_waves().unwrap();
assert_eq!(waves.len(), 3);
assert_eq!(waves[0], vec!["t1"]);
assert_eq!(waves[1], vec!["t2"]);
assert_eq!(waves[2], vec!["t3"]);
}
#[test]
fn test_compute_waves_diamond() {
let mut g = TaskGraph::new();
g.add_task(TaskNode::new("t1", "Root")).unwrap();
g.add_task(TaskNode::new("t2", "Left").with_dependency("t1"))
.unwrap();
g.add_task(TaskNode::new("t3", "Right").with_dependency("t1"))
.unwrap();
g.add_task(
TaskNode::new("t4", "Merge")
.with_dependency("t2")
.with_dependency("t3"),
)
.unwrap();
let waves = g.compute_waves().unwrap();
assert_eq!(waves.len(), 3);
assert_eq!(waves[0], vec!["t1"]);
assert_eq!(waves[1].len(), 2);
assert!(waves[1].contains(&"t2".to_string()));
assert!(waves[1].contains(&"t3".to_string()));
assert_eq!(waves[2], vec!["t4"]);
}
#[test]
fn test_compute_waves_empty() {
let g = TaskGraph::new();
let waves = g.compute_waves().unwrap();
assert!(waves.is_empty());
}
#[test]
fn test_next_ready_initial_state() {
let mut g = TaskGraph::new();
g.add_task(TaskNode::new("t1", "Root")).unwrap();
g.add_task(TaskNode::new("t2", "Branch").with_dependency("t1"))
.unwrap();
g.add_task(TaskNode::new("t3", "Independent")).unwrap();
let ready: Vec<&str> = g.next_ready().iter().map(|n| n.id.as_str()).collect();
assert!(ready.contains(&"t1"));
assert!(ready.contains(&"t3"));
assert!(!ready.contains(&"t2"));
}
#[test]
fn test_next_ready_after_completion() {
let mut g = TaskGraph::new();
g.add_task(TaskNode::new("t1", "Root")).unwrap();
g.add_task(TaskNode::new("t2", "Next").with_dependency("t1"))
.unwrap();
g.mark_completed("t1", dummy_result()).unwrap();
let ready: Vec<&str> = g.next_ready().iter().map(|n| n.id.as_str()).collect();
assert_eq!(ready, vec!["t2"]);
}
#[test]
fn test_mark_running_and_completed() {
let mut g = TaskGraph::new();
g.add_task(TaskNode::new("t1", "Task")).unwrap();
g.mark_running("t1").unwrap();
assert_eq!(g.get("t1").unwrap().status, TaskStatus::Running);
g.mark_completed("t1", dummy_result()).unwrap();
assert_eq!(g.get("t1").unwrap().status, TaskStatus::Completed);
assert!(g.get("t1").unwrap().result.is_some());
}
#[test]
fn test_mark_failed_increments_retries() {
let mut g = TaskGraph::new();
g.add_task(TaskNode::new("t1", "Task")).unwrap();
g.mark_failed("t1", "network error".to_string()).unwrap();
match &g.get("t1").unwrap().status {
TaskStatus::Failed { retries, .. } => assert_eq!(*retries, 0),
_ => panic!("expected Failed"),
}
g.reset_for_retry("t1").unwrap();
g.mark_failed("t1", "timeout".to_string()).unwrap();
match &g.get("t1").unwrap().status {
TaskStatus::Failed { retries, .. } => assert_eq!(*retries, 1),
_ => panic!("expected Failed"),
}
}
#[test]
fn test_mark_cancelled() {
let mut g = TaskGraph::new();
g.add_task(TaskNode::new("t1", "Task")).unwrap();
g.mark_cancelled("t1").unwrap();
assert_eq!(g.get("t1").unwrap().status, TaskStatus::Cancelled);
}
#[test]
fn test_cannot_start_from_non_pending() {
let mut g = TaskGraph::new();
g.add_task(TaskNode::new("t1", "Task")).unwrap();
g.mark_running("t1").unwrap();
let err = g.mark_running("t1").unwrap_err();
assert!(err.to_string().contains("Cannot mark"));
}
#[test]
fn test_is_finished() {
let mut g = TaskGraph::new();
g.add_task(TaskNode::new("t1", "A")).unwrap();
g.add_task(TaskNode::new("t2", "B")).unwrap();
assert!(!g.is_finished());
g.mark_completed("t1", dummy_result()).unwrap();
assert!(!g.is_finished());
g.mark_cancelled("t2").unwrap();
assert!(g.is_finished());
}
#[test]
fn test_status_counts() {
let mut g = TaskGraph::new();
g.add_task(TaskNode::new("t1", "A")).unwrap();
g.add_task(TaskNode::new("t2", "B")).unwrap();
g.add_task(TaskNode::new("t3", "C")).unwrap();
g.mark_completed("t1", dummy_result()).unwrap();
g.mark_failed("t2", "oops".to_string()).unwrap();
let (pending, running, completed, failed, cancelled) = g.status_counts();
assert_eq!(pending, 1); assert_eq!(running, 0);
assert_eq!(completed, 1); assert_eq!(failed, 1); assert_eq!(cancelled, 0);
}
#[test]
fn test_serialisation_roundtrip() {
let mut g = TaskGraph::new();
g.add_task(TaskNode::new("setup", "Prepare environment"))
.unwrap();
g.add_task(TaskNode::new("build", "Compile project").with_dependency("setup"))
.unwrap();
g.add_task(TaskNode::new("test", "Run test suite").with_dependency("build"))
.unwrap();
let json = serde_json::to_string_pretty(&g).unwrap();
assert!(json.contains("\"setup\""));
assert!(json.contains("\"build\""));
assert!(json.contains("\"test\""));
let mut g2: TaskGraph = serde_json::from_str(&json).unwrap();
g2.rebuild_index();
assert_eq!(g2.len(), 3);
let waves = g2.compute_waves().unwrap();
assert_eq!(waves.len(), 3);
assert_eq!(waves[0], vec!["setup"]);
assert_eq!(waves[1], vec!["build"]);
assert_eq!(waves[2], vec!["test"]);
}
}