dep_graph/
graph_par.rs

1use crate::{
2    error::Error,
3    graph::{remove_node_id, DepGraph, DependencyMap},
4};
5use crossbeam_channel::{Receiver, Sender};
6
7use rayon::iter::{
8    plumbing::{bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer},
9    IndexedParallelIterator, IntoParallelIterator, ParallelIterator,
10};
11use std::cmp;
12
13use std::fmt;
14use std::hash::{Hash, Hasher};
15use std::iter::{DoubleEndedIterator, ExactSizeIterator};
16
17use std::ops;
18use std::sync::{
19    atomic::{AtomicUsize, Ordering},
20    Arc, RwLock,
21};
22use std::thread;
23use std::time::Duration;
24
25/// Default timeout in milliseconds
26const DEFAULT_TIMEOUT: Duration = Duration::from_millis(1000);
27
28/// Add into_par_iter() to DepGraph
29impl<I> IntoParallelIterator for DepGraph<I>
30where
31    I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
32{
33    type Item = Wrapper<I>;
34    type Iter = DepGraphParIter<I>;
35
36    fn into_par_iter(self) -> Self::Iter {
37        DepGraphParIter::new(self.ready_nodes, self.deps, self.rdeps)
38    }
39}
40
41/// Wrapper for an item
42///
43/// This is used to pass items through parallel iterators. When the wrapper is
44/// dropped, we decrement the processing `counter` and notify the dispatcher
45/// thread through the `item_done_tx` channel.
46#[derive(Clone)]
47pub struct Wrapper<I>
48where
49    I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
50{
51    // Wrapped item
52    inner: I,
53    // Reference to the number of items being currently processed
54    counter: Arc<AtomicUsize>,
55    // Channel to notify that the item is done processing (upon drop)
56    item_done_tx: Sender<I>,
57}
58
59impl<I> Wrapper<I>
60where
61    I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
62{
63    /// Create a new Wrapper item
64    ///
65    /// This needs a reference to the processing counter to keep count of the
66    /// number of items currently processed (used to check for circular
67    /// dependencies) and the item done channel to notify the dispatcher
68    /// thread.
69    ///
70    /// Upon creating of a `Wrapper`, we also increment the processing counter.
71    pub fn new(inner: I, counter: Arc<AtomicUsize>, item_done_tx: Sender<I>) -> Self {
72        (*counter).fetch_add(1, Ordering::SeqCst);
73        Self {
74            inner,
75            counter,
76            item_done_tx,
77        }
78    }
79}
80
81/// Drop implementation to decrement the processing counter and notify the
82/// dispatcher thread.
83impl<I> Drop for Wrapper<I>
84where
85    I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
86{
87    /// Triggered when the wrapper is dropped.
88    ///
89    /// This will decrement the processing counter and notify the dispatcher thread.
90    fn drop(&mut self) {
91        (*self.counter).fetch_sub(1, Ordering::SeqCst);
92        self.item_done_tx
93            .send(self.inner.clone())
94            .expect("could not send message")
95    }
96}
97
98/// Dereference implementation to access the inner item
99///
100/// This allow accessing the item using `(*wrapper)`.
101impl<I> ops::Deref for Wrapper<I>
102where
103    I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
104{
105    type Target = I;
106
107    fn deref(&self) -> &Self::Target {
108        &self.inner
109    }
110}
111
112/// Dereference implementation to access the inner item
113///
114/// This allow accessing the item using `(*wrapper)`.
115impl<I> ops::DerefMut for Wrapper<I>
116where
117    I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
118{
119    fn deref_mut(&mut self) -> &mut Self::Target {
120        &mut self.inner
121    }
122}
123
124impl<I> Eq for Wrapper<I> where I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static
125{}
126
127impl<I> Hash for Wrapper<I>
128where
129    I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
130{
131    fn hash<H: Hasher>(&self, state: &mut H) {
132        self.inner.hash(state)
133    }
134}
135
136impl<I> cmp::PartialEq for Wrapper<I>
137where
138    I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
139{
140    fn eq(&self, other: &Self) -> bool {
141        self.inner == other.inner
142    }
143}
144
145/// Parallel iterator for DepGraph
146pub struct DepGraphParIter<I>
147where
148    I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
149{
150    timeout: Arc<RwLock<Duration>>,
151    counter: Arc<AtomicUsize>,
152    item_ready_rx: Receiver<I>,
153    item_done_tx: Sender<I>,
154}
155
156impl<I> DepGraphParIter<I>
157where
158    I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
159{
160    /// Create a new parallel iterator
161    ///
162    /// This will create a thread and crossbeam channels to listen/send
163    /// available and processed nodes.
164    pub fn new(ready_nodes: Vec<I>, deps: DependencyMap<I>, rdeps: DependencyMap<I>) -> Self {
165        let timeout = Arc::new(RwLock::new(DEFAULT_TIMEOUT));
166        let counter = Arc::new(AtomicUsize::new(0));
167
168        // Create communication channel for processed nodes
169        let (item_ready_tx, item_ready_rx) = crossbeam_channel::unbounded::<I>();
170        let (item_done_tx, item_done_rx) = crossbeam_channel::unbounded::<I>();
171
172        // Inject ready nodes
173        ready_nodes
174            .iter()
175            .for_each(|node| item_ready_tx.send(node.clone()).unwrap());
176
177        // Clone Arcs for dispatcher thread
178        let loop_timeout = timeout.clone();
179        let loop_counter = counter.clone();
180
181        // Start dispatcher thread
182        thread::spawn(move || {
183            loop {
184                crossbeam_channel::select! {
185                    // Grab a processed node ID
186                    recv(item_done_rx) -> id => {
187                        let id = id.unwrap();
188                        // Remove the node from all reverse dependencies
189                        let next_nodes = remove_node_id::<I>(id, &deps, &rdeps)?;
190
191                        // Send the next available nodes to the channel.
192                        next_nodes
193                            .iter()
194                            .for_each(|node_id| item_ready_tx.send(node_id.clone()).unwrap());
195
196                        // If there are no more nodes, leave the loop
197                        if deps.read().unwrap().is_empty() {
198                            break;
199                        }
200                    },
201                    // Timeout
202                    default(*loop_timeout.read().unwrap()) => {
203                        let deps = deps.read().unwrap();
204                        let counter_val = loop_counter.load(Ordering::SeqCst);
205                        if deps.is_empty() {
206                            break;
207                        // There are still some items processing.
208                        } else if counter_val > 0 {
209                            continue;
210                        } else {
211                            return Err(Error::ResolveGraphError("circular dependency detected"));
212                        }
213                    },
214                };
215            }
216
217            // Drop channel
218            // This will close threads listening to it
219            drop(item_ready_tx);
220            Ok(())
221        });
222
223        DepGraphParIter {
224            timeout,
225            counter,
226
227            item_ready_rx,
228            item_done_tx,
229        }
230    }
231
232    pub fn with_timeout(self, timeout: Duration) -> Self {
233        *self.timeout.write().unwrap() = timeout;
234        self
235    }
236}
237
238impl<I> ParallelIterator for DepGraphParIter<I>
239where
240    I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
241{
242    type Item = Wrapper<I>;
243
244    fn drive_unindexed<C>(self, consumer: C) -> C::Result
245    where
246        C: UnindexedConsumer<Self::Item>,
247    {
248        bridge(self, consumer)
249    }
250}
251
252impl<I> IndexedParallelIterator for DepGraphParIter<I>
253where
254    I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
255{
256    fn len(&self) -> usize {
257        num_cpus::get()
258    }
259
260    fn drive<C>(self, consumer: C) -> C::Result
261    where
262        C: Consumer<Self::Item>,
263    {
264        bridge(self, consumer)
265    }
266
267    fn with_producer<CB>(self, callback: CB) -> CB::Output
268    where
269        CB: ProducerCallback<Self::Item>,
270    {
271        callback.callback(DepGraphProducer {
272            counter: self.counter.clone(),
273            item_ready_rx: self.item_ready_rx,
274            item_done_tx: self.item_done_tx,
275        })
276    }
277}
278
279struct DepGraphProducer<I>
280where
281    I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
282{
283    counter: Arc<AtomicUsize>,
284    item_ready_rx: Receiver<I>,
285    item_done_tx: Sender<I>,
286}
287
288impl<I> Iterator for DepGraphProducer<I>
289where
290    I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
291{
292    type Item = Wrapper<I>;
293
294    fn next(&mut self) -> Option<Self::Item> {
295        // TODO: Check until there is an item available
296        match self.item_ready_rx.recv() {
297            Ok(item) => Some(Wrapper::new(
298                item,
299                self.counter.clone(),
300                self.item_done_tx.clone(),
301            )),
302            Err(_) => None,
303        }
304    }
305}
306
307impl<I> DoubleEndedIterator for DepGraphProducer<I>
308where
309    I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
310{
311    fn next_back(&mut self) -> Option<Self::Item> {
312        self.next()
313    }
314}
315
316impl<I> ExactSizeIterator for DepGraphProducer<I> where
317    I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static
318{
319}
320
321impl<I> Producer for DepGraphProducer<I>
322where
323    I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
324{
325    type Item = Wrapper<I>;
326    type IntoIter = Self;
327
328    fn into_iter(self) -> Self::IntoIter {
329        Self {
330            counter: self.counter.clone(),
331            item_ready_rx: self.item_ready_rx.clone(),
332            item_done_tx: self.item_done_tx,
333        }
334    }
335
336    fn split_at(self, _: usize) -> (Self, Self) {
337        (
338            Self {
339                counter: self.counter.clone(),
340                item_ready_rx: self.item_ready_rx.clone(),
341                item_done_tx: self.item_done_tx.clone(),
342            },
343            Self {
344                counter: self.counter.clone(),
345                item_ready_rx: self.item_ready_rx.clone(),
346                item_done_tx: self.item_done_tx,
347            },
348        )
349    }
350}