use super::task_plan::{
ParallelTask, ParallelTaskResult, StepResult, StepType, TaskPlan, TaskStep,
};
use crate::agent::task::{Task, TaskResult};
use anyhow::Result;
use async_trait::async_trait;
use futures::future::join_all;
use log::{debug, info};
use std::time::Instant;
#[async_trait]
pub trait AgentOrchestrator: Send + Sync {
async fn orchestrate_task(&self, task: &Task) -> Result<TaskResult> {
info!("Starting orchestration for task: {}", task.id);
let plan = self.analyze_task(task).await?;
debug!("Created plan with {} steps", plan.steps.len());
let mut all_results = Vec::new();
let mut current_plan = plan;
for (index, step) in current_plan.steps.clone().iter().enumerate() {
info!(
"Executing step {}/{}: {}",
index + 1,
current_plan.steps.len(),
step.name
);
let step_result = self.execute_step(step, ¤t_plan.context).await?;
for (key, value) in &step_result.outputs {
current_plan.update_context(key.clone(), value.clone());
}
all_results.push(step_result);
if index < current_plan.steps.len() - 1 && current_plan.adaptive {
debug!("Reviewing progress and adapting plan...");
if let Ok(adapted_plan) =
self.review_and_adapt(&all_results, &mut current_plan).await
{
current_plan = adapted_plan;
info!(
"Plan adapted, now has {} remaining steps",
current_plan.steps.len() - index - 1
);
}
}
}
self.synthesize_results(task, all_results).await
}
async fn analyze_task(&self, task: &Task) -> Result<TaskPlan>;
async fn execute_step(
&self,
step: &TaskStep,
context: &std::collections::HashMap<String, String>,
) -> Result<StepResult> {
let start = Instant::now();
let mut result = StepResult::new(step.id.clone());
if !step.parallel_tasks.is_empty() {
info!(
"Executing {} parallel tasks for step: {}",
step.parallel_tasks.len(),
step.name
);
let futures: Vec<_> = step
.parallel_tasks
.iter()
.map(|task| self.execute_parallel_task(task, context))
.collect();
let parallel_results = join_all(futures).await;
for (task, task_result) in step.parallel_tasks.iter().zip(parallel_results) {
match task_result {
Ok(res) => {
if !res.success && task.critical {
result = result.failed(format!("Critical task '{}' failed", task.name));
}
result.parallel_results.push(res);
}
Err(e) => {
let error_msg = format!("Task '{}' error: {}", task.name, e);
if task.critical {
result = result.failed(error_msg);
} else {
result.errors.push(error_msg);
}
}
}
}
}
let summary = self.generate_step_summary(step, &result).await?;
result = result.with_summary(summary);
result.duration_ms = start.elapsed().as_millis() as u64;
Ok(result)
}
async fn execute_parallel_task(
&self,
task: &ParallelTask,
context: &std::collections::HashMap<String, String>,
) -> Result<ParallelTaskResult>;
async fn review_and_adapt(
&self,
_results: &[StepResult],
current_plan: &mut TaskPlan,
) -> Result<TaskPlan> {
Ok(current_plan.clone())
}
async fn synthesize_results(&self, task: &Task, results: Vec<StepResult>)
-> Result<TaskResult>;
async fn generate_step_summary(&self, step: &TaskStep, result: &StepResult) -> Result<String> {
let success_count = result.parallel_results.iter().filter(|r| r.success).count();
let total_count = result.parallel_results.len();
Ok(format!(
"Step '{}' completed: {}/{} tasks successful. Duration: {}ms",
step.name, success_count, total_count, result.duration_ms
))
}
}
pub struct OrchestrationBuilder;
impl OrchestrationBuilder {
pub fn analysis_step(id: &str, name: &str, tasks: Vec<ParallelTask>) -> TaskStep {
let mut step = TaskStep::new(id.to_string(), name.to_string(), StepType::Analysis);
for task in tasks {
step.add_parallel_task(task);
}
step
}
pub fn execution_step(id: &str, name: &str, dependencies: Vec<&str>) -> TaskStep {
let mut step = TaskStep::new(id.to_string(), name.to_string(), StepType::Execution);
for dep in dependencies {
step = step.depends_on(dep.to_string());
}
step
}
pub fn validation_step(id: &str, name: &str) -> TaskStep {
TaskStep::new(id.to_string(), name.to_string(), StepType::Validation)
}
pub fn parallel_task(id: &str, name: &str, command: &str, critical: bool) -> ParallelTask {
ParallelTask {
id: id.to_string(),
name: name.to_string(),
command: command.to_string(),
expected_duration_ms: 1000,
critical,
expect_failure: false,
}
}
}