dag_flow/
engine.rs

1use std::collections::HashMap;
2use std::collections::VecDeque;
3use std::hash::Hash;
4use std::sync::Arc;
5use std::sync::RwLock;
6
7use futures::FutureExt;
8use futures::StreamExt;
9use futures::stream::FuturesUnordered;
10
11use crate::context::Context;
12use crate::task::DynTask;
13use crate::task::Input;
14use crate::task::Task;
15
16mod dag;
17use dag::BuildDagError;
18use dag::Dag;
19use dag::Edge;
20use dag::NodeData;
21
22#[derive(Clone)]
23pub struct Engine<'a, I, D> {
24    dag: Dag<I>,
25    tasks: Arc<HashMap<I, Arc<DynTask<'a, I, D>>>>,
26}
27
28impl<'a, I, D> Engine<'a, I, D> {
29    pub fn new() -> Self {
30        Self {
31            dag: Dag::new(),
32            tasks: Arc::new(HashMap::new()),
33        }
34    }
35
36    pub fn builder() -> EngineBuilder<'a, I, D> {
37        EngineBuilder::new()
38    }
39}
40
41impl<I, D> Default for Engine<'_, I, D> {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47impl<'a, 'cx, I, D> Engine<'a, I, D>
48where
49    'a: 'cx,
50    I: Clone + Eq + Hash + Send + 'cx,
51    D: Clone + Send + Sync + 'cx,
52{
53    pub async fn run(&self, context: Context<'cx, I, Option<D>>) {
54        let graph = self.dag.graph();
55        let mut in_degrees: HashMap<_, _> = graph
56            .iter()
57            .map(|(node, NodeData { in_neighbors, .. })| (node, in_neighbors.len()))
58            .collect();
59
60        let mut queue: VecDeque<_> = in_degrees
61            .iter()
62            .flat_map(|(&node, &in_degree)| if in_degree > 0 { None } else { Some(node) })
63            .collect();
64
65        while let Some(node) = queue.pop_front() {
66            if let Some(task) = self.tasks.get(node).cloned() {
67                let input = graph[node]
68                    .in_neighbors
69                    .iter()
70                    .flat_map(|in_neighbor| {
71                        context
72                            .get(in_neighbor)
73                            .map(|data| Input::new(in_neighbor.clone(), data))
74                    })
75                    .collect();
76
77                context.set(
78                    node.clone(),
79                    async move { task.run(input).await }.boxed().shared(),
80                );
81            }
82
83            for out_neighbor in &graph[node].out_neighbors {
84                let in_degree = in_degrees.get_mut(out_neighbor).unwrap();
85                *in_degree -= 1;
86
87                if *in_degree == 0 {
88                    queue.push_back(out_neighbor);
89                }
90            }
91        }
92
93        graph
94            .iter()
95            .flat_map(|(node, _)| context.get(node))
96            .collect::<FuturesUnordered<_>>()
97            .collect::<Vec<_>>()
98            .await;
99    }
100}
101
102#[derive(Clone)]
103pub struct EngineBuilder<'a, I, D> {
104    #[allow(clippy::type_complexity)]
105    tasks: Arc<RwLock<HashMap<I, Box<DynTask<'a, I, D>>>>>,
106}
107
108impl<I, D> EngineBuilder<'_, I, D> {
109    pub fn new() -> Self {
110        Self {
111            tasks: Arc::new(RwLock::new(HashMap::new())),
112        }
113    }
114}
115
116impl<I, D> Default for EngineBuilder<'_, I, D> {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122impl<I, D> EngineBuilder<'_, I, D>
123where
124    I: Eq + Hash,
125{
126    pub fn exists_task<T>(&self, task: &T) -> bool
127    where
128        T: Task<I, D>,
129    {
130        self.exists_task_by_id(task.id())
131    }
132
133    pub fn exists_task_by_id(&self, id: &I) -> bool {
134        self.tasks.read().unwrap().contains_key(id)
135    }
136
137    pub fn remove_task<T>(&self, task: &T) -> &Self
138    where
139        T: Task<I, D>,
140    {
141        self.remove_task_by_id(task.id())
142    }
143
144    pub fn remove_task_by_id(&self, id: &I) -> &Self {
145        self.tasks.write().unwrap().remove(id);
146        self
147    }
148}
149
150impl<'a, I, D> EngineBuilder<'a, I, D>
151where
152    I: Clone + Eq + Hash,
153{
154    pub fn add_task<T>(&self, task: T) -> &Self
155    where
156        T: Task<I, D> + 'a,
157    {
158        self.tasks
159            .write()
160            .unwrap()
161            .insert(task.id().clone(), DynTask::new_box(task));
162
163        self
164    }
165
166    pub fn build(self) -> Result<Engine<'a, I, D>, BuildEngineError> {
167        let tasks = Arc::into_inner(self.tasks).unwrap().into_inner().unwrap();
168        let mut builder = Dag::builder();
169
170        for id in tasks.keys().cloned() {
171            builder.add_node(id);
172        }
173
174        for (id, task) in &tasks {
175            for dependency in task.dependencies().iter().cloned() {
176                builder.add_edge(Edge::new(dependency, id.clone()));
177            }
178        }
179
180        Ok(Engine {
181            dag: builder.build().map_err(EngineErrorKind::DagBuildFailed)?,
182            tasks: Arc::new(
183                tasks
184                    .into_iter()
185                    .map(|(id, task)| (id, task.into()))
186                    .collect(),
187            ),
188        })
189    }
190}
191
192#[derive(Clone, Copy, Debug, Eq, PartialEq, thiserror::Error)]
193#[error(transparent)]
194pub struct BuildEngineError(#[from] EngineErrorKind);
195
196#[derive(Clone, Copy, Debug, Eq, PartialEq, thiserror::Error)]
197enum EngineErrorKind {
198    #[error("failed to build DAG")]
199    DagBuildFailed(#[from] BuildDagError),
200}