dag_executor/tasks/
conditional.rs1use crate::context::Context;
4use crate::error::TaskError;
5use crate::tasks::r#trait::{Task, TaskOutput};
6use async_trait::async_trait;
7use futures::future::BoxFuture;
8use futures::FutureExt;
9use std::future::Future;
10use std::sync::Arc;
11
12type PredicateFn =
13 Box<dyn Fn(Arc<Context>) -> BoxFuture<'static, Result<bool, TaskError>> + Send + Sync>;
14
15pub struct ConditionalTask {
24 id: String,
25 deps: Vec<String>,
26 priority: u8,
27 predicate: PredicateFn,
28 on_true: String,
29 on_false: String,
30}
31
32impl ConditionalTask {
33 pub fn new<F, Fut>(
37 id: impl Into<String>,
38 on_true: impl Into<String>,
39 on_false: impl Into<String>,
40 predicate: F,
41 ) -> Self
42 where
43 F: Fn(Arc<Context>) -> Fut + Send + Sync + 'static,
44 Fut: Future<Output = Result<bool, TaskError>> + Send + 'static,
45 {
46 ConditionalTask {
47 id: id.into(),
48 deps: Vec::new(),
49 priority: 0,
50 predicate: Box::new(move |ctx| predicate(ctx).boxed()),
51 on_true: on_true.into(),
52 on_false: on_false.into(),
53 }
54 }
55
56 pub fn with_deps<I, S>(mut self, deps: I) -> Self
58 where
59 I: IntoIterator<Item = S>,
60 S: Into<String>,
61 {
62 self.deps = deps.into_iter().map(Into::into).collect();
63 self
64 }
65
66 pub fn with_priority(mut self, priority: u8) -> Self {
68 self.priority = priority;
69 self
70 }
71}
72
73#[async_trait]
74impl Task for ConditionalTask {
75 fn id(&self) -> &str {
76 &self.id
77 }
78
79 fn dependencies(&self) -> Vec<String> {
80 self.deps.clone()
81 }
82
83 fn priority(&self) -> u8 {
84 self.priority
85 }
86
87 async fn execute(&self, ctx: Arc<Context>) -> Result<TaskOutput, TaskError> {
88 let taken = (self.predicate)(ctx.clone()).await?;
89 let branch = if taken { &self.on_true } else { &self.on_false };
90 ctx.set(
91 format!("{}.branch", self.id),
92 serde_json::Value::String(branch.clone()),
93 );
94 Ok(serde_json::json!({ "taken": taken, "branch": branch }))
95 }
96}