1mod graph;
2
3use std::{ fmt, mem, error };
4use std::hash::{ Hash, BuildHasher };
5use std::vec::IntoIter;
6use std::pin::Pin;
7use std::marker::Unpin;
8use std::task::{ Context, Poll };
9use std::collections::hash_map::RandomState;
10use num_traits::{ CheckedAdd, One };
11use futures::stream::futures_unordered::FuturesUnordered;
12use futures::channel::oneshot;
13use futures::lock::BiLock;
14use futures::prelude::*;
15use crate::graph::Graph;
16pub use crate::graph::Index;
17
18
19pub struct TaskGraph<T, I = u32, S = RandomState> {
20 dag: Graph<State<T>, I, S>,
21 pending: Vec<IndexFuture<T, I>>
22}
23
24enum State<T> {
25 Pending {
26 count: usize,
27 task: T
28 },
29 Running,
30}
31
32impl<T, I, S> Default for TaskGraph<T, I, S>
33where
34 I: Default + Hash + PartialEq + Eq,
35 S: Default + BuildHasher + Unpin
36{
37 fn default() -> TaskGraph<T, I, S> {
38 TaskGraph { dag: Graph::default(), pending: Vec::new() }
39 }
40}
41
42impl<T> TaskGraph<T> {
43 pub fn new() -> Self {
44 TaskGraph::default()
45 }
46}
47
48impl<T, I, S> TaskGraph<T, I, S>
49where
50 T: Future + Unpin,
51 I: CheckedAdd + One + Hash + PartialEq + Eq + PartialOrd + Clone + Unpin,
52 S: BuildHasher + Unpin
53{
54 pub fn add_task(&mut self, deps: &[Index<I>], task: T) -> Result<Index<I>, Error<T>> {
55 let mut count = 0;
56 for dep in deps {
57 if dep >= &self.dag.last {
58 return Err(Error::WouldCycle(task));
59 }
60
61 if self.dag.contains(dep) {
62 count += 1;
63 }
64 }
65
66 if count == 0 {
67 match self.dag.add_node(State::Running) {
68 Ok(index) => {
69 self.pending.push(IndexFuture::new(index.clone(), task));
70 Ok(index)
71 },
72 Err(_) => Err(Error::Exhausted(task))
73 }
74 } else {
75 match self.dag.add_node(State::Pending { count, task }) {
76 Ok(index) => {
77 for parent in deps {
78 self.dag.add_edge(parent, index.clone());
79 }
80 Ok(index)
81 },
82 Err(State::Pending { task, .. }) => Err(Error::Exhausted(task)),
83 Err(State::Running) => unreachable!()
84 }
85 }
86 }
87
88 pub fn execute(mut self) -> (Sender<T, I, S>, Execute<T, I, S>) {
89 let queue = FuturesUnordered::new();
90 for fut in self.pending.drain(..) {
91 queue.push(fut);
92 }
93 let (g1, g2) = BiLock::new(self);
94 let (tx, rx) = oneshot::channel();
95 (
96 Sender { inner: g1, tx },
97 Execute { inner: g2, done: Vec::new(), is_canceled: false, queue, rx }
98 )
99 }
100
101 fn walk(&mut self, index: &Index<I>) -> TaskWalker<'_, T, I, S> {
102 let walker = self.dag.walk(index);
103 TaskWalker { dag: &mut self.dag, walker }
104 }
105}
106
107pub struct Sender<T, I=u32, S=RandomState> {
108 inner: BiLock<TaskGraph<T, I, S>>,
109 tx: oneshot::Sender<()>
110}
111
112impl<T, I, S> Sender<T, I, S>
113where
114 T: Future + Unpin,
115 I: CheckedAdd + One + Hash + PartialEq + Eq + PartialOrd + Clone + Unpin,
116 S: BuildHasher + Unpin
117{
118 #[inline]
119 pub fn add_task<'a>(&'a self, deps: &'a [Index<I>], task: T)
120 -> impl Future<Output = Result<Index<I>, Error<T>>> + 'a
121 {
122 async move {
123 let mut graph = self.inner.lock().await;
124 graph.add_task(deps, task)
125 }
126 }
127
128 pub fn abort(self) {
129 let _ = self.tx.send(());
130 }
131}
132
133pub struct Execute<T, I=u32, S=RandomState> {
134 inner: BiLock<TaskGraph<T, I, S>>,
135 queue: FuturesUnordered<IndexFuture<T, I>>,
136 done: Vec<Index<I>>,
137 rx: oneshot::Receiver<()>,
138 is_canceled: bool
139}
140
141impl<T, I, S> Execute<T, I, S>
142where
143 T: Future + Unpin,
144 I: CheckedAdd + One + Hash + PartialEq + Eq + PartialOrd + Clone + Unpin,
145 S: BuildHasher + Unpin
146{
147 fn enqueue(&mut self, cx: &mut Context<'_>) -> Poll<()> {
148 let mut graph = match self.inner.poll_lock(cx) {
149 Poll::Ready(graph) => graph,
150 Poll::Pending => return Poll::Pending
151 };
152
153 for fut in graph.pending.drain(..) {
154 self.queue.push(fut);
155 }
156
157 for index in self.done.drain(..) {
158 for fut in graph.walk(&index) {
159 self.queue.push(fut);
160 }
161 graph.dag.remove_node(&index);
162 }
163
164 Poll::Ready(())
165 }
166}
167
168impl<F, I, S> Stream for Execute<F, I, S>
169where
170 F: Future + Unpin,
171 I: CheckedAdd + One + Hash + PartialEq + Eq + PartialOrd + Clone + Unpin,
172 S: BuildHasher + Unpin
173{
174 type Item = (Index<I>, F::Output);
175
176 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
177 match Pin::new(&mut self.rx).poll(cx) {
178 Poll::Pending => (),
179 Poll::Ready(Ok(())) => return Poll::Ready(None),
180 Poll::Ready(Err(_)) => {
181 self.is_canceled = true;
182 }
183 }
184
185 if let Poll::Pending = self.enqueue(cx) {
186 return Poll::Pending;
187 }
188
189 match Pin::new(&mut self.queue).poll_next(cx) {
190 Poll::Ready(Some((i, item))) => {
191 self.done.push(i.clone());
192 Poll::Ready(Some((i, item)))
193 },
194 Poll::Ready(None) if self.is_canceled => Poll::Ready(None),
195 Poll::Ready(None) | Poll::Pending => Poll::Pending
196 }
197 }
198}
199
200struct IndexFuture<F, I> {
201 index: Index<I>,
202 fut: F
203}
204
205impl<F, I> IndexFuture<F, I> {
206 pub fn new(index: Index<I>, fut: F) -> IndexFuture<F, I> {
207 IndexFuture { index, fut }
208 }
209}
210
211impl<F, I> Future for IndexFuture<F, I>
212where
213 F: Future + Unpin,
214 I: Clone + Unpin
215{
216 type Output = (Index<I>, F::Output);
217
218 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
219 let IndexFuture { index, fut } = self.get_mut();
220
221 match Pin::new(fut).poll(cx) {
222 Poll::Ready(item) => Poll::Ready((index.clone(), item)),
223 Poll::Pending => Poll::Pending,
224 }
225 }
226}
227
228struct TaskWalker<'a, T, I, S> {
229 dag: &'a mut Graph<State<T>, I, S>,
230 walker: IntoIter<Index<I>>
231}
232
233impl<'a, T, I, S> Iterator for TaskWalker<'a, T, I, S>
234where
235 I: CheckedAdd + One + Hash + PartialEq + Eq + Clone,
236 S: BuildHasher + Unpin
237{
238 type Item = IndexFuture<T, I>;
239
240 fn next(&mut self) -> Option<Self::Item> {
241 while let Some(index) = self.walker.next() {
242 let state = match self.dag.get_node_mut(&index) {
243 Some(node) => node,
244 None => continue
245 };
246
247 if let State::Pending { count, .. } = state {
248 *count -= 1;
249 }
250
251 match state {
252 State::Pending { count: 0, .. } => (),
253 _ => continue
254 }
255
256 if let State::Pending { task, .. } = mem::replace(state, State::Running) {
257 return Some(IndexFuture::new(index, task));
258 }
259 }
260
261 None
262 }
263}
264
265pub enum Error<T> {
266 WouldCycle(T),
267 Exhausted(T)
268}
269
270impl<T> fmt::Debug for Error<T> {
271 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
272 match self {
273 Error::WouldCycle(_) => f.debug_struct("WouldCycle").finish(),
274 Error::Exhausted(_) => f.debug_struct("Exhausted").finish()
275 }
276 }
277}
278
279impl<T> fmt::Display for Error<T> {
280 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
281 match self {
282 Error::WouldCycle(_) => write!(f, "would cycle"),
283 Error::Exhausted(_) => write!(f, "index exhausted")
284 }
285 }
286}
287
288impl<T> error::Error for Error<T> {}