use async_trait::async_trait;
use std::sync::Arc;
use tokio::task::JoinSet;
use crate::{Context, Result, Task, TaskResult, NextAction, GraphError};
#[derive(Clone)]
pub struct FanOutTask {
id: String,
children: Vec<Arc<dyn Task>>, prefix: Option<String>, next_action: NextAction, }
impl FanOutTask {
pub fn new(id: impl Into<String>, children: Vec<Arc<dyn Task>>) -> Arc<Self> {
Arc::new(Self {
id: id.into(),
children,
prefix: None,
next_action: NextAction::Continue,
})
}
pub fn with_prefix(mut self: Arc<Self>, prefix: impl Into<String>) -> Arc<Self> {
Arc::make_mut(&mut self).prefix = Some(prefix.into());
self
}
pub fn with_next_action(mut self: Arc<Self>, next: NextAction) -> Arc<Self> {
Arc::make_mut(&mut self).next_action = next;
self
}
fn key(&self, child_id: &str, field: &str) -> String {
if let Some(p) = &self.prefix {
format!("{}.{}.{}", p, child_id, field)
} else {
format!("fanout.{}.{}", child_id, field)
}
}
}
#[async_trait]
impl Task for FanOutTask {
fn id(&self) -> &str { &self.id }
async fn run(&self, context: Context) -> Result<TaskResult> {
let mut set = JoinSet::new();
for child in &self.children {
let child = child.clone();
let ctx = context.clone();
set.spawn(async move {
let cid = child.id().to_string();
let res = child.run(ctx.clone()).await;
(cid, res)
});
}
let mut had_error = None;
let mut completed = 0usize;
while let Some(joined) = set.join_next().await {
match joined {
Err(join_err) => {
had_error = Some(GraphError::TaskExecutionFailed(format!(
"FanOut child join error: {}", join_err
)));
}
Ok((child_id, outcome)) => match outcome {
Err(e) => {
had_error = Some(GraphError::TaskExecutionFailed(format!(
"FanOut child '{}' failed: {}", child_id, e
)));
}
Ok(tr) => {
if let Some(resp) = tr.response.clone() {
context.set(self.key(&child_id, "response"), resp).await;
}
if let Some(status) = tr.status_message.clone() {
context.set(self.key(&child_id, "status"), status).await;
}
context
.set(self.key(&child_id, "next_action"), format!("{:?}", tr.next_action))
.await;
completed += 1;
}
},
}
}
if let Some(err) = had_error {
return Err(err);
}
let summary = format!(
"FanOutTask '{}' completed {} child task(s)",
self.id, completed
);
Ok(TaskResult::new_with_status(
Some(summary.clone()),
self.next_action.clone(),
Some(summary),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use tokio::time::{sleep, Duration};
struct OkTask { name: &'static str }
struct FailingTask { name: &'static str }
#[async_trait]
impl Task for OkTask {
fn id(&self) -> &str { self.name }
async fn run(&self, ctx: Context) -> Result<TaskResult> {
ctx.set(format!("out.{}", self.name), true).await;
sleep(Duration::from_millis(10)).await;
Ok(TaskResult::new(Some(format!("{} ok", self.name)), NextAction::End))
}
}
#[async_trait]
impl Task for FailingTask {
fn id(&self) -> &str { self.name }
async fn run(&self, _ctx: Context) -> Result<TaskResult> {
Err(GraphError::TaskExecutionFailed(format!("{} failed", self.name)))
}
}
#[tokio::test]
async fn fanout_all_success_aggregates() {
let a: Arc<dyn Task> = Arc::new(OkTask { name: "a" });
let b: Arc<dyn Task> = Arc::new(OkTask { name: "b" });
let fan = FanOutTask::new("fan", vec![a, b]).with_prefix("agg");
let ctx = Context::new();
let res = fan.run(ctx.clone()).await.unwrap();
assert_eq!(res.next_action, NextAction::Continue);
let ar: Option<String> = ctx.get("agg.a.response").await;
let br: Option<String> = ctx.get("agg.b.response").await;
assert_eq!(ar, Some("a ok".to_string()));
assert_eq!(br, Some("b ok".to_string()));
let an: Option<String> = ctx.get("agg.a.next_action").await;
assert_eq!(an, Some(format!("{:?}", NextAction::End)));
}
#[tokio::test]
async fn fanout_failure_bubbles_up() {
let a: Arc<dyn Task> = Arc::new(OkTask { name: "a" });
let f: Arc<dyn Task> = Arc::new(FailingTask { name: "bad" });
let fan = FanOutTask::new("fan", vec![a, f]);
let ctx = Context::new();
let err = fan.run(ctx.clone()).await.err().unwrap();
match err {
GraphError::TaskExecutionFailed(msg) => assert!(msg.contains("bad")),
other => panic!("Unexpected error variant: {other:?}"),
}
}
}