dag_flow/
engine.rs

1use std::collections::HashMap;
2use std::collections::VecDeque;
3use std::hash::Hash;
4use std::sync::Arc;
5use std::sync::RwLock;
6
7use futures::FutureExt;
8use futures::StreamExt;
9use futures::stream::FuturesUnordered;
10
11use crate::context::Context;
12use crate::task::DynTask;
13use crate::task::Task;
14
15mod dag;
16use dag::BuildDagError;
17use dag::Dag;
18use dag::Edge;
19use dag::NodeData;
20
21#[derive(Clone)]
22pub struct Engine<'a, I, D> {
23    dag: Dag<I>,
24    tasks: Arc<HashMap<I, Arc<DynTask<'a, I, D>>>>,
25}
26
27impl<'a, I, D> Engine<'a, I, D> {
28    pub fn new() -> Self {
29        Self {
30            dag: Dag::new(),
31            tasks: Arc::new(HashMap::new()),
32        }
33    }
34
35    pub fn builder() -> EngineBuilder<'a, I, D> {
36        EngineBuilder::new()
37    }
38}
39
40impl<I, D> Default for Engine<'_, I, D> {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl<'a, 'cx, I, D> Engine<'a, I, D>
47where
48    'a: 'cx,
49    I: Clone + Eq + Hash + Send + 'cx,
50    D: Clone + Send + Sync + 'cx,
51{
52    pub async fn run(&self, context: Context<'cx, I, Option<D>>) {
53        let graph = self.dag.graph();
54        let mut in_degrees: HashMap<_, _> = graph
55            .iter()
56            .map(|(node, NodeData { in_neighbors, .. })| (node, in_neighbors.len()))
57            .collect();
58
59        let mut queue: VecDeque<_> = in_degrees
60            .iter()
61            .flat_map(|(&node, &in_degree)| (in_degree == 0).then_some(node))
62            .collect();
63
64        while let Some(node) = queue.pop_front() {
65            if let Some(task) = self.tasks.get(node).cloned() {
66                let inputs = graph[node]
67                    .in_neighbors
68                    .iter()
69                    .flat_map(|in_neighbor| {
70                        context
71                            .get(in_neighbor)
72                            .map(|data| (in_neighbor.clone(), data))
73                    })
74                    .collect();
75
76                context.set(
77                    node.clone(),
78                    async move { task.run(inputs).await }.boxed().shared(),
79                );
80            }
81
82            for out_neighbor in &graph[node].out_neighbors {
83                let Some(in_degree) = in_degrees.get_mut(out_neighbor) else {
84                    continue;
85                };
86
87                *in_degree -= 1;
88                if *in_degree == 0 {
89                    queue.push_back(out_neighbor);
90                }
91            }
92        }
93
94        graph
95            .iter()
96            .flat_map(|(node, _)| {
97                self.tasks
98                    .get(node)
99                    .filter(|task| task.is_auto())
100                    .and_then(|_| context.get(node))
101            })
102            .collect::<FuturesUnordered<_>>()
103            .collect::<Vec<_>>()
104            .await;
105    }
106}
107
108#[derive(Clone)]
109pub struct EngineBuilder<'a, I, D> {
110    #[allow(clippy::type_complexity)]
111    tasks: Arc<RwLock<HashMap<I, Box<DynTask<'a, I, D>>>>>,
112}
113
114impl<I, D> EngineBuilder<'_, I, D> {
115    pub fn new() -> Self {
116        Self {
117            tasks: Arc::new(RwLock::new(HashMap::new())),
118        }
119    }
120}
121
122impl<I, D> Default for EngineBuilder<'_, I, D> {
123    fn default() -> Self {
124        Self::new()
125    }
126}
127
128impl<I, D> EngineBuilder<'_, I, D>
129where
130    I: Eq + Hash,
131{
132    pub fn exists_task<T>(&self, task: &T) -> bool
133    where
134        T: Task<I, D>,
135    {
136        self.exists_task_by_id(&task.id())
137    }
138
139    pub fn exists_task_by_id(&self, id: &I) -> bool {
140        self.tasks.read().unwrap().contains_key(id)
141    }
142
143    pub fn remove_task<T>(&self, task: &T) -> &Self
144    where
145        T: Task<I, D>,
146    {
147        self.remove_task_by_id(&task.id())
148    }
149
150    pub fn remove_task_by_id(&self, id: &I) -> &Self {
151        self.tasks.write().unwrap().remove(id);
152        self
153    }
154}
155
156impl<'a, I, D> EngineBuilder<'a, I, D>
157where
158    I: Eq + Hash,
159{
160    pub fn add_task<T>(&self, task: T) -> &Self
161    where
162        T: Task<I, D> + 'a,
163    {
164        self.tasks
165            .write()
166            .unwrap()
167            .insert(task.id(), DynTask::new_box(task));
168
169        self
170    }
171}
172
173impl<'a, I, D> EngineBuilder<'a, I, D>
174where
175    I: Clone + Eq + Hash,
176{
177    pub fn build(self) -> Result<Engine<'a, I, D>, BuildEngineError> {
178        let tasks = Arc::into_inner(self.tasks).unwrap().into_inner().unwrap();
179        let mut builder = Dag::builder();
180
181        for id in tasks.keys().cloned() {
182            builder.add_node(id);
183        }
184
185        for (id, task) in &tasks {
186            for dependency in task.dependencies() {
187                builder.add_edge(Edge::new(dependency, id.clone()));
188            }
189        }
190
191        Ok(Engine {
192            dag: builder.build().map_err(EngineErrorKind::DagBuildFailed)?,
193            tasks: Arc::new(
194                tasks
195                    .into_iter()
196                    .map(|(id, task)| (id, task.into()))
197                    .collect(),
198            ),
199        })
200    }
201}
202
203#[derive(Clone, Copy, Debug, Eq, PartialEq, thiserror::Error)]
204#[error(transparent)]
205pub struct BuildEngineError(#[from] EngineErrorKind);
206
207#[derive(Clone, Copy, Debug, Eq, PartialEq, thiserror::Error)]
208enum EngineErrorKind {
209    #[error("failed to build DAG: {0}")]
210    DagBuildFailed(#[from] BuildDagError),
211}