Skip to main content

dag_executor/advanced/
patterns.rs

1//! Helpers for building common workflow shapes.
2//!
3//! These are convenience constructors that return [`Arc<dyn Task>`] collections
4//! wired together with the right dependencies, so callers don't hand-roll
5//! fan-out/fan-in topologies.
6
7use crate::context::Context;
8use crate::error::TaskError;
9use crate::tasks::{BasicTask, Task, TaskOutput};
10use std::future::Future;
11use std::sync::Arc;
12
13/// Build a **fan-out / fan-in** subgraph.
14///
15/// Creates `count` parallel worker tasks (each depending on `upstream`, if
16/// any), plus a single aggregator task depending on all workers. The worker
17/// closure receives its index; the aggregator closure receives the collected
18/// outputs in index order.
19///
20/// Returns the full set of tasks ready to be added to a [`crate::dag::Dag`].
21pub fn fan_out_in<W, WFut, A, AFut>(
22    prefix: &str,
23    count: usize,
24    upstream: Option<&str>,
25    worker: W,
26    aggregator: A,
27) -> Vec<Arc<dyn Task>>
28where
29    W: Fn(Arc<Context>, usize) -> WFut + Send + Sync + Clone + 'static,
30    WFut: Future<Output = Result<TaskOutput, TaskError>> + Send + 'static,
31    A: Fn(Arc<Context>, Vec<TaskOutput>) -> AFut + Send + Sync + 'static,
32    AFut: Future<Output = Result<TaskOutput, TaskError>> + Send + 'static,
33{
34    let mut tasks: Vec<Arc<dyn Task>> = Vec::with_capacity(count + 1);
35    let mut worker_ids = Vec::with_capacity(count);
36
37    for i in 0..count {
38        let id = format!("{prefix}.worker.{i}");
39        worker_ids.push(id.clone());
40        let worker = worker.clone();
41        // Each worker publishes its output to the blackboard so the aggregator
42        // can collect results without a shared mutable channel.
43        let result_key = format!("{prefix}.result.{i}");
44        let mut task = BasicTask::new(id, move |ctx: Arc<Context>| {
45            let worker = worker.clone();
46            let result_key = result_key.clone();
47            async move {
48                let out = worker(ctx.clone(), i).await?;
49                ctx.set(result_key, out.clone());
50                Ok(out)
51            }
52        });
53        if let Some(up) = upstream {
54            task = task.with_deps([up]);
55        }
56        tasks.push(Arc::new(task));
57    }
58
59    let agg_prefix = prefix.to_string();
60    let aggregator = Arc::new(aggregator);
61    let agg = BasicTask::new(format!("{prefix}.aggregate"), move |ctx: Arc<Context>| {
62        let agg_prefix = agg_prefix.clone();
63        let aggregator = aggregator.clone();
64        async move {
65            let mut collected = Vec::with_capacity(count);
66            for i in 0..count {
67                let v = ctx
68                    .get(&format!("{agg_prefix}.result.{i}"))
69                    .unwrap_or(serde_json::Value::Null);
70                collected.push(v);
71            }
72            aggregator(ctx, collected).await
73        }
74    })
75    .with_deps(worker_ids);
76    tasks.push(Arc::new(agg));
77
78    tasks
79}
80
81/// Build a **linear pipeline**: `stages[0] -> stages[1] -> ...`.
82///
83/// Each stage is a `(id, closure)` pair; stage *n* is made to depend on stage
84/// *n-1*, so they execute strictly in order.
85pub fn pipeline<F, Fut>(stages: Vec<(String, F)>) -> Vec<Arc<dyn Task>>
86where
87    F: Fn(Arc<Context>) -> Fut + Send + Sync + 'static,
88    Fut: Future<Output = Result<TaskOutput, TaskError>> + Send + 'static,
89{
90    let mut tasks: Vec<Arc<dyn Task>> = Vec::with_capacity(stages.len());
91    let mut prev: Option<String> = None;
92    for (id, f) in stages {
93        let mut task = BasicTask::new(id.clone(), f);
94        if let Some(p) = prev.take() {
95            task = task.with_deps([p]);
96        }
97        prev = Some(id);
98        tasks.push(Arc::new(task));
99    }
100    tasks
101}