use std::sync::Arc;
use crate::cancellation_reason::CancellationReason;
use crate::dsl::ParallelNode;
use crate::engine::{record_step_success, resolve_schema, ExecutionState};
use crate::engine_error::{EngineError, Result};
use crate::status::WorkflowStepStatus;
use crate::traits::action_executor::{ActionOutput, ActionParams, StepInfo};
use crate::traits::persistence::StepUpdate;
use crate::traits::run_context::RunContext;
use crate::types::{StepResult, StepSuccess};
pub fn execute_parallel(
state: &mut ExecutionState,
node: &ParallelNode,
iteration: u32,
) -> Result<()> {
let group_id = ulid::Ulid::new().to_string();
let pos_base = state.position;
tracing::info!(
"parallel: spawning {} agents (fail_fast={}, min_success={:?})",
node.calls.len(),
node.fail_fast,
node.min_success,
);
let block_schema = node
.output
.as_deref()
.map(|name| resolve_schema(state, name))
.transpose()?;
struct ParallelCallResult {
agent_name: String,
step_id: String,
agent_step_key: String,
result: std::result::Result<ActionOutput, EngineError>,
attempt: u32,
}
struct DispatchInput {
step_id: String,
agent_name: String,
agent_step_key: String,
run_ctx: Arc<dyn RunContext>,
info: StepInfo,
params: ActionParams,
retries: u32,
}
struct CallInput {
idx: usize,
agent_step_key: String,
call_schema: Option<crate::output_schema::OutputSchema>,
effective_with: Vec<String>,
retries: u32,
}
let mut skipped_count = 0u32;
let mut call_inputs: Vec<CallInput> = Vec::new();
for (i, agent_ref) in node.calls.iter().enumerate() {
let pos = pos_base + i as i64;
state.position = pos + 1;
let agent_label = agent_ref.label();
let agent_step_key = agent_ref.step_key();
if super::skip_if_already_completed(state, &agent_step_key, iteration, agent_label) {
skipped_count += 1;
continue;
}
let call_schema = node
.call_outputs
.get(&i.to_string())
.map(|name| resolve_schema(state, name))
.transpose()?;
let effective_schema = call_schema.as_ref().or(block_schema.as_ref()).cloned();
let effective_with = if let Some(extra) = node.call_with.get(&i.to_string()) {
let mut w = node.with.clone();
w.extend(extra.iter().cloned());
w
} else {
node.with.clone()
};
let retries = node.call_retries.get(&i.to_string()).copied().unwrap_or(0);
call_inputs.push(CallInput {
idx: i,
agent_step_key: agent_step_key.clone(),
call_schema: effective_schema,
effective_with,
retries,
});
}
let mut dispatch_queue: Vec<DispatchInput> = Vec::new();
let shared_inputs = super::build_inputs_map(state);
let scope_token = state.cancellation.child();
for call_input in call_inputs {
let CallInput {
idx: i,
agent_step_key,
call_schema,
effective_with,
retries,
} = call_input;
let pos = pos_base + i as i64;
let agent_ref = &node.calls[i];
let agent_label = agent_ref.label();
if let Some((cond_step, cond_marker)) = node.call_if.get(&i.to_string()) {
let has_marker = state
.step_results
.get(cond_step)
.map(|r| r.markers.iter().any(|m| m == cond_marker))
.unwrap_or(false);
if !has_marker {
tracing::info!(
"parallel: skipping '{}' (if={}.{} not satisfied)",
agent_label,
cond_step,
cond_marker
);
super::insert_step_with_status(
state,
agent_label,
"actor",
pos,
iteration,
None,
WorkflowStepStatus::Skipped,
Some(format!("skipped: {cond_step}.{cond_marker} not emitted")),
)?;
skipped_count += 1;
continue;
}
}
let step_id =
super::insert_step_record(state, agent_label, "actor", pos, iteration, Some(0))?;
let inputs = Arc::clone(&shared_inputs);
let info = super::build_step_info(state, &step_id);
let params = super::build_action_params(
agent_label,
inputs,
effective_with,
state.exec_config.dry_run,
state.last_gate_feedback.clone(),
call_schema,
retries,
None,
state.model.clone(),
state.default_as_identity.clone(),
state.extra_plugin_dirs.clone(),
None,
);
dispatch_queue.push(DispatchInput {
step_id,
agent_name: agent_label.to_string(),
agent_step_key,
run_ctx: Arc::clone(&state.run_ctx),
info,
params,
retries,
});
}
let (completion_tx, completion_rx) = std::sync::mpsc::channel::<(
String,
String,
String,
std::result::Result<ActionOutput, EngineError>,
u32,
)>();
for dispatch_input in dispatch_queue {
let tx = completion_tx.clone();
let registry = Arc::clone(&state.action_registry);
let scope = scope_token.clone();
std::thread::spawn(move || {
let max_attempts = 1 + dispatch_input.retries;
let mut last_error = String::new();
let mut params = dispatch_input.params;
let mut final_attempt = 0u32;
let mut result: std::result::Result<ActionOutput, EngineError> =
Err(EngineError::Workflow("no attempts made".into()));
for attempt in 0..max_attempts {
if scope.is_cancelled() {
result = Err(EngineError::Cancelled(CancellationReason::FailFast));
break;
}
params.retries_remaining = max_attempts - attempt - 1;
params.retry_error = if attempt == 0 {
None
} else {
Some(last_error.clone())
};
final_attempt = attempt;
result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
registry.dispatch(
¶ms.name,
&*dispatch_input.run_ctx,
&dispatch_input.info,
¶ms,
)
}))
.unwrap_or_else(|payload| {
let msg = if let Some(s) = payload.downcast_ref::<&str>() {
format!("executor '{}' panicked: {s}", params.name)
} else if let Some(s) = payload.downcast_ref::<String>() {
format!("executor '{}' panicked: {s}", params.name)
} else {
format!("executor '{}' panicked", params.name)
};
Err(EngineError::Workflow(msg))
});
match &result {
Ok(_) => break,
Err(EngineError::Cancelled(_)) => break,
Err(e) => {
last_error = e.to_string();
}
}
}
if let Err(e) = tx.send((
dispatch_input.step_id,
dispatch_input.agent_name,
dispatch_input.agent_step_key,
result,
final_attempt,
)) {
tracing::warn!("parallel: result channel broken (receiver dropped): {}", e);
}
});
}
drop(completion_tx);
let mut results: Vec<ParallelCallResult> = Vec::new();
loop {
match completion_rx.recv_timeout(std::time::Duration::from_millis(500)) {
Ok((step_id, agent_name, agent_step_key, result, attempt)) => {
let failed = result.is_err();
results.push(ParallelCallResult {
agent_name,
step_id,
agent_step_key,
result,
attempt,
});
if failed && node.fail_fast {
scope_token.cancel(CancellationReason::FailFast);
}
}
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => break,
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
if state.check_cancellation_throttled().is_err() {
scope_token.cancel(CancellationReason::UserRequested(None));
}
}
}
}
let mut merged_markers: Vec<String> = Vec::new();
let mut successes = 0u32;
let mut failures = 0u32;
let results_count = results.len();
for pr in results {
match pr.result {
Ok(output) => {
let markers_json = crate::helpers::serialize_or_empty_array(
&output.markers,
&format!("parallel: '{}'", pr.agent_name),
);
let context = output.context.clone().unwrap_or_default();
super::persist_completed_step(
state,
&pr.step_id,
output.child_run_id.clone(),
output.result_text.clone(),
Some(context.clone()),
Some(markers_json),
pr.attempt,
output.structured_output.clone(),
)?;
tracing::info!("parallel: '{}' completed", pr.agent_name,);
successes += 1;
merged_markers.extend(output.markers.iter().cloned());
record_step_success(
state,
pr.agent_step_key.clone(),
StepSuccess::from_action_output(
&output,
pr.agent_name.clone(),
context,
iteration,
None,
),
);
}
Err(e) => {
tracing::warn!("parallel: '{}' failed: {e}", pr.agent_name);
let generation = state.expect_lease_generation();
state.persistence.update_step(
&pr.step_id,
StepUpdate::failed(generation, e.to_string(), pr.attempt),
)?;
failures += 1;
}
}
}
let effective_successes = successes + skipped_count;
let total_agents = results_count as u32 + skipped_count;
let min_required = node.min_success.unwrap_or(total_agents);
tracing::info!(
"parallel: {successes} succeeded, {failures} failed, {skipped_count} skipped out of {total_agents} agents",
);
if effective_successes < min_required {
tracing::warn!(
"parallel: only {}/{} succeeded (min_success={})",
effective_successes,
total_agents,
min_required
);
state.all_succeeded = false;
}
let synthetic_result = StepResult {
step_name: format!("parallel:{}", group_id),
status: if effective_successes >= min_required {
WorkflowStepStatus::Completed
} else {
WorkflowStepStatus::Failed
},
result_text: None,
markers: merged_markers,
context: String::new(),
child_run_id: None,
structured_output: None,
output_file: None,
};
state
.step_results
.insert(format!("parallel:{}", group_id), synthetic_result);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dsl::{AgentRef, ParallelNode};
use crate::engine::ExecutionState;
use crate::engine_error::EngineError;
use crate::persistence_memory::InMemoryWorkflowPersistence;
use crate::status::WorkflowStepStatus;
use crate::traits::action_executor::{ActionExecutor, ActionOutput, ActionParams};
use crate::traits::item_provider::ItemProviderRegistry;
use crate::traits::persistence::WorkflowPersistence;
use crate::traits::run_context::NoopRunContext;
use crate::traits::script_env_provider::NoOpScriptEnvProvider;
use crate::types::WorkflowExecConfig;
use std::collections::HashMap;
use std::sync::Arc;
struct MarkersExecutor {
markers: Vec<String>,
context: String,
}
impl ActionExecutor for MarkersExecutor {
fn name(&self) -> &str {
"markers_exec"
}
fn execute(
&self,
_ctx: &dyn crate::traits::run_context::RunContext,
_info: &crate::traits::action_executor::StepInfo,
_params: &ActionParams,
) -> Result<ActionOutput, EngineError> {
let mut metadata = std::collections::HashMap::new();
metadata.insert(
crate::constants::metadata_keys::COST_USD.to_string(),
"0.05".to_string(),
);
Ok(ActionOutput {
markers: self.markers.clone(),
context: Some(self.context.clone()),
result_text: None,
structured_output: None,
metadata,
child_run_id: None,
})
}
}
fn make_persistence_with_run() -> (Arc<InMemoryWorkflowPersistence>, String) {
let p = Arc::new(InMemoryWorkflowPersistence::new());
let run = p
.create_run(crate::traits::persistence::NewRun {
workflow_name: "wf".to_string(),
parent_run_id: String::new(),
dry_run: false,
trigger: "manual".to_string(),
definition_snapshot: None,
parent_workflow_run_id: None,
})
.unwrap();
(p, run.id)
}
fn make_state(
persistence: Arc<InMemoryWorkflowPersistence>,
run_id: String,
registry: crate::traits::action_executor::ActionRegistry,
) -> ExecutionState {
ExecutionState {
persistence,
action_registry: Arc::new(registry),
script_env_provider: Arc::new(NoOpScriptEnvProvider),
workflow_run_id: run_id,
workflow_name: "wf".to_string(),
run_ctx: Arc::new(NoopRunContext::default())
as Arc<dyn crate::traits::run_context::RunContext>,
extra_plugin_dirs: vec![],
model: None,
exec_config: WorkflowExecConfig::default(),
inputs: HashMap::new(),
parent_run_id: String::new(),
depth: 0,
target_label: None,
step_results: HashMap::new(),
contexts: vec![],
position: 0,
all_succeeded: true,
total_cost: 0.0,
total_turns: 0,
total_duration_ms: 0,
total_input_tokens: 0,
total_output_tokens: 0,
total_cache_read_input_tokens: 0,
total_cache_creation_input_tokens: 0,
has_llm_metrics: false,
last_gate_feedback: None,
block_output: None,
block_with: vec![],
resume_ctx: None,
default_as_identity: None,
triggered_by_hook: false,
schema_resolver: None,
child_runner: None,
last_heartbeat_at: ExecutionState::new_heartbeat(),
registry: Arc::new(ItemProviderRegistry::new()),
event_sinks: Arc::from(vec![]),
cancellation: crate::cancellation::CancellationToken::new(),
current_execution_id: Arc::new(std::sync::Mutex::new(None)),
owner_token: None,
lease_generation: Some(0),
}
}
#[test]
fn parallel_actionoutput_dispatch_path_records_markers_and_context() {
let mut named = HashMap::new();
named.insert(
"markers_exec".to_string(),
Box::new(MarkersExecutor {
markers: vec!["m1".to_string(), "m2".to_string()],
context: "step context".to_string(),
}) as Box<dyn ActionExecutor>,
);
let registry = crate::traits::action_executor::ActionRegistry::new(named, None);
let (persistence, run_id) = make_persistence_with_run();
let mut state = make_state(Arc::clone(&persistence), run_id.clone(), registry);
let node = ParallelNode {
fail_fast: false,
min_success: None,
calls: vec![AgentRef::Name("markers_exec".to_string())],
output: None,
call_outputs: HashMap::new(),
with: vec![],
call_with: HashMap::new(),
call_if: HashMap::new(),
call_retries: HashMap::new(),
};
execute_parallel(&mut state, &node, 0).unwrap();
let steps = persistence.get_steps(&run_id).unwrap();
assert_eq!(steps.len(), 1, "expected one step record");
let step = &steps[0];
assert_eq!(
step.status,
WorkflowStepStatus::Completed,
"step should be Completed; got {:?}",
step.status
);
let markers: Vec<String> = step
.markers_out
.as_deref()
.and_then(|m| serde_json::from_str(m).ok())
.unwrap_or_default();
assert_eq!(
markers,
vec!["m1", "m2"],
"markers should match executor output"
);
assert_eq!(
step.context_out.as_deref(),
Some("step context"),
"context should match executor output"
);
assert!(
state.contexts.iter().any(|c| c.context == "step context"),
"executor context should be in state.contexts"
);
assert!(
state.total_cost > 0.0,
"cost should be accumulated from ActionOutput"
);
}
#[test]
fn parallel_panicking_executor_is_caught_and_step_is_failed() {
struct PanicExec;
impl ActionExecutor for PanicExec {
fn name(&self) -> &str {
"panic_exec"
}
fn execute(
&self,
_ctx: &dyn crate::traits::run_context::RunContext,
_info: &crate::traits::action_executor::StepInfo,
_params: &ActionParams,
) -> Result<ActionOutput, EngineError> {
panic!("deliberate panic in test executor");
}
}
let mut named = HashMap::new();
named.insert(
"panic_exec".to_string(),
Box::new(PanicExec) as Box<dyn ActionExecutor>,
);
let registry = crate::traits::action_executor::ActionRegistry::new(named, None);
let (persistence, run_id) = make_persistence_with_run();
let mut state = make_state(Arc::clone(&persistence), run_id.clone(), registry);
let node = ParallelNode {
fail_fast: false,
min_success: None,
calls: vec![AgentRef::Name("panic_exec".to_string())],
output: None,
call_outputs: HashMap::new(),
with: vec![],
call_with: HashMap::new(),
call_if: HashMap::new(),
call_retries: HashMap::new(),
};
execute_parallel(&mut state, &node, 0).unwrap();
let steps = persistence.get_steps(&run_id).unwrap();
assert_eq!(steps.len(), 1, "expected one step record");
let step = &steps[0];
assert_eq!(
step.status,
WorkflowStepStatus::Failed,
"panicking executor should produce a Failed step; got {:?}",
step.status
);
let error_msg = step.step_error.as_deref().unwrap_or("");
assert!(
error_msg.contains("panic_exec"),
"step_error should name the executor; got: {error_msg:?}"
);
assert!(
error_msg.contains("deliberate panic in test executor"),
"step_error should include the panic payload; got: {error_msg:?}"
);
}
#[test]
fn parallel_panicking_executor_string_payload_is_surfaced() {
struct PanicStringExec;
impl ActionExecutor for PanicStringExec {
fn name(&self) -> &str {
"panic_string_exec"
}
fn execute(
&self,
_ctx: &dyn crate::traits::run_context::RunContext,
_info: &crate::traits::action_executor::StepInfo,
_params: &ActionParams,
) -> Result<ActionOutput, EngineError> {
panic!("{}", "string payload panic".to_string())
}
}
let mut named = HashMap::new();
named.insert(
"panic_string_exec".to_string(),
Box::new(PanicStringExec) as Box<dyn ActionExecutor>,
);
let registry = crate::traits::action_executor::ActionRegistry::new(named, None);
let (persistence, run_id) = make_persistence_with_run();
let mut state = make_state(Arc::clone(&persistence), run_id.clone(), registry);
let node = ParallelNode {
fail_fast: false,
min_success: None,
calls: vec![AgentRef::Name("panic_string_exec".to_string())],
output: None,
call_outputs: HashMap::new(),
with: vec![],
call_with: HashMap::new(),
call_if: HashMap::new(),
call_retries: HashMap::new(),
};
execute_parallel(&mut state, &node, 0).unwrap();
let steps = persistence.get_steps(&run_id).unwrap();
assert_eq!(steps.len(), 1);
let error_msg = steps[0].step_error.as_deref().unwrap_or("");
assert!(
error_msg.contains("panic_string_exec"),
"step_error should name the executor; got: {error_msg:?}"
);
assert!(
error_msg.contains("string payload panic"),
"step_error should include the String panic payload; got: {error_msg:?}"
);
}
#[test]
fn parallel_panicking_executor_unknown_payload_fallback() {
struct PanicUnknownExec;
impl ActionExecutor for PanicUnknownExec {
fn name(&self) -> &str {
"panic_unknown_exec"
}
fn execute(
&self,
_ctx: &dyn crate::traits::run_context::RunContext,
_info: &crate::traits::action_executor::StepInfo,
_params: &ActionParams,
) -> Result<ActionOutput, EngineError> {
std::panic::panic_any(42i32)
}
}
let mut named = HashMap::new();
named.insert(
"panic_unknown_exec".to_string(),
Box::new(PanicUnknownExec) as Box<dyn ActionExecutor>,
);
let registry = crate::traits::action_executor::ActionRegistry::new(named, None);
let (persistence, run_id) = make_persistence_with_run();
let mut state = make_state(Arc::clone(&persistence), run_id.clone(), registry);
let node = ParallelNode {
fail_fast: false,
min_success: None,
calls: vec![AgentRef::Name("panic_unknown_exec".to_string())],
output: None,
call_outputs: HashMap::new(),
with: vec![],
call_with: HashMap::new(),
call_if: HashMap::new(),
call_retries: HashMap::new(),
};
execute_parallel(&mut state, &node, 0).unwrap();
let steps = persistence.get_steps(&run_id).unwrap();
assert_eq!(steps.len(), 1);
let error_msg = steps[0].step_error.as_deref().unwrap_or("");
assert!(
error_msg.contains("panic_unknown_exec"),
"step_error should name the executor; got: {error_msg:?}"
);
assert!(
!error_msg.contains("42"),
"step_error should NOT contain the unknown payload value; got: {error_msg:?}"
);
}
#[test]
fn parallel_fail_fast_stops_after_first_failure() {
struct FailExec;
impl ActionExecutor for FailExec {
fn name(&self) -> &str {
"fail_exec"
}
fn execute(
&self,
_ctx: &dyn crate::traits::run_context::RunContext,
_info: &crate::traits::action_executor::StepInfo,
_params: &ActionParams,
) -> Result<ActionOutput, EngineError> {
Err(EngineError::Workflow("intentional failure".to_string()))
}
}
let mut named = HashMap::new();
named.insert(
"fail_exec".to_string(),
Box::new(FailExec) as Box<dyn ActionExecutor>,
);
named.insert(
"markers_exec".to_string(),
Box::new(MarkersExecutor {
markers: vec!["ok".to_string()],
context: String::new(),
}) as Box<dyn ActionExecutor>,
);
let registry = crate::traits::action_executor::ActionRegistry::new(named, None);
let (persistence, run_id) = make_persistence_with_run();
let mut state = make_state(Arc::clone(&persistence), run_id.clone(), registry);
let node = ParallelNode {
fail_fast: true,
min_success: None,
calls: vec![
AgentRef::Name("fail_exec".to_string()),
AgentRef::Name("markers_exec".to_string()),
AgentRef::Name("markers_exec".to_string()),
],
output: None,
call_outputs: HashMap::new(),
with: vec![],
call_with: HashMap::new(),
call_if: HashMap::new(),
call_retries: HashMap::new(),
};
execute_parallel(&mut state, &node, 0).ok();
let steps = persistence.get_steps(&run_id).unwrap();
let failed = steps
.iter()
.filter(|s| s.status == WorkflowStepStatus::Failed)
.count();
assert!(
failed >= 1,
"at least one branch should be Failed; steps: {:?}",
steps
);
assert!(
!state.all_succeeded,
"all_succeeded should be false when fail_fast fires"
);
}
#[test]
fn parallel_wait_loop_polls_cancellation_during_long_branches() {
struct SleepingExecutor;
impl ActionExecutor for SleepingExecutor {
fn name(&self) -> &str {
"sleeping_exec"
}
fn execute(
&self,
_ctx: &dyn crate::traits::run_context::RunContext,
_info: &crate::traits::action_executor::StepInfo,
_params: &ActionParams,
) -> std::result::Result<ActionOutput, EngineError> {
std::thread::sleep(std::time::Duration::from_millis(1300));
Ok(ActionOutput::default())
}
}
let mut named: HashMap<String, Box<dyn ActionExecutor>> = HashMap::new();
named.insert(
"sleeping_exec".to_string(),
Box::new(SleepingExecutor) as Box<dyn ActionExecutor>,
);
let registry = crate::traits::action_executor::ActionRegistry::new(named, None);
let cp = Arc::new(crate::test_helpers::CountingPersistence::new());
let run_id = cp
.create_run(crate::traits::persistence::NewRun {
workflow_name: "wf".to_string(),
parent_run_id: String::new(),
dry_run: false,
trigger: "manual".to_string(),
definition_snapshot: None,
parent_workflow_run_id: None,
})
.unwrap()
.id;
let cp_for_state: Arc<dyn WorkflowPersistence> = Arc::clone(&cp) as _;
let mut state = crate::test_helpers::make_test_execution_state(cp_for_state, run_id);
state.action_registry = Arc::new(registry);
let node = ParallelNode {
fail_fast: false,
min_success: None,
calls: vec![AgentRef::Name("sleeping_exec".to_string())],
output: None,
call_outputs: HashMap::new(),
with: vec![],
call_with: HashMap::new(),
call_if: HashMap::new(),
call_retries: HashMap::new(),
};
execute_parallel(&mut state, &node, 0).unwrap();
}
#[test]
fn parallel_retries_on_failed_branch_succeeds_on_second_attempt() {
use std::sync::atomic::{AtomicU32, Ordering};
struct FailOnceThenSucceed {
call_count: Arc<AtomicU32>,
}
impl ActionExecutor for FailOnceThenSucceed {
fn name(&self) -> &str {
"fail_once"
}
fn execute(
&self,
_ctx: &dyn crate::traits::run_context::RunContext,
_info: &crate::traits::action_executor::StepInfo,
_params: &ActionParams,
) -> Result<ActionOutput, EngineError> {
let n = self.call_count.fetch_add(1, Ordering::SeqCst);
if n == 0 {
Err(EngineError::Workflow("first attempt fails".to_string()))
} else {
Ok(ActionOutput {
markers: vec!["retried_ok".to_string()],
..Default::default()
})
}
}
}
let call_count = Arc::new(AtomicU32::new(0));
let mut named = HashMap::new();
named.insert(
"fail_once".to_string(),
Box::new(FailOnceThenSucceed {
call_count: Arc::clone(&call_count),
}) as Box<dyn ActionExecutor>,
);
let registry = crate::traits::action_executor::ActionRegistry::new(named, None);
let (persistence, run_id) = make_persistence_with_run();
let mut state = make_state(Arc::clone(&persistence), run_id.clone(), registry);
let mut call_retries = HashMap::new();
call_retries.insert("0".to_string(), 1u32);
let node = ParallelNode {
fail_fast: false,
min_success: None,
calls: vec![AgentRef::Name("fail_once".to_string())],
output: None,
call_outputs: HashMap::new(),
with: vec![],
call_with: HashMap::new(),
call_if: HashMap::new(),
call_retries,
};
execute_parallel(&mut state, &node, 0).unwrap();
assert_eq!(
call_count.load(Ordering::SeqCst),
2,
"executor should be dispatched twice (initial + 1 retry)"
);
let steps = persistence.get_steps(&run_id).unwrap();
assert_eq!(steps.len(), 1);
assert_eq!(
steps[0].status,
WorkflowStepStatus::Completed,
"step should be Completed after successful retry; got {:?}",
steps[0].status
);
}
#[test]
fn parallel_retries_exhausted_marks_step_failed() {
use std::sync::atomic::{AtomicU32, Ordering};
struct AlwaysFail {
call_count: Arc<AtomicU32>,
}
impl ActionExecutor for AlwaysFail {
fn name(&self) -> &str {
"always_fail"
}
fn execute(
&self,
_ctx: &dyn crate::traits::run_context::RunContext,
_info: &crate::traits::action_executor::StepInfo,
_params: &ActionParams,
) -> Result<ActionOutput, EngineError> {
self.call_count.fetch_add(1, Ordering::SeqCst);
Err(EngineError::Workflow("always fails".to_string()))
}
}
let call_count = Arc::new(AtomicU32::new(0));
let mut named = HashMap::new();
named.insert(
"always_fail".to_string(),
Box::new(AlwaysFail {
call_count: Arc::clone(&call_count),
}) as Box<dyn ActionExecutor>,
);
let registry = crate::traits::action_executor::ActionRegistry::new(named, None);
let (persistence, run_id) = make_persistence_with_run();
let mut state = make_state(Arc::clone(&persistence), run_id.clone(), registry);
let mut call_retries = HashMap::new();
call_retries.insert("0".to_string(), 1u32);
let node = ParallelNode {
fail_fast: false,
min_success: None,
calls: vec![AgentRef::Name("always_fail".to_string())],
output: None,
call_outputs: HashMap::new(),
with: vec![],
call_with: HashMap::new(),
call_if: HashMap::new(),
call_retries,
};
execute_parallel(&mut state, &node, 0).unwrap();
assert_eq!(
call_count.load(Ordering::SeqCst),
2,
"executor should be dispatched twice (initial + 1 retry) before giving up"
);
let steps = persistence.get_steps(&run_id).unwrap();
assert_eq!(steps.len(), 1);
assert_eq!(
steps[0].status,
WorkflowStepStatus::Failed,
"step should be Failed after all retries exhausted; got {:?}",
steps[0].status
);
}
}