futures_dagtask/
lib.rs

1mod graph;
2
3use std::{ fmt, mem, error };
4use std::hash::{ Hash, BuildHasher };
5use std::vec::IntoIter;
6use std::pin::Pin;
7use std::marker::Unpin;
8use std::task::{ Context, Poll };
9use std::collections::hash_map::RandomState;
10use num_traits::{ CheckedAdd, One };
11use futures::stream::futures_unordered::FuturesUnordered;
12use futures::channel::oneshot;
13use futures::lock::BiLock;
14use futures::prelude::*;
15use crate::graph::Graph;
16pub use crate::graph::Index;
17
18
19pub struct TaskGraph<T, I = u32, S = RandomState> {
20    dag: Graph<State<T>, I, S>,
21    pending: Vec<IndexFuture<T, I>>
22}
23
24enum State<T> {
25    Pending {
26        count: usize,
27        task: T
28    },
29    Running,
30}
31
32impl<T, I, S> Default for TaskGraph<T, I, S>
33where
34    I: Default + Hash + PartialEq + Eq,
35    S: Default + BuildHasher + Unpin
36{
37    fn default() -> TaskGraph<T, I, S> {
38        TaskGraph { dag: Graph::default(), pending: Vec::new() }
39    }
40}
41
42impl<T> TaskGraph<T> {
43    pub fn new() -> Self {
44        TaskGraph::default()
45    }
46}
47
48impl<T, I, S> TaskGraph<T, I, S>
49where
50    T: Future + Unpin,
51    I: CheckedAdd + One + Hash + PartialEq + Eq + PartialOrd + Clone + Unpin,
52    S: BuildHasher + Unpin
53{
54    pub fn add_task(&mut self, deps: &[Index<I>], task: T) -> Result<Index<I>, Error<T>> {
55        let mut count = 0;
56        for dep in deps {
57            if dep >= &self.dag.last {
58                return Err(Error::WouldCycle(task));
59            }
60
61            if self.dag.contains(dep) {
62                count += 1;
63            }
64        }
65
66        if count == 0 {
67            match self.dag.add_node(State::Running) {
68                Ok(index) => {
69                    self.pending.push(IndexFuture::new(index.clone(), task));
70                    Ok(index)
71                },
72                Err(_) => Err(Error::Exhausted(task))
73            }
74        } else {
75            match self.dag.add_node(State::Pending { count, task }) {
76                Ok(index) => {
77                    for parent in deps {
78                        self.dag.add_edge(parent, index.clone());
79                    }
80                    Ok(index)
81                },
82                Err(State::Pending { task, .. }) => Err(Error::Exhausted(task)),
83                Err(State::Running) => unreachable!()
84            }
85        }
86    }
87
88    pub fn execute(mut self) -> (Sender<T, I, S>, Execute<T, I, S>) {
89        let queue = FuturesUnordered::new();
90        for fut in self.pending.drain(..) {
91            queue.push(fut);
92        }
93        let (g1, g2) = BiLock::new(self);
94        let (tx, rx) = oneshot::channel();
95        (
96            Sender { inner: g1, tx },
97            Execute { inner: g2, done: Vec::new(), is_canceled: false, queue, rx }
98        )
99    }
100
101    fn walk(&mut self, index: &Index<I>) -> TaskWalker<'_, T, I, S> {
102        let walker = self.dag.walk(index);
103        TaskWalker { dag: &mut self.dag, walker }
104    }
105}
106
107pub struct Sender<T, I=u32, S=RandomState> {
108    inner: BiLock<TaskGraph<T, I, S>>,
109    tx: oneshot::Sender<()>
110}
111
112impl<T, I, S> Sender<T, I, S>
113where
114    T: Future + Unpin,
115    I: CheckedAdd + One + Hash + PartialEq + Eq + PartialOrd + Clone + Unpin,
116    S: BuildHasher + Unpin
117{
118    #[inline]
119    pub fn add_task<'a>(&'a self, deps: &'a [Index<I>], task: T)
120        -> impl Future<Output = Result<Index<I>, Error<T>>> + 'a
121    {
122        async move {
123            let mut graph = self.inner.lock().await;
124            graph.add_task(deps, task)
125        }
126    }
127
128    pub fn abort(self) {
129        let _ = self.tx.send(());
130    }
131}
132
133pub struct Execute<T, I=u32, S=RandomState> {
134    inner: BiLock<TaskGraph<T, I, S>>,
135    queue: FuturesUnordered<IndexFuture<T, I>>,
136    done: Vec<Index<I>>,
137    rx: oneshot::Receiver<()>,
138    is_canceled: bool
139}
140
141impl<T, I, S> Execute<T, I, S>
142where
143    T: Future + Unpin,
144    I: CheckedAdd + One + Hash + PartialEq + Eq + PartialOrd + Clone + Unpin,
145    S: BuildHasher + Unpin
146{
147    fn enqueue(&mut self, cx: &mut Context<'_>) -> Poll<()> {
148        let mut graph = match self.inner.poll_lock(cx) {
149            Poll::Ready(graph) => graph,
150            Poll::Pending => return Poll::Pending
151        };
152
153        for fut in graph.pending.drain(..) {
154            self.queue.push(fut);
155        }
156
157        for index in self.done.drain(..) {
158            for fut in graph.walk(&index) {
159                self.queue.push(fut);
160            }
161            graph.dag.remove_node(&index);
162        }
163
164        Poll::Ready(())
165    }
166}
167
168impl<F, I, S> Stream for Execute<F, I, S>
169where
170    F: Future + Unpin,
171    I: CheckedAdd + One + Hash + PartialEq + Eq + PartialOrd + Clone + Unpin,
172    S: BuildHasher + Unpin
173{
174    type Item = (Index<I>, F::Output);
175
176    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
177        match Pin::new(&mut self.rx).poll(cx) {
178            Poll::Pending => (),
179            Poll::Ready(Ok(())) => return Poll::Ready(None),
180            Poll::Ready(Err(_)) => {
181                self.is_canceled = true;
182            }
183        }
184
185        if let Poll::Pending = self.enqueue(cx) {
186            return Poll::Pending;
187        }
188
189        match Pin::new(&mut self.queue).poll_next(cx) {
190            Poll::Ready(Some((i, item))) => {
191                self.done.push(i.clone());
192                Poll::Ready(Some((i, item)))
193            },
194            Poll::Ready(None) if self.is_canceled => Poll::Ready(None),
195            Poll::Ready(None) | Poll::Pending => Poll::Pending
196        }
197    }
198}
199
200struct IndexFuture<F, I> {
201    index: Index<I>,
202    fut: F
203}
204
205impl<F, I> IndexFuture<F, I> {
206    pub fn new(index: Index<I>, fut: F) -> IndexFuture<F, I> {
207        IndexFuture { index, fut }
208    }
209}
210
211impl<F, I> Future for IndexFuture<F, I>
212where
213    F: Future + Unpin,
214    I: Clone + Unpin
215{
216    type Output = (Index<I>, F::Output);
217
218    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
219        let IndexFuture { index, fut } = self.get_mut();
220
221        match Pin::new(fut).poll(cx) {
222            Poll::Ready(item) => Poll::Ready((index.clone(), item)),
223            Poll::Pending => Poll::Pending,
224        }
225    }
226}
227
228struct TaskWalker<'a, T, I, S> {
229    dag: &'a mut Graph<State<T>, I, S>,
230    walker: IntoIter<Index<I>>
231}
232
233impl<'a, T, I, S> Iterator for TaskWalker<'a, T, I, S>
234where
235    I: CheckedAdd + One + Hash + PartialEq + Eq + Clone,
236    S: BuildHasher + Unpin
237{
238    type Item = IndexFuture<T, I>;
239
240    fn next(&mut self) -> Option<Self::Item> {
241        while let Some(index) = self.walker.next() {
242            let state = match self.dag.get_node_mut(&index) {
243                Some(node) => node,
244                None => continue
245            };
246
247            if let State::Pending { count, .. } = state {
248                *count -= 1;
249            }
250
251            match state {
252                State::Pending { count: 0, .. } => (),
253                _ => continue
254            }
255
256            if let State::Pending { task, .. } = mem::replace(state, State::Running) {
257                return Some(IndexFuture::new(index, task));
258            }
259        }
260
261        None
262    }
263}
264
265pub enum Error<T> {
266    WouldCycle(T),
267    Exhausted(T)
268}
269
270impl<T> fmt::Debug for Error<T> {
271    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
272        match self {
273            Error::WouldCycle(_) => f.debug_struct("WouldCycle").finish(),
274            Error::Exhausted(_) => f.debug_struct("Exhausted").finish()
275        }
276    }
277}
278
279impl<T> fmt::Display for Error<T> {
280    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
281        match self {
282            Error::WouldCycle(_) => write!(f, "would cycle"),
283            Error::Exhausted(_) => write!(f, "index exhausted")
284        }
285    }
286}
287
288impl<T> error::Error for Error<T> {}