1use async_trait::async_trait;
63use std::sync::Arc;
64use tokio::task::JoinSet;
65
66use crate::{Context, Result, Task, TaskResult, NextAction, GraphError};
67
68#[derive(Clone)]
70pub struct FanOutTask {
71 id: String,
72 children: Vec<Arc<dyn Task>>, prefix: Option<String>, next_action: NextAction, }
76
77impl FanOutTask {
78 pub fn new(id: impl Into<String>, children: Vec<Arc<dyn Task>>) -> Arc<Self> {
80 Arc::new(Self {
81 id: id.into(),
82 children,
83 prefix: None,
84 next_action: NextAction::Continue,
85 })
86 }
87
88 pub fn with_prefix(mut self: Arc<Self>, prefix: impl Into<String>) -> Arc<Self> {
92 Arc::make_mut(&mut self).prefix = Some(prefix.into());
93 self
94 }
95
96 pub fn with_next_action(mut self: Arc<Self>, next: NextAction) -> Arc<Self> {
98 Arc::make_mut(&mut self).next_action = next;
99 self
100 }
101
102 fn key(&self, child_id: &str, field: &str) -> String {
103 if let Some(p) = &self.prefix {
104 format!("{}.{}.{}", p, child_id, field)
105 } else {
106 format!("fanout.{}.{}", child_id, field)
107 }
108 }
109}
110
111#[async_trait]
112impl Task for FanOutTask {
113 fn id(&self) -> &str { &self.id }
114
115 async fn run(&self, context: Context) -> Result<TaskResult> {
116 let mut set = JoinSet::new();
117
118 for child in &self.children {
120 let child = child.clone();
121 let ctx = context.clone();
122 set.spawn(async move {
123 let cid = child.id().to_string();
124 let res = child.run(ctx.clone()).await;
125 (cid, res)
126 });
127 }
128
129 let mut had_error = None;
130 let mut completed = 0usize;
131
132 while let Some(joined) = set.join_next().await {
133 match joined {
134 Err(join_err) => {
135 had_error = Some(GraphError::TaskExecutionFailed(format!(
136 "FanOut child join error: {}", join_err
137 )));
138 }
139 Ok((child_id, outcome)) => match outcome {
140 Err(e) => {
141 had_error = Some(GraphError::TaskExecutionFailed(format!(
142 "FanOut child '{}' failed: {}", child_id, e
143 )));
144 }
145 Ok(tr) => {
146 if let Some(resp) = tr.response.clone() {
148 context.set(self.key(&child_id, "response"), resp).await;
149 }
150 if let Some(status) = tr.status_message.clone() {
151 context.set(self.key(&child_id, "status"), status).await;
152 }
153 context
155 .set(self.key(&child_id, "next_action"), format!("{:?}", tr.next_action))
156 .await;
157 completed += 1;
158 }
159 },
160 }
161 }
162
163 if let Some(err) = had_error {
164 return Err(err);
165 }
166
167 let summary = format!(
168 "FanOutTask '{}' completed {} child task(s)",
169 self.id, completed
170 );
171
172 Ok(TaskResult::new_with_status(
173 Some(summary.clone()),
174 self.next_action.clone(),
175 Some(summary),
176 ))
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use async_trait::async_trait;
184 use tokio::time::{sleep, Duration};
185
186 struct OkTask { name: &'static str }
187 struct FailingTask { name: &'static str }
188
189 #[async_trait]
190 impl Task for OkTask {
191 fn id(&self) -> &str { self.name }
192 async fn run(&self, ctx: Context) -> Result<TaskResult> {
193 ctx.set(format!("out.{}", self.name), true).await;
194 sleep(Duration::from_millis(10)).await;
195 Ok(TaskResult::new(Some(format!("{} ok", self.name)), NextAction::End))
196 }
197 }
198
199 #[async_trait]
200 impl Task for FailingTask {
201 fn id(&self) -> &str { self.name }
202 async fn run(&self, _ctx: Context) -> Result<TaskResult> {
203 Err(GraphError::TaskExecutionFailed(format!("{} failed", self.name)))
204 }
205 }
206
207 #[tokio::test]
208 async fn fanout_all_success_aggregates() {
209 let a: Arc<dyn Task> = Arc::new(OkTask { name: "a" });
210 let b: Arc<dyn Task> = Arc::new(OkTask { name: "b" });
211 let fan = FanOutTask::new("fan", vec![a, b]).with_prefix("agg");
212
213 let ctx = Context::new();
214 let res = fan.run(ctx.clone()).await.unwrap();
215
216 assert_eq!(res.next_action, NextAction::Continue);
217
218 let ar: Option<String> = ctx.get("agg.a.response").await;
219 let br: Option<String> = ctx.get("agg.b.response").await;
220 assert_eq!(ar, Some("a ok".to_string()));
221 assert_eq!(br, Some("b ok".to_string()));
222
223 let an: Option<String> = ctx.get("agg.a.next_action").await;
225 assert_eq!(an, Some(format!("{:?}", NextAction::End)));
226 }
227
228 #[tokio::test]
229 async fn fanout_failure_bubbles_up() {
230 let a: Arc<dyn Task> = Arc::new(OkTask { name: "a" });
231 let f: Arc<dyn Task> = Arc::new(FailingTask { name: "bad" });
232 let fan = FanOutTask::new("fan", vec![a, f]);
233
234 let ctx = Context::new();
235 let err = fan.run(ctx.clone()).await.err().unwrap();
236 match err {
237 GraphError::TaskExecutionFailed(msg) => assert!(msg.contains("bad")),
238 other => panic!("Unexpected error variant: {other:?}"),
239 }
240 }
241}