Skip to main content

dag_executor/tasks/
conditional.rs

1//! A task that evaluates a predicate and records which branch was taken.
2
3use crate::context::Context;
4use crate::error::TaskError;
5use crate::tasks::r#trait::{Task, TaskOutput};
6use async_trait::async_trait;
7use futures::future::BoxFuture;
8use futures::FutureExt;
9use std::future::Future;
10use std::sync::Arc;
11
12type PredicateFn =
13    Box<dyn Fn(Arc<Context>) -> BoxFuture<'static, Result<bool, TaskError>> + Send + Sync>;
14
15/// A branching task.
16///
17/// It evaluates an async predicate and records the chosen branch label on the
18/// context blackboard under the key `{id}.branch`. Downstream tasks can read
19/// that value to decide whether to do real work or short-circuit, implementing
20/// conditional control flow within an otherwise static DAG.
21///
22/// The task's own output is `{"taken": <bool>, "branch": <label>}`.
23pub struct ConditionalTask {
24    id: String,
25    deps: Vec<String>,
26    priority: u8,
27    predicate: PredicateFn,
28    on_true: String,
29    on_false: String,
30}
31
32impl ConditionalTask {
33    /// Create a conditional with the given async predicate.
34    ///
35    /// `on_true`/`on_false` are the branch labels published to the blackboard.
36    pub fn new<F, Fut>(
37        id: impl Into<String>,
38        on_true: impl Into<String>,
39        on_false: impl Into<String>,
40        predicate: F,
41    ) -> Self
42    where
43        F: Fn(Arc<Context>) -> Fut + Send + Sync + 'static,
44        Fut: Future<Output = Result<bool, TaskError>> + Send + 'static,
45    {
46        ConditionalTask {
47            id: id.into(),
48            deps: Vec::new(),
49            priority: 0,
50            predicate: Box::new(move |ctx| predicate(ctx).boxed()),
51            on_true: on_true.into(),
52            on_false: on_false.into(),
53        }
54    }
55
56    /// Declare dependencies.
57    pub fn with_deps<I, S>(mut self, deps: I) -> Self
58    where
59        I: IntoIterator<Item = S>,
60        S: Into<String>,
61    {
62        self.deps = deps.into_iter().map(Into::into).collect();
63        self
64    }
65
66    /// Set scheduling priority.
67    pub fn with_priority(mut self, priority: u8) -> Self {
68        self.priority = priority;
69        self
70    }
71}
72
73#[async_trait]
74impl Task for ConditionalTask {
75    fn id(&self) -> &str {
76        &self.id
77    }
78
79    fn dependencies(&self) -> Vec<String> {
80        self.deps.clone()
81    }
82
83    fn priority(&self) -> u8 {
84        self.priority
85    }
86
87    async fn execute(&self, ctx: Arc<Context>) -> Result<TaskOutput, TaskError> {
88        let taken = (self.predicate)(ctx.clone()).await?;
89        let branch = if taken { &self.on_true } else { &self.on_false };
90        ctx.set(
91            format!("{}.branch", self.id),
92            serde_json::Value::String(branch.clone()),
93        );
94        Ok(serde_json::json!({ "taken": taken, "branch": branch }))
95    }
96}