graph_flow/
fanout.rs

1//! FanOutTask – a composite task that runs multiple child tasks in parallel
2//!
3//! This task provides simple parallelism within a single graph node. It executes
4//! a fixed set of child tasks concurrently, waits for them to finish, aggregates
5//! their responses into the shared `Context`, and then returns control back to
6//! the graph with `NextAction::Continue` (by default).
7//!
8//! Design goals:
9//! - Keep engine changes minimal (no changes to `Graph` needed)
10//! - Keep semantics simple and predictable
11//! - Make context aggregation explicit and easy to consume by downstream tasks
12//!
13//! Important caveats:
14//! - Child tasks' `NextAction` is ignored by `FanOutTask`. Children are treated as
15//!   units-of-work that produce outputs and/or write to context, not as control-flow
16//!   steps. The `FanOutTask` itself controls the next step of the graph.
17//! - By default, all children share the same `Context` (concurrent writes must be
18//!   coordinated by the user). To avoid key collisions, you can set a prefix so that
19//!   each child’s output is stored under `"<prefix>.<child_id>.*"`.
20//! - Error policy is conservative: if any child fails, `FanOutTask` fails.
21//!
22//! Example:
23//! ```rust
24//! use graph_flow::{Context, Task, TaskResult, NextAction};
25//! use graph_flow::fanout::FanOutTask;
26//! use async_trait::async_trait;
27//! use std::sync::Arc;
28//!
29//! struct ChildA;
30//! struct ChildB;
31//!
32//! #[async_trait]
33//! impl Task for ChildA {
34//!     fn id(&self) -> &str { "child_a" }
35//!     async fn run(&self, ctx: Context) -> graph_flow::Result<TaskResult> {
36//!         ctx.set("a", 1_i32).await;
37//!         Ok(TaskResult::new(Some("A done".to_string()), NextAction::End))
38//!     }
39//! }
40//!
41//! #[async_trait]
42//! impl Task for ChildB {
43//!     fn id(&self) -> &str { "child_b" }
44//!     async fn run(&self, ctx: Context) -> graph_flow::Result<TaskResult> {
45//!         ctx.set("b", 2_i32).await;
46//!         Ok(TaskResult::new(Some("B done".to_string()), NextAction::End))
47//!     }
48//! }
49//!
50//! # #[tokio::main]
51//! # async fn main() -> graph_flow::Result<()> {
52//! let fan = FanOutTask::new("fan", vec![Arc::new(ChildA), Arc::new(ChildB)])
53//!     .with_prefix("fanout");
54//! let ctx = Context::new();
55//! let _ = fan.run(ctx.clone()).await?;
56//! // Aggregated entries under prefix:
57//! // fanout.child_a.response, fanout.child_b.response
58//! # Ok(())
59//! # }
60//! ```
61
62use async_trait::async_trait;
63use std::sync::Arc;
64use tokio::task::JoinSet;
65
66use crate::{Context, Result, Task, TaskResult, NextAction, GraphError};
67
68/// Composite task that executes multiple child tasks concurrently and aggregates results.
69#[derive(Clone)]
70pub struct FanOutTask {
71    id: String,
72    children: Vec<Arc<dyn Task>>, // executed in parallel
73    prefix: Option<String>,        // context aggregation prefix
74    next_action: NextAction,       // default: Continue
75}
76
77impl FanOutTask {
78    /// Create a new `FanOutTask` with an explicit id and a list of child tasks.
79    pub fn new(id: impl Into<String>, children: Vec<Arc<dyn Task>>) -> Arc<Self> {
80        Arc::new(Self {
81            id: id.into(),
82            children,
83            prefix: None,
84            next_action: NextAction::Continue,
85        })
86    }
87
88    /// Set a context prefix for storing aggregated child results.
89    ///
90    /// Aggregation keys will be written as `<prefix>.<child_id>.<field>`.
91    pub fn with_prefix(mut self: Arc<Self>, prefix: impl Into<String>) -> Arc<Self> {
92        Arc::make_mut(&mut self).prefix = Some(prefix.into());
93        self
94    }
95
96    /// Override the `NextAction` returned by the `FanOutTask` (default: `Continue`).
97    pub fn with_next_action(mut self: Arc<Self>, next: NextAction) -> Arc<Self> {
98        Arc::make_mut(&mut self).next_action = next;
99        self
100    }
101
102    fn key(&self, child_id: &str, field: &str) -> String {
103        if let Some(p) = &self.prefix {
104            format!("{}.{}.{}", p, child_id, field)
105        } else {
106            format!("fanout.{}.{}", child_id, field)
107        }
108    }
109}
110
111#[async_trait]
112impl Task for FanOutTask {
113    fn id(&self) -> &str { &self.id }
114
115    async fn run(&self, context: Context) -> Result<TaskResult> {
116        let mut set = JoinSet::new();
117
118        // Spawn children concurrently
119        for child in &self.children {
120            let child = child.clone();
121            let ctx = context.clone();
122            set.spawn(async move {
123                let cid = child.id().to_string();
124                let res = child.run(ctx.clone()).await;
125                (cid, res)
126            });
127        }
128
129        let mut had_error = None;
130        let mut completed = 0usize;
131
132        while let Some(joined) = set.join_next().await {
133            match joined {
134                Err(join_err) => {
135                    had_error = Some(GraphError::TaskExecutionFailed(format!(
136                        "FanOut child join error: {}", join_err
137                    )));
138                }
139                Ok((child_id, outcome)) => match outcome {
140                    Err(e) => {
141                        had_error = Some(GraphError::TaskExecutionFailed(format!(
142                            "FanOut child '{}' failed: {}", child_id, e
143                        )));
144                    }
145                    Ok(tr) => {
146                        // Store child outputs under prefixed keys
147                        if let Some(resp) = tr.response.clone() {
148                            context.set(self.key(&child_id, "response"), resp).await;
149                        }
150                        if let Some(status) = tr.status_message.clone() {
151                            context.set(self.key(&child_id, "status"), status).await;
152                        }
153                        // Always store the reported next_action for diagnostics
154                        context
155                            .set(self.key(&child_id, "next_action"), format!("{:?}", tr.next_action))
156                            .await;
157                        completed += 1;
158                    }
159                },
160            }
161        }
162
163        if let Some(err) = had_error {
164            return Err(err);
165        }
166
167        let summary = format!(
168            "FanOutTask '{}' completed {} child task(s)",
169            self.id, completed
170        );
171
172        Ok(TaskResult::new_with_status(
173            Some(summary.clone()),
174            self.next_action.clone(),
175            Some(summary),
176        ))
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use async_trait::async_trait;
184    use tokio::time::{sleep, Duration};
185
186    struct OkTask { name: &'static str }
187    struct FailingTask { name: &'static str }
188
189    #[async_trait]
190    impl Task for OkTask {
191        fn id(&self) -> &str { self.name }
192        async fn run(&self, ctx: Context) -> Result<TaskResult> {
193            ctx.set(format!("out.{}", self.name), true).await;
194            sleep(Duration::from_millis(10)).await;
195            Ok(TaskResult::new(Some(format!("{} ok", self.name)), NextAction::End))
196        }
197    }
198
199    #[async_trait]
200    impl Task for FailingTask {
201        fn id(&self) -> &str { self.name }
202        async fn run(&self, _ctx: Context) -> Result<TaskResult> {
203            Err(GraphError::TaskExecutionFailed(format!("{} failed", self.name)))
204        }
205    }
206
207    #[tokio::test]
208    async fn fanout_all_success_aggregates() {
209        let a: Arc<dyn Task> = Arc::new(OkTask { name: "a" });
210        let b: Arc<dyn Task> = Arc::new(OkTask { name: "b" });
211        let fan = FanOutTask::new("fan", vec![a, b]).with_prefix("agg");
212
213        let ctx = Context::new();
214        let res = fan.run(ctx.clone()).await.unwrap();
215
216        assert_eq!(res.next_action, NextAction::Continue);
217
218        let ar: Option<String> = ctx.get("agg.a.response").await;
219        let br: Option<String> = ctx.get("agg.b.response").await;
220        assert_eq!(ar, Some("a ok".to_string()));
221        assert_eq!(br, Some("b ok".to_string()));
222
223        // also store next_action diagnostic
224        let an: Option<String> = ctx.get("agg.a.next_action").await;
225        assert_eq!(an, Some(format!("{:?}", NextAction::End)));
226    }
227
228    #[tokio::test]
229    async fn fanout_failure_bubbles_up() {
230        let a: Arc<dyn Task> = Arc::new(OkTask { name: "a" });
231        let f: Arc<dyn Task> = Arc::new(FailingTask { name: "bad" });
232        let fan = FanOutTask::new("fan", vec![a, f]);
233
234        let ctx = Context::new();
235        let err = fan.run(ctx.clone()).await.err().unwrap();
236        match err {
237            GraphError::TaskExecutionFailed(msg) => assert!(msg.contains("bad")),
238            other => panic!("Unexpected error variant: {other:?}"),
239        }
240    }
241}