use std::collections::HashMap;
use async_trait::async_trait;
use crate::{ExecutionMetrics, ExecutionResult, FailureBehavior, FlowPattern, RunError, RunId, RunStatus, WorkflowTask};
use super::{build_capability_output, execute_worker, PatternContext, PatternExecutor};
pub struct WorkflowExecutor;
impl WorkflowExecutor {
pub fn new() -> Self {
WorkflowExecutor
}
}
impl Default for WorkflowExecutor {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
struct TaskState {
pending_deps: usize,
completed: bool,
failed: bool,
skipped: bool,
}
impl TaskState {
fn new(pending_deps: usize) -> Self {
TaskState {
pending_deps,
completed: false,
failed: false,
skipped: false,
}
}
}
#[async_trait]
impl PatternExecutor for WorkflowExecutor {
fn name(&self) -> &'static str {
"workflow"
}
async fn execute(
&self,
ctx: &PatternContext,
runtime: &dyn crate::RuntimeAdapter,
cancel: &crate::CancellationToken,
) -> Result<ExecutionResult, RunError> {
let (tasks, synthesis) = match &ctx.swarm.flow {
FlowPattern::Workflow { tasks, synthesis } => (tasks, synthesis),
_ => {
return Err(RunError::PatternError {
pattern: "workflow".into(),
step: "flow".into(),
message: "WorkflowExecutor requires Workflow pattern in flow".into(),
})
}
};
if tasks.is_empty() {
return Ok(ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Completed,
artifacts: vec![],
error: None,
metrics: ExecutionMetrics::default(),
output: None,
});
}
let mut task_states: HashMap<String, TaskState> = HashMap::new();
let mut downstream: HashMap<String, Vec<String>> = HashMap::new();
for task in tasks {
let pending_deps = task.depends_on.len();
task_states.insert(task.name.clone(), TaskState::new(pending_deps));
for dep in &task.depends_on {
downstream
.entry(dep.clone())
.or_default()
.push(task.name.clone());
}
}
let mut artifacts = vec![];
let mut current_ctx = ctx.clone();
let on_failure = ctx.swarm.on_failure;
let mut failed_tasks: Vec<String> = vec![];
let mut skipped_tasks: Vec<String> = vec![];
loop {
if cancel.is_cancelled().await {
return Ok(ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Cancelled,
artifacts,
error: Some(RunError::Cancelled {
reason: "Execution cancelled".into(),
}),
metrics: ExecutionMetrics::default(),
output: None,
});
}
let ready: Vec<String> = task_states
.iter()
.filter(|(_, state)| {
state.pending_deps == 0
&& !state.completed
&& !state.failed
&& !state.skipped
})
.map(|(name, _)| name.clone())
.collect();
if ready.is_empty() {
break;
}
for task_name in ready {
let should_skip = should_skip_task(&task_name, &task_states, tasks);
if should_skip {
if let Some(state) = task_states.get_mut(&task_name) {
state.skipped = true;
skipped_tasks.push(task_name.clone());
}
if let Some(downstream_tasks) = downstream.get(&task_name) {
for downstream_task in downstream_tasks {
if let Some(ds_state) = task_states.get_mut(downstream_task) {
ds_state.pending_deps = ds_state.pending_deps.saturating_sub(1);
}
}
}
continue;
}
let worker = current_ctx
.get_worker(&task_name)
.ok_or_else(|| RunError::PatternError {
pattern: "workflow".into(),
step: task_name.clone(),
message: format!("Worker '{}' not found in swarm", task_name),
})?;
let result = execute_worker(
worker,
runtime,
¤t_ctx.runtime_ctx,
¤t_ctx.scope,
cancel,
)
.await?;
match result.status {
RunStatus::Completed => {
artifacts.extend(result.artifacts);
if let Some(output) = &result.output {
current_ctx.add_step_output(&task_name, output.clone());
}
if let Some(state) = task_states.get_mut(&task_name) {
state.completed = true;
}
if let Some(downstream_tasks) = downstream.get(&task_name) {
for downstream_task in downstream_tasks {
if let Some(ds_state) = task_states.get_mut(downstream_task) {
ds_state.pending_deps = ds_state.pending_deps.saturating_sub(1);
}
}
}
}
RunStatus::Failed => {
if let Some(state) = task_states.get_mut(&task_name) {
state.failed = true;
}
failed_tasks.push(task_name.clone());
match on_failure {
FailureBehavior::FailFast => {
for (name, state) in &mut task_states {
if !state.completed && !state.failed {
state.skipped = true;
skipped_tasks.push(name.clone());
}
}
break;
}
FailureBehavior::Continue | FailureBehavior::Ignore => {
if let Some(downstream_tasks) = downstream.get(&task_name) {
for downstream_task in downstream_tasks {
if let Some(ds_state) = task_states.get_mut(downstream_task) {
ds_state.pending_deps = ds_state.pending_deps.saturating_sub(1);
}
}
}
}
}
}
RunStatus::Cancelled => {
return Ok(ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Cancelled,
artifacts,
error: Some(RunError::Cancelled {
reason: "Execution cancelled".into(),
}),
metrics: result.metrics,
output: None,
});
}
_ => {
return Ok(ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Failed,
artifacts,
error: Some(RunError::RuntimeError {
message: format!("Unexpected status: {:?}", result.status),
}),
metrics: result.metrics,
output: None,
});
}
}
}
}
if let Some(synth_name) = synthesis {
let any_failed = failed_tasks.iter().any(|name| {
task_states
.get(name)
.map(|s| s.failed)
.unwrap_or(false)
});
if !any_failed && !cancel.is_cancelled().await {
let synth_worker = current_ctx
.get_worker(synth_name)
.ok_or_else(|| RunError::PatternError {
pattern: "workflow".into(),
step: synth_name.to_string(),
message: format!("Synthesis worker '{}' not found", synth_name),
})?;
let result = execute_worker(
synth_worker,
runtime,
¤t_ctx.runtime_ctx,
¤t_ctx.scope,
cancel,
)
.await?;
match result.status {
RunStatus::Completed => {
artifacts.extend(result.artifacts);
if let Some(output) = &result.output {
current_ctx.add_step_output(synth_name, output.clone());
}
}
RunStatus::Failed => {
failed_tasks.push(synth_name.to_string());
}
_ => {}
}
}
}
let (final_status, final_error) = if failed_tasks.is_empty() && skipped_tasks.is_empty() {
(RunStatus::Completed, None)
} else if !failed_tasks.is_empty() {
match on_failure {
FailureBehavior::Ignore => (RunStatus::Completed, None),
_ => (
RunStatus::Failed,
Some(RunError::PatternError {
pattern: "workflow".into(),
step: "summary".into(),
message: format!(
"{} task(s) failed: {}, {} task(s) skipped: {}",
failed_tasks.len(),
failed_tasks.join(", "),
skipped_tasks.len(),
skipped_tasks.join(", ")
),
}),
),
}
} else {
match on_failure {
FailureBehavior::Ignore => (RunStatus::Completed, None),
_ => (
RunStatus::Failed,
Some(RunError::PatternError {
pattern: "workflow".into(),
step: "summary".into(),
message: format!("{} task(s) skipped due to upstream failure", skipped_tasks.len()),
}),
),
}
};
let result = ExecutionResult {
run_id: RunId::new(),
status: final_status,
artifacts,
error: final_error,
metrics: ExecutionMetrics::default(),
output: None,
};
Ok(build_capability_output(
result,
&ctx.swarm,
¤t_ctx.scope,
))
}
async fn on_failure(
&self,
_ctx: &mut PatternContext,
_runtime: &dyn crate::RuntimeAdapter,
_failed_worker: &str,
_error: &RunError,
) -> Result<bool, RunError> {
Ok(false)
}
}
fn should_skip_task(
task_name: &str,
task_states: &HashMap<String, TaskState>,
tasks: &[WorkflowTask],
) -> bool {
let task_def = match tasks.iter().find(|t| t.name == task_name) {
Some(t) => t,
None => return false,
};
for dep in &task_def.depends_on {
if let Some(state) = task_states.get(dep) {
if state.failed || state.skipped {
return true;
}
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{CancellationToken, ExecutionContext, RuntimeKind, SwarmFile, Worker};
use std::io::Write;
#[test]
fn test_workflow_executor_name() {
let executor = WorkflowExecutor::new();
assert_eq!(executor.name(), "workflow");
}
#[tokio::test]
async fn test_workflow_executor_wrong_pattern() {
let executor = WorkflowExecutor::new();
let swarm = SwarmFile::new(
"test",
FlowPattern::Parallel {
branches: vec![],
fail_fast: false,
},
);
let ctx = PatternContext::new(swarm, ExecutionContext::new("ctx", RuntimeKind::Local));
let cancel = CancellationToken::new();
let result = executor
.execute(&ctx, &crate::LocalRuntime::new(), &cancel)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_workflow_empty_tasks() {
let executor = WorkflowExecutor::new();
let swarm = SwarmFile::new(
"test",
FlowPattern::Workflow {
tasks: vec![],
synthesis: None,
},
);
let ctx = PatternContext::new(swarm, ExecutionContext::new("ctx", RuntimeKind::Local));
let cancel = CancellationToken::new();
let result = executor
.execute(&ctx, &crate::LocalRuntime::new(), &cancel)
.await
.unwrap();
assert_eq!(result.status, RunStatus::Completed);
}
#[tokio::test]
async fn test_workflow_single_task() {
let executor = WorkflowExecutor::new();
let temp_dir = std::env::temp_dir().join("bzzz-workflow-single-test");
std::fs::create_dir_all(&temp_dir).unwrap();
let spec_path = temp_dir.join("agent.yaml");
let mut file = std::fs::File::create(&spec_path).unwrap();
writeln!(file, "apiVersion: v1").unwrap();
writeln!(file, "id: test-agent").unwrap();
writeln!(file, "runtime:").unwrap();
writeln!(file, " kind: Local").unwrap();
writeln!(file, " config:").unwrap();
writeln!(file, " command: /usr/bin/true").unwrap();
drop(file);
let swarm = SwarmFile::new(
"test",
FlowPattern::Workflow {
tasks: vec![WorkflowTask::new("w1")],
synthesis: None,
},
)
.with_worker(Worker::new("w1", spec_path.to_string_lossy().to_string()));
let ctx = PatternContext::new(swarm, ExecutionContext::new("ctx", RuntimeKind::Local));
let cancel = CancellationToken::new();
let result = executor
.execute(&ctx, &crate::LocalRuntime::new(), &cancel)
.await
.unwrap();
std::fs::remove_dir_all(&temp_dir).ok();
assert_eq!(result.status, RunStatus::Completed);
}
#[tokio::test]
async fn test_workflow_linear_dependency() {
let executor = WorkflowExecutor::new();
let temp_dir = std::env::temp_dir().join("bzzz-workflow-linear-test");
std::fs::create_dir_all(&temp_dir).unwrap();
let spec_path = temp_dir.join("agent.yaml");
let mut file = std::fs::File::create(&spec_path).unwrap();
writeln!(file, "apiVersion: v1").unwrap();
writeln!(file, "id: test-agent").unwrap();
writeln!(file, "runtime:").unwrap();
writeln!(file, " kind: Local").unwrap();
writeln!(file, " config:").unwrap();
writeln!(file, " command: /usr/bin/true").unwrap();
drop(file);
let swarm = SwarmFile::new(
"test",
FlowPattern::Workflow {
tasks: vec![
WorkflowTask::new("a"),
WorkflowTask::new("b").with_depends_on(vec!["a".into()]),
WorkflowTask::new("c").with_depends_on(vec!["b".into()]),
],
synthesis: None,
},
)
.with_worker(Worker::new("a", spec_path.to_string_lossy().to_string()))
.with_worker(Worker::new("b", spec_path.to_string_lossy().to_string()))
.with_worker(Worker::new("c", spec_path.to_string_lossy().to_string()));
let ctx = PatternContext::new(swarm, ExecutionContext::new("ctx", RuntimeKind::Local));
let cancel = CancellationToken::new();
let result = executor
.execute(&ctx, &crate::LocalRuntime::new(), &cancel)
.await
.unwrap();
std::fs::remove_dir_all(&temp_dir).ok();
assert_eq!(result.status, RunStatus::Completed);
}
#[test]
fn test_workflow_task_creation() {
let task = WorkflowTask::new("test-task")
.with_depends_on(vec!["dep1".into(), "dep2".into()]);
assert_eq!(task.name, "test-task");
assert_eq!(task.depends_on, vec!["dep1", "dep2"]);
}
#[test]
fn test_workflow_pattern_yaml_parsing() {
let yaml = r#"
apiVersion: bzzz.dev/v1
kind: swarm
id: workflow-test
workers:
- name: a
spec: a.yaml
- name: b
spec: b.yaml
flow:
type: workflow
tasks:
- name: a
depends_on: []
- name: b
depends_on: [a]
synthesis: b
"#;
let parsed: SwarmFile = serde_yaml::from_str(yaml).unwrap();
assert_eq!(parsed.id.as_str(), "workflow-test");
if let FlowPattern::Workflow { tasks, synthesis } = &parsed.flow {
assert_eq!(tasks.len(), 2);
assert_eq!(tasks[0].name, "a");
assert_eq!(tasks[0].depends_on, Vec::<String>::new());
assert_eq!(tasks[1].name, "b");
assert_eq!(tasks[1].depends_on, vec!["a"]);
assert_eq!(synthesis, &Some("b".to_string()));
} else {
panic!("Expected Workflow pattern");
}
}
#[test]
fn test_workflow_validation_cycle_detection() {
let swarm = SwarmFile::new(
"cycle-test",
FlowPattern::Workflow {
tasks: vec![
WorkflowTask::new("a").with_depends_on(vec!["c".into()]),
WorkflowTask::new("b").with_depends_on(vec!["a".into()]),
WorkflowTask::new("c").with_depends_on(vec!["b".into()]),
],
synthesis: None,
},
)
.with_worker(Worker::new("a", "a.yaml"))
.with_worker(Worker::new("b", "b.yaml"))
.with_worker(Worker::new("c", "c.yaml"));
let result = swarm.validate();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("cycle"));
}
#[test]
fn test_workflow_validation_undefined_dependency() {
let swarm = SwarmFile::new(
"test",
FlowPattern::Workflow {
tasks: vec![
WorkflowTask::new("a"),
WorkflowTask::new("b").with_depends_on(vec!["nonexistent".into()]),
],
synthesis: None,
},
)
.with_worker(Worker::new("a", "a.yaml"))
.with_worker(Worker::new("b", "b.yaml"));
let result = swarm.validate();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("undefined task"));
}
#[test]
fn test_workflow_validation_duplicate_task_name() {
let swarm = SwarmFile::new(
"test",
FlowPattern::Workflow {
tasks: vec![
WorkflowTask::new("a"),
WorkflowTask::new("a"), ],
synthesis: None,
},
)
.with_worker(Worker::new("a", "a.yaml"));
let result = swarm.validate();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("duplicate"));
}
}