use crate::context::Context;
use crate::error::TaskError;
use crate::tasks::{BasicTask, Task, TaskOutput};
use std::future::Future;
use std::sync::Arc;
pub fn fan_out_in<W, WFut, A, AFut>(
prefix: &str,
count: usize,
upstream: Option<&str>,
worker: W,
aggregator: A,
) -> Vec<Arc<dyn Task>>
where
W: Fn(Arc<Context>, usize) -> WFut + Send + Sync + Clone + 'static,
WFut: Future<Output = Result<TaskOutput, TaskError>> + Send + 'static,
A: Fn(Arc<Context>, Vec<TaskOutput>) -> AFut + Send + Sync + 'static,
AFut: Future<Output = Result<TaskOutput, TaskError>> + Send + 'static,
{
let mut tasks: Vec<Arc<dyn Task>> = Vec::with_capacity(count + 1);
let mut worker_ids = Vec::with_capacity(count);
for i in 0..count {
let id = format!("{prefix}.worker.{i}");
worker_ids.push(id.clone());
let worker = worker.clone();
let result_key = format!("{prefix}.result.{i}");
let mut task = BasicTask::new(id, move |ctx: Arc<Context>| {
let worker = worker.clone();
let result_key = result_key.clone();
async move {
let out = worker(ctx.clone(), i).await?;
ctx.set(result_key, out.clone());
Ok(out)
}
});
if let Some(up) = upstream {
task = task.with_deps([up]);
}
tasks.push(Arc::new(task));
}
let agg_prefix = prefix.to_string();
let aggregator = Arc::new(aggregator);
let agg = BasicTask::new(format!("{prefix}.aggregate"), move |ctx: Arc<Context>| {
let agg_prefix = agg_prefix.clone();
let aggregator = aggregator.clone();
async move {
let mut collected = Vec::with_capacity(count);
for i in 0..count {
let v = ctx
.get(&format!("{agg_prefix}.result.{i}"))
.unwrap_or(serde_json::Value::Null);
collected.push(v);
}
aggregator(ctx, collected).await
}
})
.with_deps(worker_ids);
tasks.push(Arc::new(agg));
tasks
}
pub fn pipeline<F, Fut>(stages: Vec<(String, F)>) -> Vec<Arc<dyn Task>>
where
F: Fn(Arc<Context>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<TaskOutput, TaskError>> + Send + 'static,
{
let mut tasks: Vec<Arc<dyn Task>> = Vec::with_capacity(stages.len());
let mut prev: Option<String> = None;
for (id, f) in stages {
let mut task = BasicTask::new(id.clone(), f);
if let Some(p) = prev.take() {
task = task.with_deps([p]);
}
prev = Some(id);
tasks.push(Arc::new(task));
}
tasks
}