Skip to main content

a3s_flow/nodes/
iteration.rs

1//! Built-in `"iteration"` node — runs a sub-flow for every element of an
2//! input array, collecting per-iteration outputs.
3//!
4//! Mirrors Dify's Iteration node. Each element is passed to the sub-flow as a
5//! flow variable named `"item"` (plus an `"index"` variable with the 0-based
6//! position).
7//!
8//! Two execution modes are available via `data["mode"]`:
9//! - `"parallel"` *(default)* — all iterations run concurrently via Tokio tasks.
10//! - `"sequential"` — iterations run one-at-a-time in order; each iteration
11//!   receives the previous iteration's collected output as `"prev_output"` in
12//!   its variable scope (`null` for the first item).
13//!
14//! # Config schema
15//!
16//! ```json
17//! {
18//!   "input_selector":  "fetch.body.items",
19//!   "output_selector": "summarize.output",
20//!   "mode":            "sequential",
21//!   "flow": { ... }
22//! }
23//! ```
24//!
25//! | Field | Type | Required | Description |
26//! |-------|------|:--------:|-------------|
27//! | `input_selector` | string | ✅ | Dot path into `inputs` to reach the array |
28//! | `output_selector` | string | ✅ | Dot path into sub-flow outputs to collect |
29//! | `flow` | object | ✅ | Inline sub-flow definition |
30//! | `mode` | string | — | `"parallel"` (default) or `"sequential"` |
31//!
32//! # Variables injected into each iteration's sub-flow
33//!
34//! | Variable | Value |
35//! |----------|-------|
36//! | `item` | The current array element |
37//! | `index` | The 0-based position of the element |
38//! | `prev_output` | *(sequential only)* The previous iteration's collected output (`null` for the first item) |
39//!
40//! # Output schema
41//!
42//! ```json
43//! { "output": [ <value from output_selector for iteration 0>, ... ] }
44//! ```
45//!
46//! Results are always returned in the original array order. A `null` is placed
47//! for any iteration whose `output_selector` path resolves to nothing.
48//!
49//! # Example
50//!
51//! ```json
52//! {
53//!   "nodes": [
54//!     { "id": "fetch", "type": "http-request", "data": { "url": "..." } },
55//!     {
56//!       "id": "process_all",
57//!       "type": "iteration",
58//!       "data": {
59//!         "input_selector":  "fetch.body.items",
60//!         "output_selector": "process.output",
61//!         "flow": {
62//!           "nodes": [
63//!             { "id": "process", "type": "code", "data": { "language": "rhai", "code": "item" } }
64//!           ],
65//!           "edges": []
66//!         }
67//!       }
68//!     }
69//!   ],
70//!   "edges": [{ "source": "fetch", "target": "process_all" }]
71//! }
72//! ```
73
74use async_trait::async_trait;
75use serde_json::{json, Value};
76use std::sync::Arc;
77use tokio::task::JoinSet;
78
79use crate::error::{FlowError, Result};
80use crate::graph::DagGraph;
81use crate::node::{ExecContext, Node};
82use crate::runner::FlowRunner;
83
84/// Iteration node — runs a sub-flow for each element of an array (Dify-compatible).
85pub struct IterationNode;
86
87/// Resolves a dot-separated path into a JSON value.
88///
89/// `"a.b.c"` into `{"a": {"b": {"c": 42}}}` returns `Some(42)`.
90/// An empty string returns the root value.
91fn resolve_path<'a>(root: &'a Value, path: &str) -> Option<&'a Value> {
92    if path.is_empty() {
93        return Some(root);
94    }
95    let mut cur = root;
96    for segment in path.split('.') {
97        cur = cur.get(segment)?;
98    }
99    Some(cur)
100}
101
102/// Resolves `selector` of the form `"<node_id>.<field>.<subfield>..."` into
103/// `outputs["node_id"]["field"]["subfield"]...`.
104///
105/// If the selector has no dot, it is treated as a node ID and the whole output
106/// for that node is returned.
107fn resolve_selector<'a>(
108    outputs: &'a std::collections::HashMap<String, Value>,
109    selector: &str,
110) -> Option<&'a Value> {
111    let (node_id, rest) = match selector.find('.') {
112        Some(pos) => (&selector[..pos], &selector[pos + 1..]),
113        None => (selector, ""),
114    };
115    let node_out = outputs.get(node_id)?;
116    resolve_path(node_out, rest)
117}
118
119#[async_trait]
120impl Node for IterationNode {
121    fn node_type(&self) -> &str {
122        "iteration"
123    }
124
125    async fn execute(&self, ctx: ExecContext) -> Result<Value> {
126        // ── Parse data ────────────────────────────────────────────────────
127        let input_selector = ctx.data["input_selector"]
128            .as_str()
129            .ok_or_else(|| {
130                FlowError::InvalidDefinition("iteration: missing data.input_selector".into())
131            })?
132            .to_string();
133
134        let output_selector = ctx.data["output_selector"]
135            .as_str()
136            .ok_or_else(|| {
137                FlowError::InvalidDefinition("iteration: missing data.output_selector".into())
138            })?
139            .to_string();
140
141        let sub_flow_def = ctx
142            .data
143            .get("flow")
144            .ok_or_else(|| FlowError::InvalidDefinition("iteration: missing data.flow".into()))?;
145
146        // ── Parse and validate the sub-flow DAG once ──────────────────────
147        let sub_dag = DagGraph::from_json(sub_flow_def)?;
148
149        // ── Resolve the input array ───────────────────────────────────────
150        // input_selector is relative to the combined inputs map, e.g. "fetch.body.items".
151        // We split on the first dot to get the node_id, then path into its output.
152        let items: Vec<Value> = {
153            let (node_id, rest) = match input_selector.find('.') {
154                Some(pos) => (&input_selector[..pos], &input_selector[pos + 1..]),
155                None => (input_selector.as_str(), ""),
156            };
157            let node_out = ctx.inputs.get(node_id).ok_or_else(|| {
158                FlowError::InvalidDefinition(format!(
159                    "iteration: input_selector '{input_selector}' references unknown node '{node_id}'"
160                ))
161            })?;
162            let arr = resolve_path(node_out, rest).ok_or_else(|| {
163                FlowError::InvalidDefinition(format!(
164                    "iteration: path '{rest}' not found in node '{node_id}' output"
165                ))
166            })?;
167            arr.as_array()
168                .ok_or_else(|| {
169                    FlowError::InvalidDefinition(format!(
170                        "iteration: input_selector '{input_selector}' must point to a JSON array"
171                    ))
172                })?
173                .clone()
174        };
175
176        if items.is_empty() {
177            return Ok(json!({ "output": [] }));
178        }
179
180        let mode = ctx.data["mode"].as_str().unwrap_or("parallel");
181        let registry = Arc::clone(&ctx.registry);
182        let base_variables = ctx.variables.clone();
183
184        if mode == "sequential" {
185            // ── Sequential: process items one-at-a-time in order ──────────
186            let mut results = Vec::with_capacity(items.len());
187            let mut prev_output = Value::Null;
188
189            for (index, item) in items.into_iter().enumerate() {
190                let mut vars = base_variables.clone();
191                vars.insert("item".into(), item);
192                vars.insert("index".into(), json!(index));
193                vars.insert("prev_output".into(), prev_output.clone());
194
195                let runner = FlowRunner::with_arc_registry(sub_dag.clone(), Arc::clone(&registry));
196                let sub_result = runner.run(vars).await?;
197
198                let value = resolve_selector(&sub_result.outputs, &output_selector)
199                    .cloned()
200                    .unwrap_or(Value::Null);
201                prev_output = value.clone();
202                results.push(value);
203            }
204
205            Ok(json!({ "output": results }))
206        } else {
207            // ── Parallel (default): launch all items concurrently ─────────
208            let n = items.len();
209            let mut join_set: JoinSet<(usize, Result<std::collections::HashMap<String, Value>>)> =
210                JoinSet::new();
211
212            for (index, item) in items.into_iter().enumerate() {
213                let dag = sub_dag.clone();
214                let reg = Arc::clone(&registry);
215                let mut vars = base_variables.clone();
216                vars.insert("item".into(), item);
217                vars.insert("index".into(), json!(index));
218
219                join_set.spawn(async move {
220                    let runner = FlowRunner::with_arc_registry(dag, reg);
221                    let result: crate::error::Result<_> = runner.run(vars).await.map(|r| r.outputs);
222                    (index, result)
223                });
224            }
225
226            // Collect results in order.
227            let mut results: Vec<Option<Value>> = vec![None; n];
228
229            while let Some(task) = join_set.join_next().await {
230                match task {
231                    Ok((index, Ok(outputs))) => {
232                        let value = resolve_selector(&outputs, &output_selector).cloned();
233                        results[index] = value;
234                    }
235                    Ok((_, Err(e))) => return Err(e),
236                    Err(e) if e.is_cancelled() => return Err(FlowError::Terminated),
237                    Err(e) => return Err(FlowError::Internal(e.to_string())),
238                }
239            }
240
241            let output: Vec<Value> = results
242                .into_iter()
243                .map(|v| v.unwrap_or(Value::Null))
244                .collect();
245
246            Ok(json!({ "output": output }))
247        }
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use std::collections::HashMap;
255
256    fn ctx(data: Value) -> ExecContext {
257        ExecContext {
258            data,
259            inputs: HashMap::new(),
260            variables: HashMap::new(),
261            ..Default::default()
262        }
263    }
264
265    fn ctx_with_inputs(data: Value, inputs: HashMap<String, Value>) -> ExecContext {
266        ExecContext {
267            data,
268            inputs,
269            variables: HashMap::new(),
270            ..Default::default()
271        }
272    }
273
274    #[tokio::test]
275    async fn iterates_over_array_and_collects_outputs() {
276        // Sub-flow: single "code" node that returns { output: item * 2 }
277        // `item` is injected as a flow variable, accessible via `variables.item` in Rhai.
278        let node = IterationNode;
279        let out = node
280            .execute(ctx_with_inputs(
281                json!({
282                    "input_selector":  "src.items",
283                    "output_selector": "double.output",
284                    "flow": {
285                        "nodes": [
286                            {
287                                "id": "double",
288                                "type": "code",
289                                "data": { "language": "rhai", "code": "variables.item * 2" }
290                            }
291                        ],
292                        "edges": []
293                    }
294                }),
295                HashMap::from([("src".into(), json!({ "items": [1, 2, 3] }))]),
296            ))
297            .await
298            .unwrap();
299
300        let arr = out["output"].as_array().unwrap();
301        assert_eq!(arr.len(), 3);
302        // Order is preserved.
303        assert_eq!(arr[0], json!(2));
304        assert_eq!(arr[1], json!(4));
305        assert_eq!(arr[2], json!(6));
306    }
307
308    #[tokio::test]
309    async fn empty_array_returns_empty_output() {
310        let node = IterationNode;
311        let out = node
312            .execute(ctx_with_inputs(
313                json!({
314                    "input_selector":  "src",
315                    "output_selector": "noop",
316                    "flow": { "nodes": [{ "id": "noop", "type": "noop" }], "edges": [] }
317                }),
318                HashMap::from([("src".into(), json!([]))]),
319            ))
320            .await
321            .unwrap();
322        assert_eq!(out["output"], json!([]));
323    }
324
325    #[tokio::test]
326    async fn index_variable_injected() {
327        // `index` is injected as a flow variable, accessible via `variables.index` in Rhai.
328        let node = IterationNode;
329        let out = node
330            .execute(ctx_with_inputs(
331                json!({
332                    "input_selector":  "src",
333                    "output_selector": "idx.output",
334                    "flow": {
335                        "nodes": [
336                            {
337                                "id": "idx",
338                                "type": "code",
339                                "data": { "language": "rhai", "code": "variables.index" }
340                            }
341                        ],
342                        "edges": []
343                    }
344                }),
345                HashMap::from([("src".into(), json!(["a", "b", "c"]))]),
346            ))
347            .await
348            .unwrap();
349
350        let arr = out["output"].as_array().unwrap();
351        assert_eq!(arr[0], json!(0));
352        assert_eq!(arr[1], json!(1));
353        assert_eq!(arr[2], json!(2));
354    }
355
356    #[tokio::test]
357    async fn rejects_missing_input_selector() {
358        let node = IterationNode;
359        let err = node
360            .execute(ctx(json!({
361                "output_selector": "x",
362                "flow": { "nodes": [{ "id": "n", "type": "noop" }], "edges": [] }
363            })))
364            .await
365            .unwrap_err();
366        assert!(matches!(err, FlowError::InvalidDefinition(_)));
367    }
368
369    #[tokio::test]
370    async fn rejects_missing_output_selector() {
371        let node = IterationNode;
372        let err = node
373            .execute(ctx(json!({
374                "input_selector": "src",
375                "flow": { "nodes": [{ "id": "n", "type": "noop" }], "edges": [] }
376            })))
377            .await
378            .unwrap_err();
379        assert!(matches!(err, FlowError::InvalidDefinition(_)));
380    }
381
382    #[tokio::test]
383    async fn rejects_non_array_input() {
384        let node = IterationNode;
385        let err = node
386            .execute(ctx_with_inputs(
387                json!({
388                    "input_selector":  "src",
389                    "output_selector": "n",
390                    "flow": { "nodes": [{ "id": "n", "type": "noop" }], "edges": [] }
391                }),
392                HashMap::from([("src".into(), json!("not an array"))]),
393            ))
394            .await
395            .unwrap_err();
396        assert!(matches!(err, FlowError::InvalidDefinition(_)));
397    }
398
399    // ── Sequential mode ────────────────────────────────────────────────────
400
401    #[tokio::test]
402    async fn sequential_mode_processes_in_order() {
403        let node = IterationNode;
404        let out = node
405            .execute(ctx_with_inputs(
406                json!({
407                    "input_selector":  "src",
408                    "output_selector": "step.output",
409                    "mode": "sequential",
410                    "flow": {
411                        "nodes": [
412                            {
413                                "id": "step",
414                                "type": "code",
415                                "data": { "language": "rhai", "code": "variables.item * 10" }
416                            }
417                        ],
418                        "edges": []
419                    }
420                }),
421                HashMap::from([("src".into(), json!([1, 2, 3]))]),
422            ))
423            .await
424            .unwrap();
425
426        let arr = out["output"].as_array().unwrap();
427        assert_eq!(arr, &[json!(10), json!(20), json!(30)]);
428    }
429
430    #[tokio::test]
431    async fn sequential_mode_injects_prev_output() {
432        // Each step receives `prev_output` from the previous iteration.
433        // Step returns index + 1; prev_output for step 1 = 0 (null → 0 in Rhai).
434        let node = IterationNode;
435        let out = node
436            .execute(ctx_with_inputs(
437                json!({
438                    "input_selector":  "src",
439                    "output_selector": "step.output",
440                    "mode": "sequential",
441                    "flow": {
442                        "nodes": [
443                            {
444                                "id": "step",
445                                "type": "code",
446                                "data": {
447                                    "language": "rhai",
448                                    // Return the index as a simple marker.
449                                    "code": "variables.index"
450                                }
451                            }
452                        ],
453                        "edges": []
454                    }
455                }),
456                HashMap::from([("src".into(), json!(["a", "b", "c"]))]),
457            ))
458            .await
459            .unwrap();
460
461        let arr = out["output"].as_array().unwrap();
462        assert_eq!(arr, &[json!(0), json!(1), json!(2)]);
463    }
464
465    #[tokio::test]
466    async fn sequential_mode_empty_array_returns_empty() {
467        let node = IterationNode;
468        let out = node
469            .execute(ctx_with_inputs(
470                json!({
471                    "input_selector":  "src",
472                    "output_selector": "n",
473                    "mode": "sequential",
474                    "flow": { "nodes": [{ "id": "n", "type": "noop" }], "edges": [] }
475                }),
476                HashMap::from([("src".into(), json!([]))]),
477            ))
478            .await
479            .unwrap();
480        assert_eq!(out["output"], json!([]));
481    }
482
483    #[tokio::test]
484    async fn unknown_mode_defaults_to_parallel() {
485        // Any unrecognised mode string falls back to parallel.
486        let node = IterationNode;
487        let out = node
488            .execute(ctx_with_inputs(
489                json!({
490                    "input_selector":  "src",
491                    "output_selector": "step.output",
492                    "mode": "turbo",
493                    "flow": {
494                        "nodes": [
495                            {
496                                "id": "step",
497                                "type": "code",
498                                "data": { "language": "rhai", "code": "variables.item" }
499                            }
500                        ],
501                        "edges": []
502                    }
503                }),
504                HashMap::from([("src".into(), json!([7, 8]))]),
505            ))
506            .await
507            .unwrap();
508
509        let mut arr = out["output"].as_array().unwrap().clone();
510        arr.sort_by(|a, b| a.as_i64().cmp(&b.as_i64()));
511        assert_eq!(arr, &[json!(7), json!(8)]);
512    }
513}