Skip to main content

a3s_flow/
runner.rs

1//! Flow execution engine.
2//!
3//! [`FlowRunner`] takes a [`DagGraph`] and a [`NodeRegistry`], executes each
4//! node wave-by-wave, and returns a [`FlowResult`].
5//!
6//! Two execution modes are available:
7//! - [`FlowRunner::run`] — fire-and-forget: run to completion with no external control
8//! - [`FlowRunner::run_controlled`] — used by [`FlowEngine`] to support pause / resume / terminate
9
10use std::collections::{HashMap, HashSet};
11use std::sync::{Arc, RwLock};
12use std::time::Duration;
13
14use tokio::sync::Semaphore;
15
16use serde_json::Value;
17use tokio::sync::watch;
18use tokio::task::JoinSet;
19use tokio_util::sync::CancellationToken;
20use tracing::{debug, info, instrument, Instrument};
21use uuid::Uuid;
22
23use crate::error::{FlowError, Result};
24use crate::event::{EventEmitter, NoopEventEmitter};
25use crate::flow_store::FlowStore;
26use crate::graph::DagGraph;
27use crate::node::{ExecContext, Node, RetryPolicy};
28use crate::registry::NodeRegistry;
29use crate::result::FlowResult;
30
31/// Signal used to control a running execution.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub(crate) enum FlowSignal {
34    /// The flow should continue executing.
35    Run,
36    /// The flow should pause at the next wave boundary.
37    Pause,
38}
39
40/// Executes a [`DagGraph`] using registered [`Node`](crate::node::Node) implementations.
41///
42/// For lifecycle control (pause / resume / terminate), use [`FlowEngine`](crate::engine::FlowEngine)
43/// instead of constructing a `FlowRunner` directly.
44///
45/// # Example
46///
47/// ```rust,no_run
48/// use a3s_flow::{DagGraph, FlowRunner, NodeRegistry};
49/// use serde_json::json;
50///
51/// #[tokio::main]
52/// async fn main() {
53///     let def = json!({
54///         "nodes": [
55///             { "id": "start", "type": "noop" },
56///             { "id": "end",   "type": "noop" }
57///         ],
58///         "edges": [{ "source": "start", "target": "end" }]
59///     });
60///     let dag = DagGraph::from_json(&def).unwrap();
61///     let registry = NodeRegistry::with_defaults();
62///     let runner = FlowRunner::new(dag, registry);
63///     let result = runner.run(Default::default()).await.unwrap();
64///     println!("{:?}", result.outputs);
65/// }
66/// ```
67pub struct FlowRunner {
68    dag: DagGraph,
69    registry: Arc<NodeRegistry>,
70    emitter: Arc<dyn EventEmitter>,
71    flow_store: Option<Arc<dyn FlowStore>>,
72    /// When set, at most this many nodes execute concurrently within a wave.
73    max_concurrency: Option<usize>,
74}
75
76impl FlowRunner {
77    /// Create a new runner from a validated DAG and a node registry.
78    ///
79    /// Uses [`NoopEventEmitter`] by default. Call
80    /// [`.with_event_emitter`](Self::with_event_emitter) to register a custom
81    /// listener before executing.
82    pub fn new(dag: DagGraph, registry: NodeRegistry) -> Self {
83        Self {
84            dag,
85            registry: Arc::new(registry),
86            emitter: Arc::new(NoopEventEmitter),
87            flow_store: None,
88            max_concurrency: None,
89        }
90    }
91
92    /// Create a new runner sharing an existing `Arc<NodeRegistry>`.
93    ///
94    /// Used by the `"iteration"` and `"sub-flow"` nodes so that sub-flow
95    /// runners share the same registry without extra `Arc` wrapping.
96    pub fn with_arc_registry(dag: DagGraph, registry: Arc<NodeRegistry>) -> Self {
97        Self {
98            dag,
99            registry,
100            emitter: Arc::new(NoopEventEmitter),
101            flow_store: None,
102            max_concurrency: None,
103        }
104    }
105
106    /// Attach a custom event emitter to this runner.
107    ///
108    /// The emitter receives node and flow lifecycle events during execution.
109    /// Returns `self` for method chaining.
110    pub fn with_event_emitter(mut self, emitter: Arc<dyn EventEmitter>) -> Self {
111        self.emitter = emitter;
112        self
113    }
114
115    /// Attach a flow definition store to this runner.
116    ///
117    /// When set, the store is passed to every [`ExecContext`] so that nodes
118    /// like `"sub-flow"` can load named flow definitions at execution time.
119    /// Returns `self` for method chaining.
120    pub fn with_flow_store(mut self, store: Arc<dyn FlowStore>) -> Self {
121        self.flow_store = Some(store);
122        self
123    }
124
125    /// Limit the number of nodes that may execute concurrently within a single
126    /// wave.
127    ///
128    /// By default all ready nodes in a wave run in parallel. Setting
129    /// `max_concurrency` to `n` caps this using a Tokio [`Semaphore`] so that
130    /// at most `n` nodes are active at the same time. Useful when downstream
131    /// services impose rate limits.
132    ///
133    /// Returns `self` for method chaining.
134    pub fn with_max_concurrency(mut self, n: usize) -> Self {
135        self.max_concurrency = Some(n);
136        self
137    }
138
139    /// Execute the flow to completion with no external control signals.
140    #[instrument(skip(self, variables), fields(execution_id))]
141    pub async fn run(&self, variables: HashMap<String, Value>) -> Result<FlowResult> {
142        let execution_id = Uuid::new_v4();
143        tracing::Span::current().record("execution_id", execution_id.to_string());
144        // No-op signal channel and a token that is never cancelled.
145        let (_tx, rx) = watch::channel(FlowSignal::Run);
146        let cancel = CancellationToken::new();
147        let context = Arc::new(RwLock::new(HashMap::new()));
148        self.run_seeded(
149            execution_id,
150            variables,
151            rx,
152            cancel,
153            HashMap::new(),
154            HashSet::new(),
155            HashSet::new(),
156            context,
157        )
158        .await
159    }
160
161    /// Resume a flow from a prior (partial or complete) result, skipping nodes
162    /// that already have outputs in `prior`.
163    ///
164    /// A new execution ID is assigned to the resumed run. Nodes listed in
165    /// `prior.completed_nodes` are not re-executed; their outputs from `prior`
166    /// are used directly as inputs for any downstream nodes that still need to run.
167    ///
168    /// # Example
169    ///
170    /// ```rust,no_run
171    /// # use a3s_flow::{DagGraph, FlowRunner, NodeRegistry};
172    /// # use serde_json::json;
173    /// # use std::collections::HashMap;
174    /// # #[tokio::main] async fn main() {
175    /// let def = json!({ "nodes": [{ "id": "a", "type": "noop" }], "edges": [] });
176    /// let dag = DagGraph::from_json(&def).unwrap();
177    /// let runner = FlowRunner::new(dag, NodeRegistry::with_defaults());
178    /// let partial = runner.run(HashMap::new()).await.unwrap();
179    /// // Resume with the partial result — completed nodes are skipped.
180    /// let full = runner.resume_from(&partial, HashMap::new()).await.unwrap();
181    /// # }
182    /// ```
183    pub async fn resume_from(
184        &self,
185        prior: &FlowResult,
186        variables: HashMap<String, Value>,
187    ) -> Result<FlowResult> {
188        let execution_id = Uuid::new_v4();
189        let (_tx, rx) = watch::channel(FlowSignal::Run);
190        let cancel = CancellationToken::new();
191        // Seed the context from the prior run so resumed nodes see accumulated state.
192        let context = Arc::new(RwLock::new(prior.context.clone()));
193        self.run_seeded(
194            execution_id,
195            variables,
196            rx,
197            cancel,
198            prior.outputs.clone(),
199            prior.completed_nodes.clone(),
200            prior.skipped_nodes.clone(),
201            context,
202        )
203        .await
204    }
205
206    /// Execute the flow with external pause / resume / terminate control.
207    ///
208    /// `context` is a pre-built `Arc<RwLock<...>>` shared with the calling
209    /// [`FlowEngine`] so that the engine can perform CRUD on context entries
210    /// while the flow is running and read the final state afterwards.
211    ///
212    /// This is the method used by [`FlowEngine`](crate::engine::FlowEngine).
213    /// Prefer using `FlowEngine` rather than calling this directly.
214    pub(crate) async fn run_controlled(
215        &self,
216        execution_id: Uuid,
217        variables: HashMap<String, Value>,
218        signal_rx: watch::Receiver<FlowSignal>,
219        cancel: CancellationToken,
220        context: Arc<RwLock<HashMap<String, Value>>>,
221    ) -> Result<FlowResult> {
222        self.run_seeded(
223            execution_id,
224            variables,
225            signal_rx,
226            cancel,
227            HashMap::new(),
228            HashSet::new(),
229            HashSet::new(),
230            context,
231        )
232        .await
233    }
234
235    // ── Internal implementation ────────────────────────────────────────────
236
237    /// Emits flow lifecycle events around [`execute_waves`].
238    ///
239    /// `initial_*` collections seed execution for partial-resume; pass empty
240    /// collections for a fresh run.
241    #[allow(clippy::too_many_arguments)]
242    async fn run_seeded(
243        &self,
244        execution_id: Uuid,
245        variables: HashMap<String, Value>,
246        signal_rx: watch::Receiver<FlowSignal>,
247        cancel: CancellationToken,
248        initial_outputs: HashMap<String, Value>,
249        initial_completed: HashSet<String>,
250        initial_skipped: HashSet<String>,
251        context: Arc<RwLock<HashMap<String, Value>>>,
252    ) -> Result<FlowResult> {
253        info!(%execution_id, "flow execution started");
254        self.emitter.on_flow_started(execution_id).await;
255
256        let outcome = self
257            .execute_waves(
258                execution_id,
259                variables,
260                signal_rx,
261                cancel,
262                initial_outputs,
263                initial_completed,
264                initial_skipped,
265                context,
266            )
267            .await;
268
269        match &outcome {
270            Ok(result) => {
271                info!(%execution_id, "flow execution complete");
272                self.emitter.on_flow_completed(execution_id, result).await;
273            }
274            Err(FlowError::Terminated) => {
275                info!(%execution_id, "flow execution terminated");
276                self.emitter.on_flow_terminated(execution_id).await;
277            }
278            Err(e) => {
279                tracing::warn!(%execution_id, error = %e, "flow execution failed");
280                self.emitter
281                    .on_flow_failed(execution_id, &e.to_string())
282                    .await;
283            }
284        }
285
286        outcome
287    }
288
289    /// Wave-based execution engine — emits node events, no flow lifecycle events.
290    #[allow(clippy::too_many_arguments)]
291    async fn execute_waves(
292        &self,
293        execution_id: Uuid,
294        variables: HashMap<String, Value>,
295        mut signal_rx: watch::Receiver<FlowSignal>,
296        cancel: CancellationToken,
297        initial_outputs: HashMap<String, Value>,
298        initial_completed: HashSet<String>,
299        initial_skipped: HashSet<String>,
300        context: Arc<RwLock<HashMap<String, Value>>>,
301    ) -> Result<FlowResult> {
302        // `variables` is mutable so that `"assign"` nodes can inject new values
303        // into the running variable scope between waves.
304        let mut variables = variables;
305        let mut outputs = initial_outputs;
306        let mut completed = initial_completed;
307        // Nodes whose `run_if` evaluated to false — used to propagate skips.
308        let mut skipped = initial_skipped;
309        // Only include nodes that haven't completed yet.
310        let mut remaining: Vec<String> = self
311            .dag
312            .nodes_in_order()
313            .map(|n| n.id.clone())
314            .filter(|id| !completed.contains(id))
315            .collect();
316
317        while !remaining.is_empty() {
318            // ── Pause / cancel checkpoint (between waves) ──────────────────
319            loop {
320                if cancel.is_cancelled() {
321                    return Err(FlowError::Terminated);
322                }
323                // Copy the signal value before matching so the borrow is
324                // released before we call signal_rx.changed() below.
325                let signal = *signal_rx.borrow();
326                match signal {
327                    FlowSignal::Run => break,
328                    FlowSignal::Pause => {
329                        // Block until the signal changes or we are cancelled.
330                        tokio::select! {
331                            _ = signal_rx.changed() => continue,
332                            _ = cancel.cancelled()  => return Err(FlowError::Terminated),
333                        }
334                    }
335                }
336            }
337
338            // ── Find nodes ready to run ────────────────────────────────────
339            let (ready, not_ready): (Vec<_>, Vec<_>) = remaining.into_iter().partition(|id| {
340                self.dag
341                    .dependencies_of(id)
342                    .iter()
343                    .all(|dep| completed.contains(dep))
344            });
345
346            if ready.is_empty() {
347                return Err(FlowError::Internal(
348                    "execution stalled: no nodes are ready but not all nodes are done".into(),
349                ));
350            }
351
352            remaining = not_ready;
353
354            // ── Collect assign-node IDs before consuming `ready` ──────────
355            // After the wave completes, these nodes' outputs are merged into
356            // the live variable map so that downstream nodes see the new values.
357            let assign_node_ids: Vec<String> = ready
358                .iter()
359                .filter(|id| {
360                    self.dag
361                        .nodes
362                        .get(*id)
363                        .map(|n| n.write_to_variables)
364                        .unwrap_or(false)
365                })
366                .cloned()
367                .collect();
368
369            // ── Concurrency limiter for this wave ─────────────────────────
370            let semaphore = self.max_concurrency.map(|n| Arc::new(Semaphore::new(n)));
371
372            // ── Launch ready nodes concurrently ───────────────────────────
373            let mut join_set: JoinSet<(String, Result<Value>)> = JoinSet::new();
374
375            for node_id in ready {
376                let node_def = self.dag.nodes[&node_id].clone();
377
378                // Check run_if guard: if the condition fails, skip this node.
379                if let Some(ref cond) = node_def.run_if {
380                    if !cond.evaluate(&outputs, &skipped) {
381                        debug!(%node_id, "node skipped (run_if condition false)");
382                        self.emitter.on_node_skipped(execution_id, &node_id).await;
383                        outputs.insert(node_id.clone(), Value::Null);
384                        skipped.insert(node_id.clone());
385                        completed.insert(node_id);
386                        continue;
387                    }
388                }
389
390                let node = self.registry.get(&node_def.node_type)?;
391
392                let inputs: HashMap<String, Value> = self
393                    .dag
394                    .dependencies_of(&node_id)
395                    .iter()
396                    .filter_map(|dep| outputs.get(dep).map(|v| (dep.clone(), v.clone())))
397                    .collect();
398
399                let ctx = ExecContext {
400                    data: node_def.data.clone(),
401                    inputs,
402                    variables: variables.clone(),
403                    context: Arc::clone(&context),
404                    registry: Arc::clone(&self.registry),
405                    flow_store: self.flow_store.clone(),
406                };
407
408                let retry = node_def.retry.clone();
409                let timeout_ms = node_def.timeout_ms;
410                let continue_on_error = node_def.continue_on_error;
411                let emitter = Arc::clone(&self.emitter);
412                let sem = semaphore.clone();
413
414                debug!(
415                    %node_id,
416                    node_type = %node_def.node_type,
417                    retry = ?retry.as_ref().map(|r| r.max_attempts),
418                    timeout_ms,
419                    continue_on_error,
420                    "executing node"
421                );
422
423                // ── Per-node OTel-compatible span ──────────────────────────
424                let span = tracing::info_span!(
425                    "node.execute",
426                    node_id = node_id.as_str(),
427                    node_type = node_def.node_type.as_str(),
428                    %execution_id,
429                );
430
431                join_set.spawn(
432                    async move {
433                        // Acquire concurrency permit inside the task so all
434                        // tasks are spawned immediately but only `max_concurrency`
435                        // run at the same time. The permit is released on drop.
436                        let _permit = if let Some(ref s) = sem {
437                            Some(Arc::clone(s).acquire_owned().await.ok())
438                        } else {
439                            None
440                        };
441
442                        emitter
443                            .on_node_started(execution_id, &node_id, &node_def.node_type)
444                            .await;
445
446                        let result: Result<Value> =
447                            execute_with_policy(node, ctx, retry, timeout_ms)
448                                .await
449                                .map_err(|e| FlowError::NodeFailed {
450                                    node_id: node_id.clone(),
451                                    execution_id,
452                                    reason: e.to_string(),
453                                });
454
455                        // If continue_on_error is set, absorb failure and emit
456                        // a completed event with an `__error__` sentinel output.
457                        let result: Result<Value> = if continue_on_error {
458                            result
459                                .or_else(|e| Ok(serde_json::json!({ "__error__": e.to_string() })))
460                        } else {
461                            result
462                        };
463
464                        match &result {
465                            Ok(v) => {
466                                emitter.on_node_completed(execution_id, &node_id, v).await;
467                            }
468                            Err(e) => {
469                                emitter
470                                    .on_node_failed(execution_id, &node_id, &e.to_string())
471                                    .await;
472                            }
473                        }
474
475                        (node_id, result)
476                    }
477                    .instrument(span),
478                );
479            }
480
481            // ── Collect results (cancel-aware) ─────────────────────────────
482            loop {
483                tokio::select! {
484                    // Termination signal takes priority over pending node results.
485                    _ = cancel.cancelled() => {
486                        // Remaining tasks are aborted when join_set is dropped.
487                        return Err(FlowError::Terminated);
488                    }
489                    maybe = join_set.join_next() => {
490                        match maybe {
491                            None => break, // all nodes in this wave done
492                            Some(Ok((node_id, Ok(value)))) => {
493                                debug!(%node_id, "node completed");
494                                outputs.insert(node_id.clone(), value);
495                                completed.insert(node_id);
496                            }
497                            Some(Ok((_, Err(e)))) => return Err(e),
498                            Some(Err(join_err)) if join_err.is_cancelled() => {
499                                return Err(FlowError::Terminated);
500                            }
501                            Some(Err(e)) => return Err(FlowError::Internal(e.to_string())),
502                        }
503                    }
504                }
505            }
506
507            // ── Merge assign-node outputs into the live variable map ───────
508            // Only non-error outputs are merged (skip `continue_on_error` sentinels).
509            for node_id in &assign_node_ids {
510                if let Some(Value::Object(obj)) = outputs.get(node_id) {
511                    if !obj.contains_key("__error__") {
512                        for (k, v) in obj {
513                            variables.insert(k.clone(), v.clone());
514                        }
515                    }
516                }
517            }
518        }
519
520        let context_snapshot = context.read().unwrap().clone();
521        Ok(FlowResult {
522            execution_id,
523            outputs,
524            completed_nodes: completed,
525            skipped_nodes: skipped,
526            context: context_snapshot,
527        })
528    }
529}
530
531// ── Node execution helper ──────────────────────────────────────────────────
532
533/// Execute a node with optional retry and timeout policies.
534///
535/// - Retries up to `retry.max_attempts` times (first attempt included).
536/// - Each retry waits `backoff_ms * 2^(attempt-1)` ms (capped at 64× base).
537/// - Each individual attempt is bounded by `timeout_ms` if set.
538async fn execute_with_policy(
539    node: Arc<dyn Node>,
540    ctx: ExecContext,
541    retry: Option<RetryPolicy>,
542    timeout_ms: Option<u64>,
543) -> Result<Value> {
544    let max_attempts = retry.as_ref().map(|r| r.max_attempts.max(1)).unwrap_or(1);
545    let backoff_ms = retry.as_ref().map(|r| r.backoff_ms).unwrap_or(0);
546
547    let mut last_err = FlowError::Internal("no attempts made".into());
548
549    for attempt in 0..max_attempts {
550        if attempt > 0 && backoff_ms > 0 {
551            // Exponential backoff: base * 2^(attempt-1), capped at base * 64.
552            let multiplier = 1u64 << (attempt - 1).min(6);
553            let delay = backoff_ms.saturating_mul(multiplier);
554            tokio::time::sleep(Duration::from_millis(delay)).await;
555        }
556
557        let fut = node.execute(ctx.clone());
558
559        let result = if let Some(ms) = timeout_ms {
560            tokio::time::timeout(Duration::from_millis(ms), fut)
561                .await
562                .unwrap_or_else(|_| Err(FlowError::Internal(format!("timed out after {ms}ms"))))
563        } else {
564            fut.await
565        };
566
567        match result {
568            Ok(v) => return Ok(v),
569            Err(e) => last_err = e,
570        }
571    }
572
573    Err(last_err)
574}
575
576#[cfg(test)]
577mod tests {
578    use super::*;
579    use crate::graph::DagGraph;
580    use crate::registry::NodeRegistry;
581    use serde_json::json;
582
583    #[tokio::test]
584    async fn runs_linear_flow() {
585        let def = json!({
586            "nodes": [
587                { "id": "a", "type": "noop" },
588                { "id": "b", "type": "noop" },
589                { "id": "c", "type": "noop" }
590            ],
591            "edges": [
592                { "source": "a", "target": "b" },
593                { "source": "b", "target": "c" }
594            ]
595        });
596        let dag = DagGraph::from_json(&def).unwrap();
597        let registry = NodeRegistry::with_defaults();
598        let runner = FlowRunner::new(dag, registry);
599        let result = runner.run(HashMap::new()).await.unwrap();
600
601        assert!(result.outputs.contains_key("a"));
602        assert!(result.outputs.contains_key("b"));
603        assert!(result.outputs.contains_key("c"));
604    }
605
606    #[tokio::test]
607    async fn runs_parallel_fan_out() {
608        let def = json!({
609            "nodes": [
610                { "id": "start", "type": "noop" },
611                { "id": "b",     "type": "noop" },
612                { "id": "c",     "type": "noop" },
613                { "id": "end",   "type": "noop" }
614            ],
615            "edges": [
616                { "source": "start", "target": "b" },
617                { "source": "start", "target": "c" },
618                { "source": "b",     "target": "end" },
619                { "source": "c",     "target": "end" }
620            ]
621        });
622        let dag = DagGraph::from_json(&def).unwrap();
623        let registry = NodeRegistry::with_defaults();
624        let runner = FlowRunner::new(dag, registry);
625        let result = runner.run(HashMap::new()).await.unwrap();
626        assert_eq!(result.outputs.len(), 4);
627    }
628
629    #[tokio::test]
630    async fn variables_available_in_context() {
631        let def = json!({ "nodes": [{ "id": "only", "type": "noop" }], "edges": [] });
632        let dag = DagGraph::from_json(&def).unwrap();
633        let registry = NodeRegistry::with_defaults();
634        let runner = FlowRunner::new(dag, registry);
635
636        let vars = HashMap::from([("env".into(), json!("production"))]);
637        let result = runner.run(vars).await.unwrap();
638        assert!(result.outputs.contains_key("only"));
639    }
640
641    #[tokio::test]
642    async fn run_if_skips_node_when_if_else_falls_to_else() {
643        // "route" if-else: data == 999 → no match → branch = "else"
644        // "process" run_if checks branch == "hit" → skipped
645        let def = json!({
646            "nodes": [
647                { "id": "data", "type": "noop" },
648                {
649                    "id": "route", "type": "if-else",
650                    "data": { "cases": [{ "id": "hit", "conditions": [{ "from": "data", "path": "", "op": "eq", "value": 999 }] }] }
651                },
652                {
653                    "id": "process", "type": "noop",
654                    "data": { "run_if": { "from": "route", "path": "branch", "op": "eq", "value": "hit" } }
655                }
656            ],
657            "edges": [
658                { "source": "data",  "target": "route" },
659                { "source": "route", "target": "process" }
660            ]
661        });
662        let dag = DagGraph::from_json(&def).unwrap();
663        let runner = FlowRunner::new(dag, NodeRegistry::with_defaults());
664        let result = runner.run(HashMap::new()).await.unwrap();
665
666        assert_eq!(result.outputs["process"], json!(null));
667    }
668
669    #[tokio::test]
670    async fn run_if_executes_node_when_if_else_matches() {
671        // noop outputs {} — if-else matches {} == {} → branch = "hit"
672        let def = json!({
673            "nodes": [
674                { "id": "src", "type": "noop" },
675                {
676                    "id": "gate", "type": "if-else",
677                    "data": { "cases": [{ "id": "hit", "conditions": [{ "from": "src", "path": "", "op": "eq", "value": {} }] }] }
678                },
679                {
680                    "id": "sink", "type": "noop",
681                    "data": { "run_if": { "from": "gate", "path": "branch", "op": "eq", "value": "hit" } }
682                }
683            ],
684            "edges": [
685                { "source": "src",  "target": "gate" },
686                { "source": "gate", "target": "sink" }
687            ]
688        });
689        let dag = DagGraph::from_json(&def).unwrap();
690        let runner = FlowRunner::new(dag, NodeRegistry::with_defaults());
691        let result = runner.run(HashMap::new()).await.unwrap();
692
693        assert!(result.outputs["sink"].is_object());
694        assert_ne!(result.outputs["sink"], json!(null));
695    }
696
697    #[tokio::test]
698    async fn skip_propagates_through_chain() {
699        // A → B (run_if fails on missing field) → C (run_if on B which is in skipped set)
700        let def = json!({
701            "nodes": [
702                { "id": "a", "type": "noop" },
703                {
704                    "id": "b", "type": "noop",
705                    "data": { "run_if": { "from": "a", "path": "nonexistent_field", "op": "eq", "value": true } }
706                },
707                {
708                    "id": "c", "type": "noop",
709                    "data": { "run_if": { "from": "b", "path": "x", "op": "eq", "value": 1 } }
710                }
711            ],
712            "edges": [
713                { "source": "a", "target": "b" },
714                { "source": "b", "target": "c" }
715            ]
716        });
717        let dag = DagGraph::from_json(&def).unwrap();
718        let runner = FlowRunner::new(dag, NodeRegistry::with_defaults());
719        let result = runner.run(HashMap::new()).await.unwrap();
720
721        assert_eq!(result.outputs["b"], json!(null));
722        assert_eq!(result.outputs["c"], json!(null));
723    }
724
725    #[tokio::test]
726    async fn if_else_with_variable_aggregator_fan_in() {
727        // route → path_ok (run if "ok") / path_err (run if "else") → merge
728        let def = json!({
729            "nodes": [
730                { "id": "src", "type": "noop" },
731                {
732                    "id": "route", "type": "if-else",
733                    "data": { "cases": [{ "id": "ok", "conditions": [{ "from": "src", "path": "", "op": "eq", "value": {} }] }] }
734                },
735                {
736                    "id": "path_ok", "type": "noop",
737                    "data": { "run_if": { "from": "route", "path": "branch", "op": "eq", "value": "ok" } }
738                },
739                {
740                    "id": "path_err", "type": "noop",
741                    "data": { "run_if": { "from": "route", "path": "branch", "op": "eq", "value": "else" } }
742                },
743                {
744                    "id": "merge", "type": "variable-aggregator",
745                    "data": { "inputs": ["path_ok", "path_err"] }
746                }
747            ],
748            "edges": [
749                { "source": "src",      "target": "route" },
750                { "source": "route",    "target": "path_ok" },
751                { "source": "route",    "target": "path_err" },
752                { "source": "path_ok",  "target": "merge" },
753                { "source": "path_err", "target": "merge" }
754            ]
755        });
756        let dag = DagGraph::from_json(&def).unwrap();
757        let runner = FlowRunner::new(dag, NodeRegistry::with_defaults());
758        let result = runner.run(HashMap::new()).await.unwrap();
759
760        // path_ok ran (src == {}), path_err was skipped → merge returns path_ok's output
761        assert!(!result.outputs["merge"]["output"].is_null());
762        assert_eq!(result.outputs["path_err"], json!(null));
763    }
764
765    // ── completed_nodes / skipped_nodes tracking ───────────────────────────
766
767    #[tokio::test]
768    async fn completed_nodes_tracks_all_executed_nodes() {
769        let def = json!({
770            "nodes": [
771                { "id": "a", "type": "noop" },
772                { "id": "b", "type": "noop" }
773            ],
774            "edges": [{ "source": "a", "target": "b" }]
775        });
776        let dag = DagGraph::from_json(&def).unwrap();
777        let runner = FlowRunner::new(dag, NodeRegistry::with_defaults());
778        let result = runner.run(HashMap::new()).await.unwrap();
779
780        assert!(result.completed_nodes.contains("a"));
781        assert!(result.completed_nodes.contains("b"));
782        assert!(result.skipped_nodes.is_empty());
783    }
784
785    #[tokio::test]
786    async fn skipped_nodes_tracks_run_if_skipped_nodes() {
787        // "a" → "b" with run_if that always fails → "b" is skipped
788        let def = json!({
789            "nodes": [
790                { "id": "a", "type": "noop" },
791                {
792                    "id": "b", "type": "noop",
793                    "data": { "run_if": { "from": "a", "path": "nonexistent", "op": "eq", "value": true } }
794                }
795            ],
796            "edges": [{ "source": "a", "target": "b" }]
797        });
798        let dag = DagGraph::from_json(&def).unwrap();
799        let runner = FlowRunner::new(dag, NodeRegistry::with_defaults());
800        let result = runner.run(HashMap::new()).await.unwrap();
801
802        assert!(result.completed_nodes.contains("a"));
803        assert!(result.completed_nodes.contains("b"));
804        assert!(result.skipped_nodes.contains("b"));
805        assert!(!result.skipped_nodes.contains("a"));
806    }
807
808    // ── retry policy ───────────────────────────────────────────────────────
809
810    #[tokio::test]
811    async fn retry_succeeds_after_transient_failures() {
812        use crate::node::{ExecContext, Node};
813        use async_trait::async_trait;
814        use std::sync::atomic::{AtomicU32, Ordering};
815
816        // Fails twice, succeeds on the third attempt.
817        struct FlakyNode {
818            call_count: Arc<AtomicU32>,
819        }
820
821        #[async_trait]
822        impl Node for FlakyNode {
823            fn node_type(&self) -> &str {
824                "flaky"
825            }
826
827            async fn execute(&self, _ctx: ExecContext) -> Result<Value> {
828                let n = self.call_count.fetch_add(1, Ordering::SeqCst) + 1;
829                if n < 3 {
830                    Err(FlowError::Internal(format!("transient failure #{n}")))
831                } else {
832                    Ok(json!({ "ok": true }))
833                }
834            }
835        }
836
837        let call_count = Arc::new(AtomicU32::new(0));
838        let mut registry = NodeRegistry::with_defaults();
839        registry.register(Arc::new(FlakyNode {
840            call_count: Arc::clone(&call_count),
841        }));
842
843        let def = json!({
844            "nodes": [{
845                "id": "step",
846                "type": "flaky",
847                "data": { "retry": { "max_attempts": 3, "backoff_ms": 0 } }
848            }],
849            "edges": []
850        });
851        let dag = DagGraph::from_json(&def).unwrap();
852        let runner = FlowRunner::new(dag, registry);
853        let result = runner.run(HashMap::new()).await.unwrap();
854
855        assert_eq!(result.outputs["step"]["ok"], json!(true));
856        assert_eq!(call_count.load(Ordering::SeqCst), 3);
857    }
858
859    #[tokio::test]
860    async fn retry_exhausted_returns_last_error() {
861        use crate::node::{ExecContext, Node};
862        use async_trait::async_trait;
863
864        // Always fails.
865        struct AlwaysFailNode;
866
867        #[async_trait]
868        impl Node for AlwaysFailNode {
869            fn node_type(&self) -> &str {
870                "always-fail"
871            }
872
873            async fn execute(&self, _ctx: ExecContext) -> Result<Value> {
874                Err(FlowError::Internal("permanent failure".into()))
875            }
876        }
877
878        let mut registry = NodeRegistry::with_defaults();
879        registry.register(Arc::new(AlwaysFailNode));
880
881        let def = json!({
882            "nodes": [{
883                "id": "step",
884                "type": "always-fail",
885                "data": { "retry": { "max_attempts": 2, "backoff_ms": 0 } }
886            }],
887            "edges": []
888        });
889        let dag = DagGraph::from_json(&def).unwrap();
890        let runner = FlowRunner::new(dag, registry);
891        let err = runner.run(HashMap::new()).await.unwrap_err();
892
893        assert!(matches!(err, FlowError::NodeFailed { .. }));
894        let msg = err.to_string();
895        assert!(msg.contains("permanent failure"));
896    }
897
898    // ── timeout ────────────────────────────────────────────────────────────
899
900    #[tokio::test]
901    async fn timeout_kills_slow_node() {
902        use crate::node::{ExecContext, Node};
903        use async_trait::async_trait;
904
905        struct SlowNode;
906
907        #[async_trait]
908        impl Node for SlowNode {
909            fn node_type(&self) -> &str {
910                "slow-timeout"
911            }
912
913            async fn execute(&self, _ctx: ExecContext) -> Result<Value> {
914                tokio::time::sleep(Duration::from_millis(500)).await;
915                Ok(json!({}))
916            }
917        }
918
919        let mut registry = NodeRegistry::with_defaults();
920        registry.register(Arc::new(SlowNode));
921
922        // timeout_ms (50ms) is well below node sleep (500ms).
923        let def = json!({
924            "nodes": [{
925                "id": "step",
926                "type": "slow-timeout",
927                "data": { "timeout_ms": 50 }
928            }],
929            "edges": []
930        });
931        let dag = DagGraph::from_json(&def).unwrap();
932        let runner = FlowRunner::new(dag, registry);
933        let err = runner.run(HashMap::new()).await.unwrap_err();
934
935        assert!(matches!(err, FlowError::NodeFailed { .. }));
936        assert!(err.to_string().contains("timed out"));
937    }
938
939    #[tokio::test]
940    async fn timeout_does_not_affect_fast_node() {
941        // noop is instant — a 200ms timeout should never trigger.
942        let def = json!({
943            "nodes": [{
944                "id": "step",
945                "type": "noop",
946                "data": { "timeout_ms": 200 }
947            }],
948            "edges": []
949        });
950        let dag = DagGraph::from_json(&def).unwrap();
951        let runner = FlowRunner::new(dag, NodeRegistry::with_defaults());
952        let result = runner.run(HashMap::new()).await.unwrap();
953        assert!(result.outputs.contains_key("step"));
954    }
955
956    // ── partial execution resume ────────────────────────────────────────────
957
958    #[tokio::test]
959    async fn resume_from_skips_already_completed_nodes() {
960        use crate::node::{ExecContext, Node};
961        use async_trait::async_trait;
962        use std::sync::atomic::{AtomicU32, Ordering};
963
964        // Counts how many times it is called.
965        struct CountingNode {
966            call_count: Arc<AtomicU32>,
967        }
968
969        #[async_trait]
970        impl Node for CountingNode {
971            fn node_type(&self) -> &str {
972                "counting"
973            }
974
975            async fn execute(&self, _ctx: ExecContext) -> Result<Value> {
976                self.call_count.fetch_add(1, Ordering::SeqCst);
977                Ok(json!({ "counted": true }))
978            }
979        }
980
981        let count_a = Arc::new(AtomicU32::new(0));
982        let count_b = Arc::new(AtomicU32::new(0));
983        let mut registry = NodeRegistry::with_defaults();
984        registry.register(Arc::new(CountingNode {
985            call_count: Arc::clone(&count_a),
986        }));
987
988        // We can't register two distinct "counting" nodes, so use noop for b.
989        let def = json!({
990            "nodes": [
991                { "id": "a", "type": "counting" },
992                { "id": "b", "type": "noop" }
993            ],
994            "edges": [{ "source": "a", "target": "b" }]
995        });
996
997        let dag = DagGraph::from_json(&def).unwrap();
998        let _ = count_b; // unused — b is noop
999        let runner = FlowRunner::new(dag, registry);
1000
1001        // Full first run — counting node executes once.
1002        let first = runner.run(HashMap::new()).await.unwrap();
1003        assert_eq!(count_a.load(Ordering::SeqCst), 1);
1004
1005        // Resume: "a" is already completed — should NOT re-execute.
1006        let resumed = runner.resume_from(&first, HashMap::new()).await.unwrap();
1007        assert_eq!(count_a.load(Ordering::SeqCst), 1); // still 1
1008        assert!(resumed.outputs.contains_key("a"));
1009        assert!(resumed.outputs.contains_key("b"));
1010    }
1011
1012    #[tokio::test]
1013    async fn resume_from_runs_only_pending_nodes() {
1014        // Simulate a partial result where only "a" has completed.
1015        // "b" has not run yet.  resume_from should run "b" only.
1016        use crate::node::{ExecContext, Node};
1017        use async_trait::async_trait;
1018        use std::sync::atomic::{AtomicU32, Ordering};
1019
1020        struct CountNode(Arc<AtomicU32>);
1021
1022        #[async_trait]
1023        impl Node for CountNode {
1024            fn node_type(&self) -> &str {
1025                "count-b"
1026            }
1027            async fn execute(&self, _ctx: ExecContext) -> Result<Value> {
1028                self.0.fetch_add(1, Ordering::SeqCst);
1029                Ok(json!({ "ran": true }))
1030            }
1031        }
1032
1033        let count_b = Arc::new(AtomicU32::new(0));
1034        let mut registry = NodeRegistry::with_defaults();
1035        registry.register(Arc::new(CountNode(Arc::clone(&count_b))));
1036
1037        let def = json!({
1038            "nodes": [
1039                { "id": "a", "type": "noop" },
1040                { "id": "b", "type": "count-b" }
1041            ],
1042            "edges": [{ "source": "a", "target": "b" }]
1043        });
1044        let dag = DagGraph::from_json(&def).unwrap();
1045        let runner = FlowRunner::new(dag, registry);
1046
1047        // Build a partial result where only "a" is done.
1048        let partial = FlowResult {
1049            execution_id: uuid::Uuid::new_v4(),
1050            outputs: HashMap::from([("a".into(), json!({}))]),
1051            completed_nodes: HashSet::from(["a".into()]),
1052            skipped_nodes: HashSet::new(),
1053            context: HashMap::new(),
1054        };
1055
1056        let result = runner.resume_from(&partial, HashMap::new()).await.unwrap();
1057        assert_eq!(count_b.load(Ordering::SeqCst), 1);
1058        assert!(result.outputs["b"]["ran"].as_bool().unwrap());
1059
1060        // Resuming a fully-completed result should not re-run any node.
1061        let full = runner.run(HashMap::new()).await.unwrap();
1062        count_b.store(0, Ordering::SeqCst);
1063        let _ = runner.resume_from(&full, HashMap::new()).await.unwrap();
1064        assert_eq!(count_b.load(Ordering::SeqCst), 0);
1065
1066        let _ = partial; // suppress unused warning
1067    }
1068
1069    // ── continue_on_error ──────────────────────────────────────────────────
1070
1071    #[tokio::test]
1072    async fn continue_on_error_keeps_flow_running_after_node_failure() {
1073        use crate::node::{ExecContext, Node};
1074        use async_trait::async_trait;
1075
1076        struct FailNode;
1077
1078        #[async_trait]
1079        impl Node for FailNode {
1080            fn node_type(&self) -> &str {
1081                "always-fail-coe"
1082            }
1083            async fn execute(&self, _: ExecContext) -> Result<Value> {
1084                Err(FlowError::Internal("boom".into()))
1085            }
1086        }
1087
1088        let mut registry = NodeRegistry::with_defaults();
1089        registry.register(Arc::new(FailNode));
1090
1091        let def = json!({
1092            "nodes": [
1093                {
1094                    "id": "fail",
1095                    "type": "always-fail-coe",
1096                    "data": { "continue_on_error": true }
1097                },
1098                { "id": "after", "type": "noop" }
1099            ],
1100            "edges": [{ "source": "fail", "target": "after" }]
1101        });
1102
1103        let dag = DagGraph::from_json(&def).unwrap();
1104        let result = FlowRunner::new(dag, registry)
1105            .run(HashMap::new())
1106            .await
1107            .unwrap();
1108
1109        // "fail" should have an __error__ key in its output.
1110        assert!(result.outputs["fail"]["__error__"].is_string());
1111        // "after" should still have run.
1112        assert!(result.completed_nodes.contains("after"));
1113    }
1114
1115    #[tokio::test]
1116    async fn continue_on_error_false_halts_flow_on_failure() {
1117        use crate::node::{ExecContext, Node};
1118        use async_trait::async_trait;
1119
1120        struct FailNode2;
1121
1122        #[async_trait]
1123        impl Node for FailNode2 {
1124            fn node_type(&self) -> &str {
1125                "always-fail-halt"
1126            }
1127            async fn execute(&self, _: ExecContext) -> Result<Value> {
1128                Err(FlowError::Internal("halt".into()))
1129            }
1130        }
1131
1132        let mut registry = NodeRegistry::with_defaults();
1133        registry.register(Arc::new(FailNode2));
1134
1135        let def = json!({
1136            "nodes": [
1137                { "id": "fail", "type": "always-fail-halt" },
1138                { "id": "after", "type": "noop" }
1139            ],
1140            "edges": [{ "source": "fail", "target": "after" }]
1141        });
1142
1143        let dag = DagGraph::from_json(&def).unwrap();
1144        let err = FlowRunner::new(dag, registry)
1145            .run(HashMap::new())
1146            .await
1147            .unwrap_err();
1148
1149        assert!(matches!(err, FlowError::NodeFailed { .. }));
1150    }
1151
1152    // ── max_concurrency ────────────────────────────────────────────────────
1153
1154    #[tokio::test]
1155    async fn max_concurrency_limits_parallel_execution() {
1156        use crate::node::{ExecContext, Node};
1157        use async_trait::async_trait;
1158        use std::sync::atomic::{AtomicU32, Ordering};
1159
1160        // Tracks the peak number of concurrently-running nodes.
1161        let active = Arc::new(AtomicU32::new(0));
1162        let peak = Arc::new(AtomicU32::new(0));
1163
1164        struct PeakNode {
1165            active: Arc<AtomicU32>,
1166            peak: Arc<AtomicU32>,
1167        }
1168
1169        #[async_trait]
1170        impl Node for PeakNode {
1171            fn node_type(&self) -> &str {
1172                "peak-tracker"
1173            }
1174            async fn execute(&self, _: ExecContext) -> Result<Value> {
1175                let current = self.active.fetch_add(1, Ordering::SeqCst) + 1;
1176                // Update peak.
1177                let mut prev = self.peak.load(Ordering::SeqCst);
1178                while current > prev {
1179                    match self.peak.compare_exchange_weak(
1180                        prev,
1181                        current,
1182                        Ordering::SeqCst,
1183                        Ordering::SeqCst,
1184                    ) {
1185                        Ok(_) => break,
1186                        Err(actual) => prev = actual,
1187                    }
1188                }
1189                tokio::time::sleep(Duration::from_millis(20)).await;
1190                self.active.fetch_sub(1, Ordering::SeqCst);
1191                Ok(json!({}))
1192            }
1193        }
1194
1195        let mut registry = NodeRegistry::with_defaults();
1196        registry.register(Arc::new(PeakNode {
1197            active: Arc::clone(&active),
1198            peak: Arc::clone(&peak),
1199        }));
1200
1201        // 5 independent nodes, max_concurrency = 2.
1202        let def = json!({
1203            "nodes": [
1204                { "id": "n1", "type": "peak-tracker" },
1205                { "id": "n2", "type": "peak-tracker" },
1206                { "id": "n3", "type": "peak-tracker" },
1207                { "id": "n4", "type": "peak-tracker" },
1208                { "id": "n5", "type": "peak-tracker" }
1209            ],
1210            "edges": []
1211        });
1212
1213        let dag = DagGraph::from_json(&def).unwrap();
1214        let runner = FlowRunner::new(dag, registry).with_max_concurrency(2);
1215        let result = runner.run(HashMap::new()).await.unwrap();
1216
1217        assert_eq!(result.completed_nodes.len(), 5);
1218        assert!(
1219            peak.load(Ordering::SeqCst) <= 2,
1220            "peak concurrency {} exceeded max of 2",
1221            peak.load(Ordering::SeqCst)
1222        );
1223    }
1224
1225    #[tokio::test]
1226    async fn max_concurrency_unlimited_by_default() {
1227        // With no max_concurrency, all 5 independent nodes should be able to
1228        // run concurrently (peak may be ≤ 5, just verify flow completes).
1229        let def = json!({
1230            "nodes": [
1231                { "id": "a", "type": "noop" },
1232                { "id": "b", "type": "noop" },
1233                { "id": "c", "type": "noop" }
1234            ],
1235            "edges": []
1236        });
1237        let dag = DagGraph::from_json(&def).unwrap();
1238        let result = FlowRunner::new(dag, NodeRegistry::with_defaults())
1239            .run(HashMap::new())
1240            .await
1241            .unwrap();
1242        assert_eq!(result.completed_nodes.len(), 3);
1243    }
1244
1245    // ── start / end nodes ──────────────────────────────────────────────────
1246
1247    #[tokio::test]
1248    async fn start_node_resolves_variables_and_end_node_gathers_output() {
1249        let def = json!({
1250            "nodes": [
1251                {
1252                    "id": "start",
1253                    "type": "start",
1254                    "data": {
1255                        "inputs": [
1256                            { "name": "greeting", "type": "string" },
1257                            { "name": "repeat",   "type": "number", "default": 1 }
1258                        ]
1259                    }
1260                },
1261                {
1262                    "id": "end",
1263                    "type": "end",
1264                    "data": {
1265                        "outputs": {
1266                            "greeting": "/start/greeting",
1267                            "repeat":   "/start/repeat"
1268                        }
1269                    }
1270                }
1271            ],
1272            "edges": [{ "source": "start", "target": "end" }]
1273        });
1274        let dag = DagGraph::from_json(&def).unwrap();
1275        let mut vars = HashMap::new();
1276        vars.insert("greeting".to_string(), json!("hello"));
1277        let result = FlowRunner::new(dag, NodeRegistry::with_defaults())
1278            .run(vars)
1279            .await
1280            .unwrap();
1281
1282        // start node resolves greeting and applies default for repeat.
1283        assert_eq!(result.outputs["start"]["greeting"], json!("hello"));
1284        assert_eq!(result.outputs["start"]["repeat"], json!(1));
1285
1286        // end node gathers via JSON pointer.
1287        assert_eq!(result.outputs["end"]["greeting"], json!("hello"));
1288        assert_eq!(result.outputs["end"]["repeat"], json!(1));
1289    }
1290
1291    // ── assign node — variable scope mutation ──────────────────────────────
1292
1293    #[tokio::test]
1294    async fn assign_node_makes_value_visible_to_downstream_nodes() {
1295        // "init" assigns greeting; "read" is a code node that reads it from variables.
1296        let def = json!({
1297            "nodes": [
1298                {
1299                    "id": "init",
1300                    "type": "assign",
1301                    "data": { "assigns": { "greeting": "hello from assign" } }
1302                },
1303                {
1304                    "id": "read",
1305                    "type": "code",
1306                    "data": { "language": "rhai", "code": "variables.greeting" }
1307                }
1308            ],
1309            "edges": [{ "source": "init", "target": "read" }]
1310        });
1311        let dag = DagGraph::from_json(&def).unwrap();
1312        let runner = FlowRunner::new(dag, NodeRegistry::with_defaults());
1313        let result = runner.run(HashMap::new()).await.unwrap();
1314
1315        assert_eq!(result.outputs["read"]["output"], json!("hello from assign"));
1316    }
1317
1318    #[tokio::test]
1319    async fn assign_node_overwrites_existing_variable() {
1320        let def = json!({
1321            "nodes": [
1322                {
1323                    "id": "overwrite",
1324                    "type": "assign",
1325                    "data": { "assigns": { "x": "new_value" } }
1326                },
1327                {
1328                    "id": "read",
1329                    "type": "code",
1330                    "data": { "language": "rhai", "code": "variables.x" }
1331                }
1332            ],
1333            "edges": [{ "source": "overwrite", "target": "read" }]
1334        });
1335        let dag = DagGraph::from_json(&def).unwrap();
1336        let runner = FlowRunner::new(dag, NodeRegistry::with_defaults());
1337        let mut vars = HashMap::new();
1338        vars.insert("x".to_string(), json!("old_value"));
1339        let result = runner.run(vars).await.unwrap();
1340
1341        assert_eq!(result.outputs["read"]["output"], json!("new_value"));
1342    }
1343
1344    #[tokio::test]
1345    async fn assign_node_does_not_affect_parallel_siblings() {
1346        // "assign_a" and "noop_b" run in the same wave (no edges between them).
1347        // "read" runs after both — sees the assigned value.
1348        let def = json!({
1349            "nodes": [
1350                {
1351                    "id": "assign_a",
1352                    "type": "assign",
1353                    "data": { "assigns": { "flag": "set" } }
1354                },
1355                { "id": "noop_b", "type": "noop" },
1356                {
1357                    "id": "read",
1358                    "type": "code",
1359                    "data": { "language": "rhai", "code": "variables.flag" }
1360                }
1361            ],
1362            "edges": [
1363                { "source": "assign_a", "target": "read" },
1364                { "source": "noop_b",   "target": "read" }
1365            ]
1366        });
1367        let dag = DagGraph::from_json(&def).unwrap();
1368        let runner = FlowRunner::new(dag, NodeRegistry::with_defaults());
1369        let result = runner.run(HashMap::new()).await.unwrap();
1370
1371        // "read" runs in wave 2; the assign happened in wave 1, so it's visible.
1372        assert_eq!(result.outputs["read"]["output"], json!("set"));
1373    }
1374}