Skip to main content

dag_executor/tasks/
loop_task.rs

1//! A task that repeats a body until a break condition or iteration cap.
2
3use crate::context::Context;
4use crate::error::TaskError;
5use crate::tasks::r#trait::{Task, TaskOutput};
6use async_trait::async_trait;
7use futures::future::BoxFuture;
8use futures::FutureExt;
9use std::future::Future;
10use std::sync::Arc;
11
12type BodyFn = Box<
13    dyn Fn(Arc<Context>, u32) -> BoxFuture<'static, Result<TaskOutput, TaskError>> + Send + Sync,
14>;
15type BreakFn = Box<dyn Fn(&TaskOutput) -> bool + Send + Sync>;
16
17/// A task that runs `body` repeatedly until `break_when` returns true or the
18/// iteration count reaches `max_iterations`.
19///
20/// The body receives the zero-based iteration index. The task output is
21/// `{"iterations": <n>, "last": <last body output>, "broke_early": <bool>}`.
22/// Cancellation is honored between iterations.
23pub struct LoopTask {
24    id: String,
25    deps: Vec<String>,
26    priority: u8,
27    max_iterations: u32,
28    body: BodyFn,
29    break_when: BreakFn,
30}
31
32impl LoopTask {
33    /// Create a loop task.
34    pub fn new<F, Fut>(id: impl Into<String>, max_iterations: u32, body: F) -> Self
35    where
36        F: Fn(Arc<Context>, u32) -> Fut + Send + Sync + 'static,
37        Fut: Future<Output = Result<TaskOutput, TaskError>> + Send + 'static,
38    {
39        LoopTask {
40            id: id.into(),
41            deps: Vec::new(),
42            priority: 0,
43            max_iterations,
44            body: Box::new(move |ctx, i| body(ctx, i).boxed()),
45            // By default, never break early — run all iterations.
46            break_when: Box::new(|_| false),
47        }
48    }
49
50    /// Set the early-exit predicate, evaluated against each body output.
51    pub fn with_break<F>(mut self, break_when: F) -> Self
52    where
53        F: Fn(&TaskOutput) -> bool + Send + Sync + 'static,
54    {
55        self.break_when = Box::new(break_when);
56        self
57    }
58
59    /// Declare dependencies.
60    pub fn with_deps<I, S>(mut self, deps: I) -> Self
61    where
62        I: IntoIterator<Item = S>,
63        S: Into<String>,
64    {
65        self.deps = deps.into_iter().map(Into::into).collect();
66        self
67    }
68
69    /// Set scheduling priority.
70    pub fn with_priority(mut self, priority: u8) -> Self {
71        self.priority = priority;
72        self
73    }
74}
75
76#[async_trait]
77impl Task for LoopTask {
78    fn id(&self) -> &str {
79        &self.id
80    }
81
82    fn dependencies(&self) -> Vec<String> {
83        self.deps.clone()
84    }
85
86    fn priority(&self) -> u8 {
87        self.priority
88    }
89
90    async fn execute(&self, ctx: Arc<Context>) -> Result<TaskOutput, TaskError> {
91        let mut last = serde_json::Value::Null;
92        let mut iterations = 0u32;
93        let mut broke_early = false;
94
95        while iterations < self.max_iterations {
96            if ctx.is_cancelled() {
97                return Err(TaskError::Cancelled);
98            }
99            last = (self.body)(ctx.clone(), iterations).await?;
100            iterations += 1;
101            if (self.break_when)(&last) {
102                broke_early = true;
103                break;
104            }
105        }
106
107        Ok(serde_json::json!({
108            "iterations": iterations,
109            "last": last,
110            "broke_early": broke_early,
111        }))
112    }
113}