use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::mpsc;
use tokio::task::JoinSet;
use crate::{ExecutionMetrics, ExecutionResult, FlowPattern, RunError, RunId, RunStatus};
use super::{build_capability_output, execute_worker_with_arc, PatternContext, PatternExecutor};
pub struct CompeteExecutor;
impl CompeteExecutor {
pub fn new() -> Self {
CompeteExecutor
}
}
impl Default for CompeteExecutor {
fn default() -> Self {
Self::new()
}
}
struct CompetitorResult {
worker_name: String,
result: Result<ExecutionResult, RunError>,
}
#[async_trait]
impl PatternExecutor for CompeteExecutor {
fn name(&self) -> &'static str {
"compete"
}
async fn execute(
&self,
_ctx: &PatternContext,
_runtime: &dyn crate::RuntimeAdapter,
_cancel: &crate::CancellationToken,
) -> Result<ExecutionResult, RunError> {
Err(RunError::RuntimeError {
message: "CompeteExecutor requires Arc runtime. Use execute_with_arc() instead.".into(),
})
}
async fn execute_with_arc(
&self,
ctx: &PatternContext,
runtime: Arc<dyn crate::RuntimeAdapter>,
cancel: &crate::CancellationToken,
) -> Result<ExecutionResult, RunError> {
let workers = match &ctx.swarm.flow {
FlowPattern::Compete { workers } => workers.clone(),
_ => {
return Err(RunError::PatternError {
pattern: "compete".into(),
step: "flow".into(),
message: "CompeteExecutor requires Compete pattern in flow".into(),
})
}
};
if workers.is_empty() {
return Ok(ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Completed,
artifacts: vec![],
error: None,
metrics: ExecutionMetrics::default(),
output: None,
});
}
if cancel.is_cancelled().await {
return Ok(ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Cancelled,
artifacts: vec![],
error: Some(RunError::Cancelled {
reason: "Execution cancelled".into(),
}),
metrics: ExecutionMetrics::default(),
output: None,
});
}
let (result_tx, mut result_rx) = mpsc::channel::<CompetitorResult>(workers.len().max(1));
let mut tasks: JoinSet<()> = JoinSet::new();
for worker_name in workers.iter() {
let worker = ctx
.get_worker(worker_name)
.ok_or_else(|| RunError::PatternError {
pattern: "compete".into(),
step: worker_name.clone(),
message: format!("Worker '{}' not found in swarm", worker_name),
})?;
let worker = worker.clone();
let runtime_clone = runtime.clone();
let runtime_ctx = ctx.runtime_ctx.clone();
let scope = ctx.scope.clone();
let cancel_clone = cancel.clone();
let tx = result_tx.clone();
let worker_name_clone = worker_name.clone();
tasks.spawn(async move {
let result = execute_worker_with_arc(
&worker,
runtime_clone,
&runtime_ctx,
&scope,
&cancel_clone,
)
.await;
let _ = tx
.send(CompetitorResult {
worker_name: worker_name_clone,
result,
})
.await;
});
}
drop(result_tx);
let mut last_error: Option<RunError> = None;
let mut error_count: usize = 0;
while let Some(competitor_result) = result_rx.recv().await {
match competitor_result.result {
Ok(exec_result) if exec_result.status == RunStatus::Completed => {
cancel.cancel().await;
tasks.abort_all();
while result_rx.recv().await.is_some() {}
let mut final_scope = ctx.scope.clone();
if let Some(ref output) = exec_result.output {
final_scope
.add_step_output(competitor_result.worker_name.clone(), output.clone());
final_scope.add_step_output("winner".to_string(), output.clone());
}
let winning_result = ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Completed,
artifacts: exec_result.artifacts,
error: None,
metrics: exec_result.metrics,
output: exec_result.output,
};
return Ok(build_capability_output(
winning_result,
&ctx.swarm,
&final_scope,
));
}
Ok(exec_result) if exec_result.status == RunStatus::Failed => {
error_count += 1;
last_error = exec_result
.error
.or(Some(RunError::RuntimeError {
message: format!(
"Worker '{}' failed",
competitor_result.worker_name
),
}));
}
Ok(_) => {
}
Err(e) => {
error_count += 1;
last_error = Some(e);
}
}
}
while tasks.join_next().await.is_some() {}
Ok(ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Failed,
artifacts: vec![],
error: Some(last_error.unwrap_or(RunError::RuntimeError {
message: format!(
"All {} competitors failed or were cancelled without a winner",
error_count
),
})),
metrics: ExecutionMetrics::default(),
output: None,
})
}
async fn on_failure(
&self,
ctx: &mut PatternContext,
_runtime: &dyn crate::RuntimeAdapter,
failed_worker: &str,
_error: &RunError,
) -> Result<bool, RunError> {
ctx.state.failed.push(failed_worker.to_string());
Ok(true)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::template::Scope;
use crate::{
CancellationToken, ExecutionContext, FlowPattern, RuntimeKind, SwarmFile, Worker,
};
use serde_json::json;
#[test]
fn test_compete_executor_name() {
let executor = CompeteExecutor::new();
assert_eq!(executor.name(), "compete");
}
#[tokio::test]
async fn test_compete_execute_requires_arc() {
let executor = CompeteExecutor::new();
let swarm = SwarmFile::new("test", FlowPattern::Compete { workers: vec![] });
let ctx = PatternContext::new(swarm, ExecutionContext::new("ctx", RuntimeKind::Local));
let cancel = CancellationToken::new();
let runtime = crate::LocalRuntime::new();
let result = executor.execute(&ctx, &runtime, &cancel).await;
assert!(result.is_err());
match result.unwrap_err() {
RunError::RuntimeError { message } => {
assert!(
message.contains("execute_with_arc"),
"Error message should mention execute_with_arc, got: {}",
message
);
}
other => panic!("Expected RuntimeError, got: {:?}", other),
}
}
#[tokio::test]
async fn test_compete_executor_wrong_pattern() {
let executor = CompeteExecutor::new();
let swarm = SwarmFile::new("test", FlowPattern::Sequence { steps: vec![] });
let ctx = PatternContext::new(swarm, ExecutionContext::new("ctx", RuntimeKind::Local));
let cancel = CancellationToken::new();
let runtime = crate::create_runtime(RuntimeKind::Local).unwrap();
let result = executor.execute_with_arc(&ctx, runtime, &cancel).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_compete_executor_empty_workers() {
let executor = CompeteExecutor::new();
let swarm = SwarmFile::new("test", FlowPattern::Compete { workers: vec![] });
let ctx = PatternContext::new(swarm, ExecutionContext::new("ctx", RuntimeKind::Local));
let cancel = CancellationToken::new();
let runtime = crate::create_runtime(RuntimeKind::Local).unwrap();
let result = executor.execute_with_arc(&ctx, runtime, &cancel).await;
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.status, RunStatus::Completed);
}
#[tokio::test]
async fn test_compete_executor_cancellation_before_start() {
let executor = CompeteExecutor::new();
let swarm = SwarmFile::new(
"test",
FlowPattern::Compete {
workers: vec!["w1".into()],
},
)
.with_worker(Worker::new("w1", "agent.yaml"));
let ctx = PatternContext::new(swarm, ExecutionContext::new("ctx", RuntimeKind::Local));
let cancel = CancellationToken::new();
cancel.cancel().await;
let runtime = crate::create_runtime(RuntimeKind::Local).unwrap();
let result = executor.execute_with_arc(&ctx, runtime, &cancel).await;
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.status, RunStatus::Cancelled);
}
#[tokio::test]
async fn test_compete_executor_real_execution_winner() {
let executor = CompeteExecutor::new();
let swarm = SwarmFile::new(
"test",
FlowPattern::Compete {
workers: vec!["w1".into(), "w2".into()],
},
)
.with_worker(Worker::new("w1", "agent.yaml"))
.with_worker(Worker::new("w2", "agent.yaml"));
let ctx = PatternContext::new(swarm, ExecutionContext::new("ctx", RuntimeKind::Local));
let cancel = CancellationToken::new();
let runtime = crate::create_runtime(RuntimeKind::Local).unwrap();
let result = executor.execute_with_arc(&ctx, runtime, &cancel).await;
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.status, RunStatus::Completed);
}
#[tokio::test]
async fn test_compete_executor_all_fail() {
use std::io::Write;
let executor = CompeteExecutor::new();
let temp_dir = std::env::temp_dir().join("bzzz-compete-fail-test");
std::fs::create_dir_all(&temp_dir).unwrap();
let failing_spec_path = temp_dir.join("failing.yaml");
let mut file = std::fs::File::create(&failing_spec_path).unwrap();
writeln!(file, "apiVersion: v1").unwrap();
writeln!(file, "id: failing-agent").unwrap();
writeln!(file, "runtime:").unwrap();
writeln!(file, " kind: Local").unwrap();
writeln!(file, " config:").unwrap();
writeln!(file, " command: /usr/bin/false").unwrap();
drop(file);
let swarm = SwarmFile::new(
"test",
FlowPattern::Compete {
workers: vec!["fail1".into()],
},
)
.with_worker(Worker::new(
"fail1",
failing_spec_path.to_string_lossy().to_string(),
));
let ctx = PatternContext::new(swarm, ExecutionContext::new("ctx", RuntimeKind::Local));
let cancel = CancellationToken::new();
let runtime = crate::create_runtime(RuntimeKind::Local).unwrap();
let result = executor.execute_with_arc(&ctx, runtime, &cancel).await;
std::fs::remove_dir_all(&temp_dir).ok();
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.status, RunStatus::Failed);
assert!(result.error.is_some());
}
#[test]
fn test_compete_winner_output_passthrough() {
let output = json!({ "answer": 42 });
let result = ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Completed,
artifacts: vec![],
error: None,
metrics: ExecutionMetrics::default(),
output: Some(output.clone()),
};
let exec_result = ExecutionResult {
run_id: RunId::new(),
status: RunStatus::Completed,
artifacts: result.artifacts,
error: None,
metrics: result.metrics,
output: result.output, };
assert_eq!(exec_result.output, Some(output));
}
#[test]
fn test_compete_winner_scope_write() {
let winner_output = json!({ "score": 99 });
let mut scope = Scope::with_input(json!({}));
scope.add_step_output("winner".to_string(), winner_output.clone());
let data = scope.to_json();
assert_eq!(data["steps"]["winner"]["output"]["score"], json!(99));
}
}