use crate::flow_dispatcher::{DispatchCtx, DispatchError, NodeOutcome};
use crate::flow_execution_event::{now_ms, FlowExecutionEvent};
use crate::ir_nodes::{IRFlowNode, IRParallelBlock};
pub async fn run_par(
_node: &IRParallelBlock,
ctx: &mut DispatchCtx,
) -> Result<NodeOutcome, DispatchError> {
if ctx.cancel.is_cancelled() {
return Err(DispatchError::UpstreamCancelled);
}
let step_index = ctx.step_counter;
ctx.step_counter += 1;
let step_name = "Par".to_string();
ctx.tx
.send(FlowExecutionEvent::StepStart {
step_name: step_name.clone(),
step_index,
step_type: "par".to_string(),
timestamp_ms: now_ms(),
})
.map_err(|_| DispatchError::ChannelClosed)?;
ctx.tx
.send(FlowExecutionEvent::StepComplete {
step_name,
step_index,
success: true,
full_output: String::new(),
tokens_input: 0,
tokens_output: 0,
timestamp_ms: now_ms(),
})
.map_err(|_| DispatchError::ChannelClosed)?;
Ok(NodeOutcome::Completed {
output: String::new(),
tokens_emitted: 0,
step_index,
})
}
pub async fn run_branches_concurrently(
branches: &[Vec<IRFlowNode>],
ctx: &mut DispatchCtx,
) -> Result<NodeOutcome, DispatchError> {
use futures::future::join_all;
if ctx.cancel.is_cancelled() {
return Err(DispatchError::UpstreamCancelled);
}
let entry_step_index = ctx.step_counter;
let mut branch_ctxs: Vec<DispatchCtx> = Vec::with_capacity(branches.len());
for (idx, _body) in branches.iter().enumerate() {
let mut bc = ctx.clone();
bc.branch_path.push(format!("par[{idx}]"));
bc.pending_effect_policy = None;
branch_ctxs.push(bc);
}
let futures: Vec<_> = branches
.iter()
.zip(branch_ctxs.iter_mut())
.map(|(body, bc)| {
let body_ref = body.as_slice();
async move { dispatch_branch_body(body_ref, bc).await }
})
.collect();
let results = join_all(futures).await;
let mut aggregate_output_parts: Vec<String> = Vec::new();
let mut aggregate_tokens: u64 = 0;
let mut return_value: Option<String> = None;
let mut max_step_counter = ctx.step_counter;
for (bc, outcome) in branch_ctxs.iter().zip(results.into_iter()) {
max_step_counter = max_step_counter.max(bc.step_counter);
match outcome {
Ok(NodeOutcome::Completed {
output,
tokens_emitted,
..
}) => {
if !output.is_empty() {
aggregate_output_parts.push(output);
}
aggregate_tokens += tokens_emitted;
}
Ok(NodeOutcome::Break) | Ok(NodeOutcome::LoopContinue) => {
}
Ok(NodeOutcome::Return { value }) => {
if return_value.is_none() {
return_value = Some(value);
}
}
Err(e) => return Err(e),
}
}
ctx.step_counter = max_step_counter;
if let Some(value) = return_value {
Ok(NodeOutcome::Return { value })
} else {
Ok(NodeOutcome::Completed {
output: aggregate_output_parts.join("\n"),
tokens_emitted: aggregate_tokens,
step_index: entry_step_index,
})
}
}
async fn dispatch_branch_body(
body: &[IRFlowNode],
ctx: &mut DispatchCtx,
) -> Result<NodeOutcome, DispatchError> {
let mut last_output = String::new();
let mut total_tokens: u64 = 0;
let entry_step_index = ctx.step_counter;
for (i, child) in body.iter().enumerate() {
if ctx.cancel.is_cancelled() {
return Err(DispatchError::UpstreamCancelled);
}
ctx.branch_path.push(format!("step[{i}]"));
let outcome = Box::pin(crate::flow_dispatcher::dispatch_node(child, ctx)).await;
ctx.branch_path.pop();
match outcome? {
NodeOutcome::Completed {
output,
tokens_emitted,
..
} => {
if !output.is_empty() {
last_output = output;
}
total_tokens += tokens_emitted;
}
NodeOutcome::Break => return Ok(NodeOutcome::Break),
NodeOutcome::LoopContinue => return Ok(NodeOutcome::LoopContinue),
NodeOutcome::Return { value } => {
return Ok(NodeOutcome::Return { value });
}
}
}
Ok(NodeOutcome::Completed {
output: last_output,
tokens_emitted: total_tokens,
step_index: entry_step_index,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cancel_token::CancellationFlag;
use crate::ir_nodes::*;
use tokio::sync::mpsc;
fn fresh_ctx() -> (
DispatchCtx,
mpsc::UnboundedReceiver<FlowExecutionEvent>,
) {
let (tx, rx) = mpsc::unbounded_channel();
let ctx = DispatchCtx::new(
"TestFlow",
"stub",
"",
CancellationFlag::new(),
tx,
);
(ctx, rx)
}
fn let_branch(target: &str, value: &str) -> Vec<IRFlowNode> {
vec![IRFlowNode::Let(IRLetBinding {
node_type: "let",
source_line: 0,
source_column: 0,
target: target.into(),
value: value.into(),
value_kind: "literal".into(),
})]
}
#[tokio::test]
async fn run_par_emits_canonical_wire_shape() {
let (mut ctx, mut rx) = fresh_ctx();
let par = IRParallelBlock {
node_type: "par",
source_line: 0,
source_column: 0,
};
let outcome = run_par(&par, &mut ctx).await.unwrap();
match outcome {
NodeOutcome::Completed {
output,
tokens_emitted,
step_index,
} => {
assert_eq!(output, "");
assert_eq!(tokens_emitted, 0);
assert_eq!(step_index, 0);
}
other => panic!("expected Completed, got {other:?}"),
}
let mut events = Vec::new();
while let Ok(ev) = rx.try_recv() {
events.push(ev);
}
assert_eq!(events.len(), 2);
match &events[0] {
FlowExecutionEvent::StepStart { step_type, .. } => {
assert_eq!(step_type, "par");
}
e => panic!("expected StepStart, got {e:?}"),
}
match &events[1] {
FlowExecutionEvent::StepComplete { tokens_output, .. } => {
assert_eq!(*tokens_output, 0);
}
e => panic!("expected StepComplete, got {e:?}"),
}
}
#[tokio::test]
async fn run_par_cancel_short_circuits() {
let cancel = CancellationFlag::new();
cancel.cancel();
let (tx, _rx) = mpsc::unbounded_channel();
let mut ctx = DispatchCtx::new("F", "stub", "", cancel, tx);
let par = IRParallelBlock {
node_type: "par",
source_line: 0,
source_column: 0,
};
assert!(matches!(
run_par(&par, &mut ctx).await,
Err(DispatchError::UpstreamCancelled)
));
}
#[tokio::test]
async fn run_branches_concurrently_two_let_branches() {
let (mut ctx, _rx) = fresh_ctx();
let branches = vec![let_branch("a", "A-value"), let_branch("b", "B-value")];
let outcome = run_branches_concurrently(&branches, &mut ctx).await.unwrap();
match outcome {
NodeOutcome::Completed {
output,
tokens_emitted,
..
} => {
assert!(
output.contains("A-value") && output.contains("B-value"),
"expected both branch outputs aggregated, got {output:?}"
);
assert_eq!(tokens_emitted, 0);
}
other => panic!("expected Completed, got {other:?}"),
}
assert!(!ctx.let_bindings.contains_key("a"));
assert!(!ctx.let_bindings.contains_key("b"));
}
#[tokio::test]
async fn run_branches_concurrently_zero_branches_returns_completed() {
let (mut ctx, _rx) = fresh_ctx();
let outcome = run_branches_concurrently(&[], &mut ctx).await.unwrap();
match outcome {
NodeOutcome::Completed {
output,
tokens_emitted,
..
} => {
assert_eq!(output, "");
assert_eq!(tokens_emitted, 0);
}
other => panic!("expected Completed, got {other:?}"),
}
}
#[tokio::test]
async fn run_branches_concurrently_propagates_return_sentinel() {
let (mut ctx, _rx) = fresh_ctx();
let branches = vec![
let_branch("a", "side"),
vec![IRFlowNode::Return(IRReturnStep {
node_type: "return",
source_line: 0,
source_column: 0,
value_expr: "computed-from-branch-1".into(),
})],
];
let outcome = run_branches_concurrently(&branches, &mut ctx).await.unwrap();
match outcome {
NodeOutcome::Return { value } => {
assert_eq!(value, "computed-from-branch-1");
}
other => panic!("expected Return propagation, got {other:?}"),
}
}
#[tokio::test]
async fn run_branches_concurrently_cancel_propagation() {
let cancel = CancellationFlag::new();
cancel.cancel();
let (tx, _rx) = mpsc::unbounded_channel();
let mut ctx = DispatchCtx::new("F", "stub", "", cancel, tx);
let branches = vec![let_branch("a", "v")];
assert!(matches!(
run_branches_concurrently(&branches, &mut ctx).await,
Err(DispatchError::UpstreamCancelled)
));
}
#[tokio::test]
async fn run_branches_concurrently_step_counter_merges_max() {
let (mut ctx, _rx) = fresh_ctx();
let branches = vec![
vec![
IRFlowNode::Step(make_step("A1")),
IRFlowNode::Step(make_step("A2")),
],
vec![IRFlowNode::Step(make_step("B1"))],
];
run_branches_concurrently(&branches, &mut ctx).await.unwrap();
assert_eq!(
ctx.step_counter, 2,
"parent counter merges to max(2, 1) = 2 post-join"
);
}
fn make_step(name: &str) -> IRStep {
IRStep {
node_type: "step",
source_line: 0,
source_column: 0,
name: name.into(),
persona_ref: String::new(),
given: String::new(),
ask: "hi".into(),
use_tool: None,
probe: None,
reason: None,
weave: None,
output_type: String::new(),
confidence_floor: None,
navigate_ref: String::new(),
apply_ref: String::new(),
body: Vec::new(),
}
}
#[tokio::test]
async fn run_branches_concurrently_branch_path_isolated_per_branch() {
let (mut ctx, _rx) = fresh_ctx();
ctx.branch_path.push("outer".into());
let branches = vec![
vec![IRFlowNode::Step(make_step("InA"))],
vec![IRFlowNode::Step(make_step("InB"))],
];
run_branches_concurrently(&branches, &mut ctx).await.unwrap();
assert_eq!(ctx.branch_path, vec!["outer".to_string()]);
let audit = ctx.step_audit_records.lock().await;
assert_eq!(audit.len(), 2);
}
}