use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::mpsc;
use tokio::task::JoinSet;
use crate::{
CancellationToken, ExecutionMetrics, ExecutionResult, FailureBehavior, FlowPattern, RunError,
RunId, RunStatus, RuntimeAdapter,
};
use serde_json::Value;
use super::{build_capability_output, execute_worker_with_arc, PatternContext, PatternExecutor};
pub struct ParallelExecutor {
result_buffer_size: usize,
}
impl ParallelExecutor {
pub fn new() -> Self {
ParallelExecutor {
result_buffer_size: 32,
}
}
pub fn with_buffer_size(buffer_size: usize) -> Self {
ParallelExecutor {
result_buffer_size: buffer_size.max(1),
}
}
}
impl Default for ParallelExecutor {
fn default() -> Self {
Self::new()
}
}
struct BranchResult {
branch_name: String,
result: Result<ExecutionResult, RunError>,
}
#[async_trait]
impl PatternExecutor for ParallelExecutor {
fn name(&self) -> &'static str {
"parallel"
}
async fn execute(
&self,
_ctx: &PatternContext,
_runtime: &dyn crate::RuntimeAdapter,
_cancel: &CancellationToken,
) -> Result<ExecutionResult, RunError> {
Err(RunError::RuntimeError {
message: "ParallelExecutor requires Arc runtime. Use execute_with_arc() instead."
.into(),
})
}
async fn execute_with_arc(
&self,
ctx: &PatternContext,
runtime: Arc<dyn RuntimeAdapter>,
cancel: &CancellationToken,
) -> Result<ExecutionResult, RunError> {
let (branches, pattern_fail_fast) = match &ctx.swarm.flow {
FlowPattern::Parallel {
branches,
fail_fast,
} => (branches.clone(), *fail_fast),
_ => {
return Err(RunError::PatternError {
pattern: "parallel".into(),
step: "flow".into(),
message: "ParallelExecutor requires Parallel pattern in flow".into(),
})
}
};
let on_failure = ctx.swarm.on_failure;
let fail_fast = match on_failure {
FailureBehavior::FailFast => pattern_fail_fast,
FailureBehavior::Continue | FailureBehavior::Ignore => false,
};
if branches.is_empty() {
return Ok(ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Completed,
artifacts: vec![],
error: None,
metrics: ExecutionMetrics::default(),
output: None,
});
}
if cancel.is_cancelled().await {
return Ok(ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Cancelled,
artifacts: vec![],
error: Some(RunError::Cancelled {
reason: "Execution cancelled before start".into(),
}),
metrics: ExecutionMetrics::default(),
output: None,
});
}
let (result_tx, mut result_rx) = mpsc::channel::<BranchResult>(self.result_buffer_size);
let total_wall_time = Arc::new(AtomicU64::new(0));
let completed_count = Arc::new(AtomicU64::new(0));
let failed_count = Arc::new(AtomicU64::new(0));
let mut tasks = JoinSet::new();
for branch_name in branches.iter() {
let worker = ctx
.get_worker(branch_name)
.ok_or_else(|| RunError::PatternError {
pattern: "parallel".into(),
step: branch_name.clone(),
message: format!("Worker '{}' not found in swarm", branch_name),
})?;
let worker = worker.clone();
let runtime_clone = runtime.clone();
let runtime_ctx = ctx.runtime_ctx.clone();
let scope = ctx.scope.clone();
let cancel_clone = cancel.clone();
let tx = result_tx.clone();
let branch_name_clone = branch_name.clone();
let wall_time_counter = total_wall_time.clone();
let completed_counter = completed_count.clone();
let failed_counter = failed_count.clone();
tasks.spawn(async move {
let result = execute_worker_with_arc(
&worker,
runtime_clone.clone(),
&runtime_ctx,
&scope,
&cancel_clone,
)
.await;
match &result {
Ok(exec_result) => {
wall_time_counter
.fetch_add(exec_result.metrics.wall_time_ms, Ordering::Relaxed);
if exec_result.status == RunStatus::Completed {
completed_counter.fetch_add(1, Ordering::Relaxed);
} else if exec_result.status == RunStatus::Failed {
failed_counter.fetch_add(1, Ordering::Relaxed);
}
}
Err(_) => {
failed_counter.fetch_add(1, Ordering::Relaxed);
}
}
if tx
.send(BranchResult {
branch_name: branch_name_clone,
result,
})
.await
.is_err()
{
}
});
}
drop(result_tx);
let mut artifacts = vec![];
let mut first_error: Option<RunError> = None;
let mut branch_outputs: serde_json::Map<String, Value> = serde_json::Map::new();
while let Some(branch_result) = result_rx.recv().await {
match branch_result.result {
Ok(exec_result) => {
match exec_result.status {
RunStatus::Completed => {
artifacts.extend(exec_result.artifacts);
if let Some(output) = exec_result.output {
branch_outputs.insert(branch_result.branch_name.clone(), output);
}
}
RunStatus::Failed => {
if fail_fast && first_error.is_none() {
cancel.cancel().await;
tasks.abort_all();
while result_rx.recv().await.is_some() {}
let error = exec_result.error;
return Ok(ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Failed,
artifacts,
error,
metrics: ExecutionMetrics {
wall_time_ms: total_wall_time.load(Ordering::Relaxed),
..Default::default()
},
output: if branch_outputs.is_empty() {
None
} else {
Some(Value::Object(branch_outputs))
},
});
} else if first_error.is_none() {
first_error = exec_result.error;
}
}
RunStatus::Cancelled => {
}
_ => {}
}
}
Err(e) => {
if fail_fast && first_error.is_none() {
cancel.cancel().await;
tasks.abort_all();
while result_rx.recv().await.is_some() {}
return Err(e);
} else if first_error.is_none() {
first_error = Some(e);
}
}
}
}
while tasks.join_next().await.is_some() {}
let failures = failed_count.load(Ordering::Relaxed);
let (final_status, final_error) = if failures == 0 {
(RunStatus::Completed, None)
} else {
match on_failure {
FailureBehavior::FailFast | FailureBehavior::Continue => (
RunStatus::Failed,
first_error.or_else(|| {
Some(RunError::PatternError {
pattern: "parallel".into(),
step: "summary".into(),
message: format!("{} branch(es) failed", failures),
})
}),
),
FailureBehavior::Ignore => (RunStatus::Completed, None),
}
};
let mut final_scope = ctx.scope.clone();
for (branch_name, output) in &branch_outputs {
final_scope.add_step_output(branch_name.clone(), output.clone());
}
let result = ExecutionResult {
run_id: RunId::new(),
status: final_status,
artifacts,
error: final_error,
metrics: ExecutionMetrics {
wall_time_ms: total_wall_time.load(Ordering::Relaxed),
..Default::default()
},
output: None, };
Ok(build_capability_output(result, &ctx.swarm, &final_scope))
}
async fn on_failure(
&self,
ctx: &mut PatternContext,
_runtime: &dyn crate::RuntimeAdapter,
failed_worker: &str,
_error: &RunError,
) -> Result<bool, RunError> {
ctx.state.failed.push(failed_worker.to_string());
let fail_fast = match &ctx.swarm.flow {
FlowPattern::Parallel { fail_fast, .. } => *fail_fast,
_ => true,
};
Ok(!fail_fast)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{CancellationToken, ExecutionContext, FlowPattern, RuntimeKind, SwarmFile, Worker};
#[test]
fn test_parallel_executor_name() {
let executor = ParallelExecutor::new();
assert_eq!(executor.name(), "parallel");
}
#[test]
fn test_parallel_executor_with_buffer_size() {
let executor = ParallelExecutor::with_buffer_size(64);
assert_eq!(executor.result_buffer_size, 64);
}
#[test]
fn test_parallel_executor_buffer_size_min() {
let executor = ParallelExecutor::with_buffer_size(0);
assert_eq!(executor.result_buffer_size, 1);
}
#[tokio::test]
async fn test_parallel_executor_wrong_pattern() {
let executor = ParallelExecutor::new();
let swarm = SwarmFile::new("test", FlowPattern::Sequence { steps: vec![] });
let ctx = PatternContext::new(swarm, ExecutionContext::new("ctx", RuntimeKind::Local));
let cancel = CancellationToken::new();
let runtime = crate::create_runtime(RuntimeKind::Local).unwrap();
let result = executor.execute_with_arc(&ctx, runtime, &cancel).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_parallel_executor_empty_branches() {
let executor = ParallelExecutor::new();
let swarm = SwarmFile::new(
"test",
FlowPattern::Parallel {
branches: vec![],
fail_fast: false,
},
);
let ctx = PatternContext::new(swarm, ExecutionContext::new("ctx", RuntimeKind::Local));
let cancel = CancellationToken::new();
let runtime = crate::create_runtime(RuntimeKind::Local).unwrap();
let result = executor.execute_with_arc(&ctx, runtime, &cancel).await;
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.status, RunStatus::Completed);
assert!(result.artifacts.is_empty());
}
#[tokio::test]
async fn test_parallel_executor_real_execution() {
let executor = ParallelExecutor::new();
let swarm = SwarmFile::new(
"test",
FlowPattern::Parallel {
branches: vec!["w1".into(), "w2".into()],
fail_fast: false,
},
)
.with_worker(Worker::new("w1", "agent.yaml"))
.with_worker(Worker::new("w2", "agent.yaml"));
let runtime_ctx = ExecutionContext::new("ctx", RuntimeKind::Local);
let ctx = PatternContext::new(swarm, runtime_ctx);
let cancel = CancellationToken::new();
let runtime = crate::create_runtime(RuntimeKind::Local).unwrap();
let result = executor.execute_with_arc(&ctx, runtime, &cancel).await;
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.status, RunStatus::Completed);
}
#[tokio::test]
async fn test_parallel_executor_with_failure() {
use std::io::Write;
let executor = ParallelExecutor::new();
let temp_dir = std::env::temp_dir().join("bzzz-parallel-fail-test");
std::fs::create_dir_all(&temp_dir).unwrap();
let failing_spec_path = temp_dir.join("failing.yaml");
let mut file = std::fs::File::create(&failing_spec_path).unwrap();
writeln!(file, "apiVersion: v1").unwrap();
writeln!(file, "id: failing-agent").unwrap();
writeln!(file, "runtime:").unwrap();
writeln!(file, " kind: Local").unwrap();
writeln!(file, " config:").unwrap();
writeln!(file, " command: /usr/bin/false").unwrap(); drop(file);
let swarm = SwarmFile::new(
"test",
FlowPattern::Parallel {
branches: vec!["failing".into()],
fail_fast: false,
},
)
.with_worker(Worker::new(
"failing",
failing_spec_path.to_string_lossy().to_string(),
));
let runtime_ctx = ExecutionContext::new("ctx", RuntimeKind::Local);
let ctx = PatternContext::new(swarm, runtime_ctx);
let cancel = CancellationToken::new();
let runtime = crate::create_runtime(RuntimeKind::Local).unwrap();
let result = executor.execute_with_arc(&ctx, runtime, &cancel).await;
std::fs::remove_dir_all(&temp_dir).ok();
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.status, RunStatus::Failed);
assert!(result.error.is_some());
}
#[tokio::test]
async fn test_parallel_executor_fail_fast() {
use std::io::Write;
let executor = ParallelExecutor::new();
let temp_dir = std::env::temp_dir().join("bzzz-parallel-failfast-test");
std::fs::create_dir_all(&temp_dir).unwrap();
let failing_spec_path = temp_dir.join("failing.yaml");
let mut file = std::fs::File::create(&failing_spec_path).unwrap();
writeln!(file, "apiVersion: v1").unwrap();
writeln!(file, "id: failing-agent").unwrap();
writeln!(file, "runtime:").unwrap();
writeln!(file, " kind: Local").unwrap();
writeln!(file, " config:").unwrap();
writeln!(file, " command: /usr/bin/false").unwrap();
drop(file);
let swarm = SwarmFile::new(
"test",
FlowPattern::Parallel {
branches: vec!["failing".into()],
fail_fast: true,
},
)
.with_worker(Worker::new(
"failing",
failing_spec_path.to_string_lossy().to_string(),
));
let runtime_ctx = ExecutionContext::new("ctx", RuntimeKind::Local);
let ctx = PatternContext::new(swarm, runtime_ctx);
let cancel = CancellationToken::new();
let runtime = crate::create_runtime(RuntimeKind::Local).unwrap();
let result = executor.execute_with_arc(&ctx, runtime, &cancel).await;
std::fs::remove_dir_all(&temp_dir).ok();
let result = result.expect("execute_with_arc should return Ok for executed workers");
assert_eq!(result.status, RunStatus::Failed);
}
#[tokio::test]
async fn test_parallel_executor_cancellation() {
let executor = ParallelExecutor::new();
let swarm = SwarmFile::new(
"test",
FlowPattern::Parallel {
branches: vec!["w1".into()],
fail_fast: false,
},
)
.with_worker(Worker::new("w1", "agent.yaml"));
let runtime_ctx = ExecutionContext::new("ctx", RuntimeKind::Local);
let ctx = PatternContext::new(swarm, runtime_ctx);
let cancel = CancellationToken::new();
cancel.cancel().await;
let runtime = crate::create_runtime(RuntimeKind::Local).unwrap();
let result = executor.execute_with_arc(&ctx, runtime, &cancel).await;
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.status, RunStatus::Cancelled);
}
#[tokio::test]
async fn test_parallel_executor_metrics_aggregation() {
let executor = ParallelExecutor::new();
let swarm = SwarmFile::new(
"test",
FlowPattern::Parallel {
branches: vec!["w1".into(), "w2".into(), "w3".into()],
fail_fast: false,
},
)
.with_worker(Worker::new("w1", "agent.yaml"))
.with_worker(Worker::new("w2", "agent.yaml"))
.with_worker(Worker::new("w3", "agent.yaml"));
let runtime_ctx = ExecutionContext::new("ctx", RuntimeKind::Local);
let ctx = PatternContext::new(swarm, runtime_ctx);
let cancel = CancellationToken::new();
let runtime = crate::create_runtime(RuntimeKind::Local).unwrap();
let result = executor.execute_with_arc(&ctx, runtime, &cancel).await;
assert!(result.is_ok());
let result = result.unwrap();
assert!(result.metrics.wall_time_ms > 0 || result.status == RunStatus::Completed);
}
#[tokio::test]
async fn test_parallel_on_failure_continue() {
use crate::FailureBehavior;
use std::io::Write;
let executor = ParallelExecutor::new();
let temp_dir = std::env::temp_dir().join("bzzz-par-continue-test");
std::fs::create_dir_all(&temp_dir).unwrap();
let failing_spec_path = temp_dir.join("failing.yaml");
let mut file = std::fs::File::create(&failing_spec_path).unwrap();
writeln!(file, "apiVersion: v1").unwrap();
writeln!(file, "id: failing-agent").unwrap();
writeln!(file, "runtime:").unwrap();
writeln!(file, " kind: Local").unwrap();
writeln!(file, " config:").unwrap();
writeln!(file, " command: /usr/bin/false").unwrap();
drop(file);
let ok_spec_path = temp_dir.join("ok.yaml");
let mut file = std::fs::File::create(&ok_spec_path).unwrap();
writeln!(file, "apiVersion: v1").unwrap();
writeln!(file, "id: ok-agent").unwrap();
writeln!(file, "runtime:").unwrap();
writeln!(file, " kind: Local").unwrap();
writeln!(file, " config:").unwrap();
writeln!(file, " command: /usr/bin/true").unwrap();
drop(file);
let swarm = SwarmFile::new(
"test",
FlowPattern::Parallel {
branches: vec!["failing".into(), "ok".into()],
fail_fast: true, },
)
.with_worker(Worker::new(
"failing",
failing_spec_path.to_string_lossy().to_string(),
))
.with_worker(Worker::new(
"ok",
ok_spec_path.to_string_lossy().to_string(),
))
.with_failure_behavior(FailureBehavior::Continue);
let runtime_ctx = ExecutionContext::new("ctx", RuntimeKind::Local);
let ctx = PatternContext::new(swarm, runtime_ctx);
let cancel = CancellationToken::new();
let runtime = crate::create_runtime(RuntimeKind::Local).unwrap();
let result = executor.execute_with_arc(&ctx, runtime, &cancel).await.unwrap();
std::fs::remove_dir_all(&temp_dir).ok();
assert_eq!(result.status, RunStatus::Failed);
assert!(result.error.is_some());
}
#[tokio::test]
async fn test_parallel_on_failure_ignore() {
use crate::FailureBehavior;
use std::io::Write;
let executor = ParallelExecutor::new();
let temp_dir = std::env::temp_dir().join("bzzz-par-ignore-test");
std::fs::create_dir_all(&temp_dir).unwrap();
let failing_spec_path = temp_dir.join("failing.yaml");
let mut file = std::fs::File::create(&failing_spec_path).unwrap();
writeln!(file, "apiVersion: v1").unwrap();
writeln!(file, "id: failing-agent").unwrap();
writeln!(file, "runtime:").unwrap();
writeln!(file, " kind: Local").unwrap();
writeln!(file, " config:").unwrap();
writeln!(file, " command: /usr/bin/false").unwrap();
drop(file);
let swarm = SwarmFile::new(
"test",
FlowPattern::Parallel {
branches: vec!["failing".into()],
fail_fast: false,
},
)
.with_worker(Worker::new(
"failing",
failing_spec_path.to_string_lossy().to_string(),
))
.with_failure_behavior(FailureBehavior::Ignore);
let runtime_ctx = ExecutionContext::new("ctx", RuntimeKind::Local);
let ctx = PatternContext::new(swarm, runtime_ctx);
let cancel = CancellationToken::new();
let runtime = crate::create_runtime(RuntimeKind::Local).unwrap();
let result = executor.execute_with_arc(&ctx, runtime, &cancel).await.unwrap();
std::fs::remove_dir_all(&temp_dir).ok();
assert_eq!(result.status, RunStatus::Completed);
assert!(result.error.is_none());
}
}