use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::Result;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use crate::agent::coder::CoderAgent;
use crate::agent::{Agent, AgentResult};
use crate::io::AgentIO;
use crate::llm::{LlmProvider, Message};
use crate::orchestrator::graph::TaskGraph;
use crate::tools::{ToolContext, ToolRegistry};
#[derive(Debug)]
pub struct ExecutionReport {
pub task_results: HashMap<String, AgentResult>,
pub failed: Vec<String>,
pub cancelled: Vec<String>,
pub total_duration: Duration,
}
pub struct TaskExecutor {
pub graph: TaskGraph,
pub max_concurrent: usize,
pub max_retries: u32,
}
impl TaskExecutor {
pub fn new(graph: TaskGraph) -> Self {
TaskExecutor {
graph,
max_concurrent: 4,
max_retries: 2,
}
}
pub fn with_max_concurrent(mut self, n: usize) -> Self {
self.max_concurrent = n;
self
}
#[allow(dead_code)]
pub fn with_max_retries(mut self, n: u32) -> Self {
self.max_retries = n;
self
}
pub async fn run(
mut self,
llm: Arc<dyn LlmProvider>,
tools: Arc<ToolRegistry>,
tool_ctx: ToolContext,
io: Arc<dyn AgentIO>,
) -> Result<ExecutionReport> {
if self.graph.is_empty() {
return Ok(ExecutionReport {
task_results: HashMap::new(),
failed: Vec::new(),
cancelled: Vec::new(),
total_duration: Duration::ZERO,
});
}
let start = Instant::now();
let graph = Arc::new(Mutex::new(self.graph));
let mut join_set: JoinSet<(String, Result<AgentResult>)> = JoinSet::new();
let mut running_count: usize = 0;
let mut task_results: HashMap<String, AgentResult> = HashMap::new();
let mut failed_ids: Vec<String> = Vec::new();
let mut cancelled_ids: Vec<String> = Vec::new();
loop {
{
let mut g = graph.lock().await;
let available = self.max_concurrent.saturating_sub(running_count);
if available > 0 {
let to_dispatch: Vec<(String, String, crate::config::AgentConfig)> = g
.next_ready()
.into_iter()
.take(available)
.map(|node| {
(
node.id.clone(),
node.description.clone(),
node.agent_config.clone(),
)
})
.collect();
for (task_id, description, agent_config) in to_dispatch {
g.mark_running(&task_id)?;
running_count += 1;
let (pending, running_now, completed, failed_n, cancelled_n) =
g.status_counts();
let status_msg = format!(
"▶ Dispatching '{}' | pending={} running={} done={} failed={} cancelled={}",
task_id,
pending,
running_now,
completed,
failed_n,
cancelled_n,
);
drop(g);
io.show_status(&status_msg).await?;
g = graph.lock().await;
let llm_clone: Arc<dyn LlmProvider> = Arc::clone(&llm);
let tools_clone = Arc::clone(&tools);
let tool_ctx_clone = tool_ctx.clone();
let task_id_clone = task_id.clone();
join_set.spawn(async move {
let agent = CoderAgent::new(agent_config);
let mut messages: Vec<Message> = vec![Message::user(description)];
let result = agent
.run(
&mut messages,
&tools_clone,
llm_clone.as_ref(),
&tool_ctx_clone,
)
.await;
(task_id_clone, result)
});
}
}
}
{
let g = graph.lock().await;
if g.is_finished() {
break;
}
if running_count == 0 {
let (pending, _, _, failed_n, _) = g.status_counts();
if pending > 0 && failed_n == 0 {
anyhow::bail!(
"Task graph deadlock: {} pending tasks but none are ready \
and none are running. This is a bug — check for missing \
dependency cancellations.",
pending
);
}
break;
}
}
let Some(join_result) = join_set.join_next().await else {
break;
};
running_count -= 1;
let (task_id, agent_result) = match join_result {
Ok(pair) => pair,
Err(join_err) => {
io.write_error(&format!("⚠ A task panicked: {}", join_err))
.await?;
continue;
}
};
match agent_result {
Ok(result) => {
let mut g = graph.lock().await;
g.mark_completed(&task_id, result.clone())?;
task_results.insert(task_id.clone(), result);
let (pending, running_now, completed, failed_n, cancelled_n) =
g.status_counts();
drop(g);
io.show_status(&format!(
"✓ '{}' completed | pending={} running={} done={} failed={} cancelled={}",
task_id, pending, running_now, completed, failed_n, cancelled_n,
))
.await?;
}
Err(err) => {
let mut g = graph.lock().await;
g.mark_failed(&task_id, err.to_string())?;
let retry_count = g.get(&task_id).map(|n| n.retry_count).unwrap_or(u32::MAX);
if retry_count <= self.max_retries {
g.reset_for_retry(&task_id)?;
drop(g);
io.show_status(&format!(
"↺ '{}' failed (attempt {}), will retry (max {})",
task_id, retry_count, self.max_retries,
))
.await?;
} else {
let to_cancel = collect_dependents(&g, &task_id);
for dep_id in &to_cancel {
let _ = g.mark_cancelled(dep_id);
cancelled_ids.push(dep_id.clone());
}
failed_ids.push(task_id.clone());
let (pending, running_now, completed, failed_n, cancelled_n) =
g.status_counts();
drop(g);
io.write_error(&format!(
"✗ '{}' permanently failed after {} retries: {}",
task_id, retry_count, err
))
.await?;
io.show_status(&format!(
" Cancelled {} dependents | pending={} running={} done={} failed={} cancelled={}",
to_cancel.len(), pending, running_now, completed, failed_n, cancelled_n,
))
.await?;
}
}
}
}
self.graph = Arc::try_unwrap(graph)
.expect("Arc still has other holders — this is a bug in the executor")
.into_inner();
Ok(ExecutionReport {
task_results,
failed: failed_ids,
cancelled: cancelled_ids,
total_duration: start.elapsed(),
})
}
}
fn collect_dependents(graph: &TaskGraph, failed_id: &str) -> Vec<String> {
let mut result = Vec::new();
let mut frontier: Vec<String> = vec![failed_id.to_string()];
let mut visited = std::collections::HashSet::new();
visited.insert(failed_id.to_string());
while !frontier.is_empty() {
let current_frontier = std::mem::take(&mut frontier);
for current_id in ¤t_frontier {
for node in graph.nodes() {
if node.depends_on.contains(current_id) && !visited.contains(&node.id) {
visited.insert(node.id.clone());
result.push(node.id.clone());
frontier.push(node.id.clone());
}
}
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::io::NullIO;
use crate::orchestrator::graph::{TaskGraph, TaskNode};
use crate::tools::ToolRegistry;
use crate::tracking::SessionTracker;
use std::sync::Arc;
use crate::llm::{LlmProvider, LlmResponse, Message as LlmMessage, ToolDefinition};
use async_trait::async_trait;
struct ImmediateCompleteProvider;
#[async_trait]
impl LlmProvider for ImmediateCompleteProvider {
async fn chat_completion(
&self,
_messages: &[LlmMessage],
_tools: &[ToolDefinition],
) -> Result<LlmResponse> {
Ok(LlmResponse {
content: Some("[TASK_COMPLETE]".to_string()),
tool_calls: None,
usage: None,
})
}
}
fn make_tool_ctx(io: Arc<dyn AgentIO>) -> crate::tools::ToolContext {
use crate::tools::ToolContext;
use std::path::PathBuf;
use tokio::sync::Mutex;
ToolContext {
working_dir: PathBuf::from("/tmp"),
sandbox_enabled: false,
io,
compact_mode: false,
lsp_client: Arc::new(Mutex::new(None)),
mcp_client: None,
nesting_depth: 0,
llm: Arc::new(ImmediateCompleteProvider),
tools: Arc::new(ToolRegistry::new()),
permissions: vec![],
formatters: std::collections::HashMap::new(),
}
}
fn dummy_result() -> AgentResult {
AgentResult {
final_message: "done".to_string(),
iterations: 1,
tool_calls_total: 0,
auto_continues: 0,
tracker: SessionTracker::new("test-model"),
}
}
#[tokio::test]
async fn test_executor_empty_graph() {
let graph = TaskGraph::new();
let executor = TaskExecutor::new(graph);
let io: Arc<dyn AgentIO> = Arc::new(NullIO);
let report = executor
.run(
Arc::new(ImmediateCompleteProvider),
Arc::new(ToolRegistry::new()),
make_tool_ctx(Arc::clone(&io)),
io,
)
.await
.unwrap();
assert!(report.task_results.is_empty());
assert!(report.failed.is_empty());
assert!(report.cancelled.is_empty());
assert!(report.total_duration < Duration::from_secs(1));
}
#[tokio::test]
async fn test_executor_single_task() {
let mut graph = TaskGraph::new();
graph
.add_task(TaskNode::new("t1", "Write a hello world function"))
.unwrap();
let executor = TaskExecutor::new(graph).with_max_retries(0);
let io: Arc<dyn AgentIO> = Arc::new(NullIO);
let report = executor
.run(
Arc::new(ImmediateCompleteProvider),
Arc::new(ToolRegistry::new()),
make_tool_ctx(Arc::clone(&io)),
io,
)
.await
.unwrap();
assert_eq!(report.task_results.len(), 1, "expected 1 completed task");
assert!(report.task_results.contains_key("t1"));
assert!(report.failed.is_empty());
assert!(report.cancelled.is_empty());
}
#[tokio::test]
async fn test_executor_parallel_tasks() {
let mut graph = TaskGraph::new();
graph.add_task(TaskNode::new("a", "Task A")).unwrap();
graph.add_task(TaskNode::new("b", "Task B")).unwrap();
graph.add_task(TaskNode::new("c", "Task C")).unwrap();
let executor = TaskExecutor::new(graph)
.with_max_concurrent(3)
.with_max_retries(0);
let io: Arc<dyn AgentIO> = Arc::new(NullIO);
let report = executor
.run(
Arc::new(ImmediateCompleteProvider),
Arc::new(ToolRegistry::new()),
make_tool_ctx(Arc::clone(&io)),
io,
)
.await
.unwrap();
assert_eq!(report.task_results.len(), 3);
assert!(report.task_results.contains_key("a"));
assert!(report.task_results.contains_key("b"));
assert!(report.task_results.contains_key("c"));
assert!(report.failed.is_empty());
assert!(report.cancelled.is_empty());
}
#[tokio::test]
async fn test_executor_linear_chain() {
let mut graph = TaskGraph::new();
graph.add_task(TaskNode::new("t1", "Step 1")).unwrap();
graph
.add_task(TaskNode::new("t2", "Step 2").with_dependency("t1"))
.unwrap();
graph
.add_task(TaskNode::new("t3", "Step 3").with_dependency("t2"))
.unwrap();
let executor = TaskExecutor::new(graph)
.with_max_concurrent(1)
.with_max_retries(0);
let io: Arc<dyn AgentIO> = Arc::new(NullIO);
let report = executor
.run(
Arc::new(ImmediateCompleteProvider),
Arc::new(ToolRegistry::new()),
make_tool_ctx(Arc::clone(&io)),
io,
)
.await
.unwrap();
assert_eq!(
report.task_results.len(),
3,
"all three tasks must complete"
);
assert!(report.failed.is_empty());
assert!(report.cancelled.is_empty());
}
#[tokio::test]
async fn test_executor_cancels_dependents() {
let mut graph = TaskGraph::new();
graph.add_task(TaskNode::new("t1", "Failing task")).unwrap();
graph
.add_task(TaskNode::new("t2", "Depends on t1").with_dependency("t1"))
.unwrap();
graph
.add_task(TaskNode::new("t3", "Depends on t2").with_dependency("t2"))
.unwrap();
graph.add_task(TaskNode::new("t4", "Independent")).unwrap();
graph
.mark_failed("t1", "injected failure".to_string())
.unwrap();
graph.reset_for_retry("t1").unwrap();
let executor = TaskExecutor::new(graph)
.with_max_concurrent(4)
.with_max_retries(0);
let io: Arc<dyn AgentIO> = Arc::new(NullIO);
let report = executor
.run(
Arc::new(FailTaskProvider {
fail_prefix: "Failing task".to_string(),
}),
Arc::new(ToolRegistry::new()),
make_tool_ctx(Arc::clone(&io)),
io,
)
.await
.unwrap();
assert!(
report.task_results.contains_key("t4"),
"t4 (independent) should complete even when t1 fails"
);
assert!(
report.failed.contains(&"t1".to_string()),
"t1 should be failed"
);
assert!(
report.cancelled.contains(&"t2".to_string()),
"t2 should be cancelled"
);
assert!(
report.cancelled.contains(&"t3".to_string()),
"t3 should be cancelled"
);
}
#[test]
fn test_collect_dependents_transitive() {
let mut graph = TaskGraph::new();
graph.add_task(TaskNode::new("root", "Root")).unwrap();
graph
.add_task(TaskNode::new("a", "A").with_dependency("root"))
.unwrap();
graph
.add_task(TaskNode::new("b", "B").with_dependency("root"))
.unwrap();
graph
.add_task(
TaskNode::new("c", "C")
.with_dependency("a")
.with_dependency("b"),
)
.unwrap();
let deps = collect_dependents(&graph, "root");
assert!(deps.contains(&"a".to_string()));
assert!(deps.contains(&"b".to_string()));
assert!(deps.contains(&"c".to_string()));
assert!(!deps.contains(&"root".to_string()));
}
#[test]
fn test_collect_dependents_direct_only() {
let mut graph = TaskGraph::new();
graph.add_task(TaskNode::new("a", "A")).unwrap();
graph
.add_task(TaskNode::new("b", "B").with_dependency("a"))
.unwrap();
graph
.add_task(TaskNode::new("c", "C").with_dependency("b"))
.unwrap();
let deps = collect_dependents(&graph, "a");
assert!(deps.contains(&"b".to_string()));
assert!(deps.contains(&"c".to_string()));
}
#[test]
fn test_collect_dependents_leaf() {
let mut graph = TaskGraph::new();
graph.add_task(TaskNode::new("a", "A")).unwrap();
graph
.add_task(TaskNode::new("b", "B").with_dependency("a"))
.unwrap();
let deps = collect_dependents(&graph, "b");
assert!(deps.is_empty());
}
struct FailTaskProvider {
fail_prefix: String,
}
#[async_trait]
impl LlmProvider for FailTaskProvider {
async fn chat_completion(
&self,
messages: &[LlmMessage],
_tools: &[ToolDefinition],
) -> Result<LlmResponse> {
let should_fail = messages.iter().any(|m| {
m.text_content()
.map(|t| t.starts_with(&self.fail_prefix))
.unwrap_or(false)
});
if should_fail {
Err(anyhow::anyhow!("injected failure for testing"))
} else {
Ok(LlmResponse {
content: Some("[TASK_COMPLETE]".to_string()),
tool_calls: None,
usage: None,
})
}
}
}
}