1use std::collections::HashMap;
2use std::collections::VecDeque;
3use std::hash::Hash;
4use std::sync::Arc;
5use std::sync::RwLock;
6
7use futures::FutureExt;
8use futures::StreamExt;
9use futures::stream::FuturesUnordered;
10
11use crate::context::Context;
12use crate::task::DynTask;
13use crate::task::Task;
14
15mod dag;
16use dag::BuildDagError;
17use dag::Dag;
18use dag::Edge;
19use dag::NodeData;
20
21#[derive(Clone)]
22pub struct Engine<'a, I, D> {
23 dag: Dag<I>,
24 tasks: Arc<HashMap<I, Arc<DynTask<'a, I, D>>>>,
25}
26
27impl<'a, I, D> Engine<'a, I, D> {
28 pub fn new() -> Self {
29 Self {
30 dag: Dag::new(),
31 tasks: Arc::new(HashMap::new()),
32 }
33 }
34
35 pub fn builder() -> EngineBuilder<'a, I, D> {
36 EngineBuilder::new()
37 }
38}
39
40impl<I, D> Default for Engine<'_, I, D> {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl<'a, 'cx, I, D> Engine<'a, I, D>
47where
48 'a: 'cx,
49 I: Clone + Eq + Hash + Send + 'cx,
50 D: Clone + Send + Sync + 'cx,
51{
52 pub async fn run(&self, context: Context<'cx, I, Option<D>>) {
53 let graph = self.dag.graph();
54 let mut in_degrees: HashMap<_, _> = graph
55 .iter()
56 .map(|(node, NodeData { in_neighbors, .. })| (node, in_neighbors.len()))
57 .collect();
58
59 let mut queue: VecDeque<_> = in_degrees
60 .iter()
61 .flat_map(|(&node, &in_degree)| if in_degree > 0 { None } else { Some(node) })
62 .collect();
63
64 while let Some(node) = queue.pop_front() {
65 if let Some(task) = self.tasks.get(node).cloned() {
66 let inputs = graph[node]
67 .in_neighbors
68 .iter()
69 .flat_map(|in_neighbor| {
70 context
71 .get(in_neighbor)
72 .map(|data| (in_neighbor.clone(), data))
73 })
74 .collect();
75
76 context.set(
77 node.clone(),
78 async move { task.run(inputs).await }.boxed().shared(),
79 );
80 }
81
82 for out_neighbor in &graph[node].out_neighbors {
83 let in_degree = in_degrees.get_mut(out_neighbor).unwrap();
84 *in_degree -= 1;
85
86 if *in_degree == 0 {
87 queue.push_back(out_neighbor);
88 }
89 }
90 }
91
92 graph
93 .iter()
94 .flat_map(|(node, _)| {
95 if self.tasks.get(node)?.is_auto() {
96 context.get(node)
97 } else {
98 None
99 }
100 })
101 .collect::<FuturesUnordered<_>>()
102 .collect::<Vec<_>>()
103 .await;
104 }
105}
106
107#[derive(Clone)]
108pub struct EngineBuilder<'a, I, D> {
109 #[allow(clippy::type_complexity)]
110 tasks: Arc<RwLock<HashMap<I, Box<DynTask<'a, I, D>>>>>,
111}
112
113impl<I, D> EngineBuilder<'_, I, D> {
114 pub fn new() -> Self {
115 Self {
116 tasks: Arc::new(RwLock::new(HashMap::new())),
117 }
118 }
119}
120
121impl<I, D> Default for EngineBuilder<'_, I, D> {
122 fn default() -> Self {
123 Self::new()
124 }
125}
126
127impl<I, D> EngineBuilder<'_, I, D>
128where
129 I: Eq + Hash,
130{
131 pub fn exists_task<T>(&self, task: &T) -> bool
132 where
133 T: Task<I, D>,
134 {
135 self.exists_task_by_id(&task.id())
136 }
137
138 pub fn exists_task_by_id(&self, id: &I) -> bool {
139 self.tasks.read().unwrap().contains_key(id)
140 }
141
142 pub fn remove_task<T>(&self, task: &T) -> &Self
143 where
144 T: Task<I, D>,
145 {
146 self.remove_task_by_id(&task.id())
147 }
148
149 pub fn remove_task_by_id(&self, id: &I) -> &Self {
150 self.tasks.write().unwrap().remove(id);
151 self
152 }
153}
154
155impl<'a, I, D> EngineBuilder<'a, I, D>
156where
157 I: Eq + Hash,
158{
159 pub fn add_task<T>(&self, task: T) -> &Self
160 where
161 T: Task<I, D> + 'a,
162 {
163 self.tasks
164 .write()
165 .unwrap()
166 .insert(task.id(), DynTask::new_box(task));
167
168 self
169 }
170}
171
172impl<'a, I, D> EngineBuilder<'a, I, D>
173where
174 I: Clone + Eq + Hash,
175{
176 pub fn build(self) -> Result<Engine<'a, I, D>, BuildEngineError> {
177 let tasks = Arc::into_inner(self.tasks).unwrap().into_inner().unwrap();
178 let mut builder = Dag::builder();
179
180 for id in tasks.keys().cloned() {
181 builder.add_node(id);
182 }
183
184 for (id, task) in &tasks {
185 for dependency in task.dependencies() {
186 builder.add_edge(Edge::new(dependency, id.clone()));
187 }
188 }
189
190 Ok(Engine {
191 dag: builder.build().map_err(EngineErrorKind::DagBuildFailed)?,
192 tasks: Arc::new(
193 tasks
194 .into_iter()
195 .map(|(id, task)| (id, task.into()))
196 .collect(),
197 ),
198 })
199 }
200}
201
202#[derive(Clone, Copy, Debug, Eq, PartialEq, thiserror::Error)]
203#[error(transparent)]
204pub struct BuildEngineError(#[from] EngineErrorKind);
205
206#[derive(Clone, Copy, Debug, Eq, PartialEq, thiserror::Error)]
207enum EngineErrorKind {
208 #[error("failed to build DAG")]
209 DagBuildFailed(#[from] BuildDagError),
210}