dag_executor/advanced/
patterns.rs1use crate::context::Context;
8use crate::error::TaskError;
9use crate::tasks::{BasicTask, Task, TaskOutput};
10use std::future::Future;
11use std::sync::Arc;
12
13pub fn fan_out_in<W, WFut, A, AFut>(
22 prefix: &str,
23 count: usize,
24 upstream: Option<&str>,
25 worker: W,
26 aggregator: A,
27) -> Vec<Arc<dyn Task>>
28where
29 W: Fn(Arc<Context>, usize) -> WFut + Send + Sync + Clone + 'static,
30 WFut: Future<Output = Result<TaskOutput, TaskError>> + Send + 'static,
31 A: Fn(Arc<Context>, Vec<TaskOutput>) -> AFut + Send + Sync + 'static,
32 AFut: Future<Output = Result<TaskOutput, TaskError>> + Send + 'static,
33{
34 let mut tasks: Vec<Arc<dyn Task>> = Vec::with_capacity(count + 1);
35 let mut worker_ids = Vec::with_capacity(count);
36
37 for i in 0..count {
38 let id = format!("{prefix}.worker.{i}");
39 worker_ids.push(id.clone());
40 let worker = worker.clone();
41 let result_key = format!("{prefix}.result.{i}");
44 let mut task = BasicTask::new(id, move |ctx: Arc<Context>| {
45 let worker = worker.clone();
46 let result_key = result_key.clone();
47 async move {
48 let out = worker(ctx.clone(), i).await?;
49 ctx.set(result_key, out.clone());
50 Ok(out)
51 }
52 });
53 if let Some(up) = upstream {
54 task = task.with_deps([up]);
55 }
56 tasks.push(Arc::new(task));
57 }
58
59 let agg_prefix = prefix.to_string();
60 let aggregator = Arc::new(aggregator);
61 let agg = BasicTask::new(format!("{prefix}.aggregate"), move |ctx: Arc<Context>| {
62 let agg_prefix = agg_prefix.clone();
63 let aggregator = aggregator.clone();
64 async move {
65 let mut collected = Vec::with_capacity(count);
66 for i in 0..count {
67 let v = ctx
68 .get(&format!("{agg_prefix}.result.{i}"))
69 .unwrap_or(serde_json::Value::Null);
70 collected.push(v);
71 }
72 aggregator(ctx, collected).await
73 }
74 })
75 .with_deps(worker_ids);
76 tasks.push(Arc::new(agg));
77
78 tasks
79}
80
81pub fn pipeline<F, Fut>(stages: Vec<(String, F)>) -> Vec<Arc<dyn Task>>
86where
87 F: Fn(Arc<Context>) -> Fut + Send + Sync + 'static,
88 Fut: Future<Output = Result<TaskOutput, TaskError>> + Send + 'static,
89{
90 let mut tasks: Vec<Arc<dyn Task>> = Vec::with_capacity(stages.len());
91 let mut prev: Option<String> = None;
92 for (id, f) in stages {
93 let mut task = BasicTask::new(id.clone(), f);
94 if let Some(p) = prev.take() {
95 task = task.with_deps([p]);
96 }
97 prev = Some(id);
98 tasks.push(Arc::new(task));
99 }
100 tasks
101}