1use std::collections::HashMap;
2use std::collections::VecDeque;
3use std::hash::Hash;
4use std::sync::Arc;
5use std::sync::RwLock;
6
7use futures::FutureExt;
8use futures::StreamExt;
9use futures::stream::FuturesUnordered;
10
11use crate::context::Context;
12use crate::task::DynTask;
13use crate::task::Input;
14use crate::task::Task;
15
16mod dag;
17use dag::BuildDagError;
18use dag::Dag;
19use dag::Edge;
20use dag::NodeData;
21
22#[derive(Clone)]
23pub struct Engine<'a, I, D> {
24 dag: Dag<I>,
25 tasks: Arc<HashMap<I, Arc<DynTask<'a, I, D>>>>,
26}
27
28impl<'a, I, D> Engine<'a, I, D> {
29 pub fn new() -> Self {
30 Self {
31 dag: Dag::new(),
32 tasks: Arc::new(HashMap::new()),
33 }
34 }
35
36 pub fn builder() -> EngineBuilder<'a, I, D> {
37 EngineBuilder::new()
38 }
39}
40
41impl<I, D> Default for Engine<'_, I, D> {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47impl<'a, 'cx, I, D> Engine<'a, I, D>
48where
49 'a: 'cx,
50 I: Clone + Eq + Hash + Send + 'cx,
51 D: Clone + Send + Sync + 'cx,
52{
53 pub async fn run(&self, context: Context<'cx, I, Option<D>>) {
54 let graph = self.dag.graph();
55 let mut in_degrees: HashMap<_, _> = graph
56 .iter()
57 .map(|(node, NodeData { in_neighbors, .. })| (node, in_neighbors.len()))
58 .collect();
59
60 let mut queue: VecDeque<_> = in_degrees
61 .iter()
62 .flat_map(|(&node, &in_degree)| if in_degree > 0 { None } else { Some(node) })
63 .collect();
64
65 while let Some(node) = queue.pop_front() {
66 if let Some(task) = self.tasks.get(node).cloned() {
67 let input = graph[node]
68 .in_neighbors
69 .iter()
70 .flat_map(|in_neighbor| {
71 context
72 .get(in_neighbor)
73 .map(|data| Input::new(in_neighbor.clone(), data))
74 })
75 .collect();
76
77 context.set(
78 node.clone(),
79 async move { task.run(input).await }.boxed().shared(),
80 );
81 }
82
83 for out_neighbor in &graph[node].out_neighbors {
84 let in_degree = in_degrees.get_mut(out_neighbor).unwrap();
85 *in_degree -= 1;
86
87 if *in_degree == 0 {
88 queue.push_back(out_neighbor);
89 }
90 }
91 }
92
93 graph
94 .iter()
95 .flat_map(|(node, _)| context.get(node))
96 .collect::<FuturesUnordered<_>>()
97 .collect::<Vec<_>>()
98 .await;
99 }
100}
101
102#[derive(Clone)]
103pub struct EngineBuilder<'a, I, D> {
104 #[allow(clippy::type_complexity)]
105 tasks: Arc<RwLock<HashMap<I, Box<DynTask<'a, I, D>>>>>,
106}
107
108impl<I, D> EngineBuilder<'_, I, D> {
109 pub fn new() -> Self {
110 Self {
111 tasks: Arc::new(RwLock::new(HashMap::new())),
112 }
113 }
114}
115
116impl<I, D> Default for EngineBuilder<'_, I, D> {
117 fn default() -> Self {
118 Self::new()
119 }
120}
121
122impl<I, D> EngineBuilder<'_, I, D>
123where
124 I: Eq + Hash,
125{
126 pub fn exists_task<T>(&self, task: &T) -> bool
127 where
128 T: Task<I, D>,
129 {
130 self.exists_task_by_id(task.id())
131 }
132
133 pub fn exists_task_by_id(&self, id: &I) -> bool {
134 self.tasks.read().unwrap().contains_key(id)
135 }
136
137 pub fn remove_task<T>(&self, task: &T) -> &Self
138 where
139 T: Task<I, D>,
140 {
141 self.remove_task_by_id(task.id())
142 }
143
144 pub fn remove_task_by_id(&self, id: &I) -> &Self {
145 self.tasks.write().unwrap().remove(id);
146 self
147 }
148}
149
150impl<'a, I, D> EngineBuilder<'a, I, D>
151where
152 I: Clone + Eq + Hash,
153{
154 pub fn add_task<T>(&self, task: T) -> &Self
155 where
156 T: Task<I, D> + 'a,
157 {
158 self.tasks
159 .write()
160 .unwrap()
161 .insert(task.id().clone(), DynTask::new_box(task));
162
163 self
164 }
165
166 pub fn build(self) -> Result<Engine<'a, I, D>, BuildEngineError> {
167 let tasks = Arc::into_inner(self.tasks).unwrap().into_inner().unwrap();
168 let mut builder = Dag::builder();
169
170 for id in tasks.keys().cloned() {
171 builder.add_node(id);
172 }
173
174 for (id, task) in &tasks {
175 for dependency in task.dependencies().iter().cloned() {
176 builder.add_edge(Edge::new(dependency, id.clone()));
177 }
178 }
179
180 Ok(Engine {
181 dag: builder.build().map_err(EngineErrorKind::DagBuildFailed)?,
182 tasks: Arc::new(
183 tasks
184 .into_iter()
185 .map(|(id, task)| (id, task.into()))
186 .collect(),
187 ),
188 })
189 }
190}
191
192#[derive(Clone, Copy, Debug, Eq, PartialEq, thiserror::Error)]
193#[error(transparent)]
194pub struct BuildEngineError(#[from] EngineErrorKind);
195
196#[derive(Clone, Copy, Debug, Eq, PartialEq, thiserror::Error)]
197enum EngineErrorKind {
198 #[error("failed to build DAG")]
199 DagBuildFailed(#[from] BuildDagError),
200}