dag_flow/
engine.rs

1use std::collections::HashMap;
2use std::collections::VecDeque;
3use std::hash::Hash;
4use std::sync::Arc;
5use std::sync::RwLock;
6
7use futures::FutureExt;
8use futures::StreamExt;
9use futures::stream::FuturesUnordered;
10
11use crate::context::Context;
12use crate::task::DynTask;
13use crate::task::Task;
14
15mod dag;
16use dag::BuildDagError;
17use dag::Dag;
18use dag::Edge;
19use dag::NodeData;
20
21#[derive(Clone)]
22pub struct Engine<'a, I, D> {
23    dag: Dag<I>,
24    tasks: Arc<HashMap<I, Arc<DynTask<'a, I, D>>>>,
25}
26
27impl<'a, I, D> Engine<'a, I, D> {
28    pub fn new() -> Self {
29        Self {
30            dag: Dag::new(),
31            tasks: Arc::new(HashMap::new()),
32        }
33    }
34
35    pub fn builder() -> EngineBuilder<'a, I, D> {
36        EngineBuilder::new()
37    }
38}
39
40impl<I, D> Default for Engine<'_, I, D> {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl<'a, 'cx, I, D> Engine<'a, I, D>
47where
48    'a: 'cx,
49    I: Clone + Eq + Hash + Send + 'cx,
50    D: Clone + Send + Sync + 'cx,
51{
52    pub async fn run(&self, context: Context<'cx, I, Option<D>>) {
53        let graph = self.dag.graph();
54        let mut in_degrees: HashMap<_, _> = graph
55            .iter()
56            .map(|(node, NodeData { in_neighbors, .. })| (node, in_neighbors.len()))
57            .collect();
58
59        let mut queue: VecDeque<_> = in_degrees
60            .iter()
61            .flat_map(|(&node, &in_degree)| if in_degree > 0 { None } else { Some(node) })
62            .collect();
63
64        while let Some(node) = queue.pop_front() {
65            if let Some(task) = self.tasks.get(node).cloned() {
66                let inputs = graph[node]
67                    .in_neighbors
68                    .iter()
69                    .flat_map(|in_neighbor| {
70                        context
71                            .get(in_neighbor)
72                            .map(|data| (in_neighbor.clone(), data))
73                    })
74                    .collect();
75
76                context.set(
77                    node.clone(),
78                    async move { task.run(inputs).await }.boxed().shared(),
79                );
80            }
81
82            for out_neighbor in &graph[node].out_neighbors {
83                let in_degree = in_degrees.get_mut(out_neighbor).unwrap();
84                *in_degree -= 1;
85
86                if *in_degree == 0 {
87                    queue.push_back(out_neighbor);
88                }
89            }
90        }
91
92        graph
93            .iter()
94            .flat_map(|(node, _)| {
95                if self.tasks.get(node)?.is_auto() {
96                    context.get(node)
97                } else {
98                    None
99                }
100            })
101            .collect::<FuturesUnordered<_>>()
102            .collect::<Vec<_>>()
103            .await;
104    }
105}
106
107#[derive(Clone)]
108pub struct EngineBuilder<'a, I, D> {
109    #[allow(clippy::type_complexity)]
110    tasks: Arc<RwLock<HashMap<I, Box<DynTask<'a, I, D>>>>>,
111}
112
113impl<I, D> EngineBuilder<'_, I, D> {
114    pub fn new() -> Self {
115        Self {
116            tasks: Arc::new(RwLock::new(HashMap::new())),
117        }
118    }
119}
120
121impl<I, D> Default for EngineBuilder<'_, I, D> {
122    fn default() -> Self {
123        Self::new()
124    }
125}
126
127impl<I, D> EngineBuilder<'_, I, D>
128where
129    I: Eq + Hash,
130{
131    pub fn exists_task<T>(&self, task: &T) -> bool
132    where
133        T: Task<I, D>,
134    {
135        self.exists_task_by_id(&task.id())
136    }
137
138    pub fn exists_task_by_id(&self, id: &I) -> bool {
139        self.tasks.read().unwrap().contains_key(id)
140    }
141
142    pub fn remove_task<T>(&self, task: &T) -> &Self
143    where
144        T: Task<I, D>,
145    {
146        self.remove_task_by_id(&task.id())
147    }
148
149    pub fn remove_task_by_id(&self, id: &I) -> &Self {
150        self.tasks.write().unwrap().remove(id);
151        self
152    }
153}
154
155impl<'a, I, D> EngineBuilder<'a, I, D>
156where
157    I: Eq + Hash,
158{
159    pub fn add_task<T>(&self, task: T) -> &Self
160    where
161        T: Task<I, D> + 'a,
162    {
163        self.tasks
164            .write()
165            .unwrap()
166            .insert(task.id(), DynTask::new_box(task));
167
168        self
169    }
170}
171
172impl<'a, I, D> EngineBuilder<'a, I, D>
173where
174    I: Clone + Eq + Hash,
175{
176    pub fn build(self) -> Result<Engine<'a, I, D>, BuildEngineError> {
177        let tasks = Arc::into_inner(self.tasks).unwrap().into_inner().unwrap();
178        let mut builder = Dag::builder();
179
180        for id in tasks.keys().cloned() {
181            builder.add_node(id);
182        }
183
184        for (id, task) in &tasks {
185            for dependency in task.dependencies() {
186                builder.add_edge(Edge::new(dependency, id.clone()));
187            }
188        }
189
190        Ok(Engine {
191            dag: builder.build().map_err(EngineErrorKind::DagBuildFailed)?,
192            tasks: Arc::new(
193                tasks
194                    .into_iter()
195                    .map(|(id, task)| (id, task.into()))
196                    .collect(),
197            ),
198        })
199    }
200}
201
202#[derive(Clone, Copy, Debug, Eq, PartialEq, thiserror::Error)]
203#[error(transparent)]
204pub struct BuildEngineError(#[from] EngineErrorKind);
205
206#[derive(Clone, Copy, Debug, Eq, PartialEq, thiserror::Error)]
207enum EngineErrorKind {
208    #[error("failed to build DAG")]
209    DagBuildFailed(#[from] BuildDagError),
210}