dag_executor/tasks/
loop_task.rs1use crate::context::Context;
4use crate::error::TaskError;
5use crate::tasks::r#trait::{Task, TaskOutput};
6use async_trait::async_trait;
7use futures::future::BoxFuture;
8use futures::FutureExt;
9use std::future::Future;
10use std::sync::Arc;
11
12type BodyFn = Box<
13 dyn Fn(Arc<Context>, u32) -> BoxFuture<'static, Result<TaskOutput, TaskError>> + Send + Sync,
14>;
15type BreakFn = Box<dyn Fn(&TaskOutput) -> bool + Send + Sync>;
16
17pub struct LoopTask {
24 id: String,
25 deps: Vec<String>,
26 priority: u8,
27 max_iterations: u32,
28 body: BodyFn,
29 break_when: BreakFn,
30}
31
32impl LoopTask {
33 pub fn new<F, Fut>(id: impl Into<String>, max_iterations: u32, body: F) -> Self
35 where
36 F: Fn(Arc<Context>, u32) -> Fut + Send + Sync + 'static,
37 Fut: Future<Output = Result<TaskOutput, TaskError>> + Send + 'static,
38 {
39 LoopTask {
40 id: id.into(),
41 deps: Vec::new(),
42 priority: 0,
43 max_iterations,
44 body: Box::new(move |ctx, i| body(ctx, i).boxed()),
45 break_when: Box::new(|_| false),
47 }
48 }
49
50 pub fn with_break<F>(mut self, break_when: F) -> Self
52 where
53 F: Fn(&TaskOutput) -> bool + Send + Sync + 'static,
54 {
55 self.break_when = Box::new(break_when);
56 self
57 }
58
59 pub fn with_deps<I, S>(mut self, deps: I) -> Self
61 where
62 I: IntoIterator<Item = S>,
63 S: Into<String>,
64 {
65 self.deps = deps.into_iter().map(Into::into).collect();
66 self
67 }
68
69 pub fn with_priority(mut self, priority: u8) -> Self {
71 self.priority = priority;
72 self
73 }
74}
75
76#[async_trait]
77impl Task for LoopTask {
78 fn id(&self) -> &str {
79 &self.id
80 }
81
82 fn dependencies(&self) -> Vec<String> {
83 self.deps.clone()
84 }
85
86 fn priority(&self) -> u8 {
87 self.priority
88 }
89
90 async fn execute(&self, ctx: Arc<Context>) -> Result<TaskOutput, TaskError> {
91 let mut last = serde_json::Value::Null;
92 let mut iterations = 0u32;
93 let mut broke_early = false;
94
95 while iterations < self.max_iterations {
96 if ctx.is_cancelled() {
97 return Err(TaskError::Cancelled);
98 }
99 last = (self.body)(ctx.clone(), iterations).await?;
100 iterations += 1;
101 if (self.break_when)(&last) {
102 broke_early = true;
103 break;
104 }
105 }
106
107 Ok(serde_json::json!({
108 "iterations": iterations,
109 "last": last,
110 "broke_early": broke_early,
111 }))
112 }
113}