1use crate::{
2 error::Error,
3 graph::{remove_node_id, DepGraph, DependencyMap},
4};
5use crossbeam_channel::{Receiver, Sender};
6
7use rayon::iter::{
8 plumbing::{bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer},
9 IndexedParallelIterator, IntoParallelIterator, ParallelIterator,
10};
11use std::cmp;
12
13use std::fmt;
14use std::hash::{Hash, Hasher};
15use std::iter::{DoubleEndedIterator, ExactSizeIterator};
16
17use std::ops;
18use std::sync::{
19 atomic::{AtomicUsize, Ordering},
20 Arc, RwLock,
21};
22use std::thread;
23use std::time::Duration;
24
25const DEFAULT_TIMEOUT: Duration = Duration::from_millis(1000);
27
28impl<I> IntoParallelIterator for DepGraph<I>
30where
31 I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
32{
33 type Item = Wrapper<I>;
34 type Iter = DepGraphParIter<I>;
35
36 fn into_par_iter(self) -> Self::Iter {
37 DepGraphParIter::new(self.ready_nodes, self.deps, self.rdeps)
38 }
39}
40
41#[derive(Clone)]
47pub struct Wrapper<I>
48where
49 I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
50{
51 inner: I,
53 counter: Arc<AtomicUsize>,
55 item_done_tx: Sender<I>,
57}
58
59impl<I> Wrapper<I>
60where
61 I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
62{
63 pub fn new(inner: I, counter: Arc<AtomicUsize>, item_done_tx: Sender<I>) -> Self {
72 (*counter).fetch_add(1, Ordering::SeqCst);
73 Self {
74 inner,
75 counter,
76 item_done_tx,
77 }
78 }
79}
80
81impl<I> Drop for Wrapper<I>
84where
85 I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
86{
87 fn drop(&mut self) {
91 (*self.counter).fetch_sub(1, Ordering::SeqCst);
92 self.item_done_tx
93 .send(self.inner.clone())
94 .expect("could not send message")
95 }
96}
97
98impl<I> ops::Deref for Wrapper<I>
102where
103 I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
104{
105 type Target = I;
106
107 fn deref(&self) -> &Self::Target {
108 &self.inner
109 }
110}
111
112impl<I> ops::DerefMut for Wrapper<I>
116where
117 I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
118{
119 fn deref_mut(&mut self) -> &mut Self::Target {
120 &mut self.inner
121 }
122}
123
124impl<I> Eq for Wrapper<I> where I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static
125{}
126
127impl<I> Hash for Wrapper<I>
128where
129 I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
130{
131 fn hash<H: Hasher>(&self, state: &mut H) {
132 self.inner.hash(state)
133 }
134}
135
136impl<I> cmp::PartialEq for Wrapper<I>
137where
138 I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
139{
140 fn eq(&self, other: &Self) -> bool {
141 self.inner == other.inner
142 }
143}
144
145pub struct DepGraphParIter<I>
147where
148 I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
149{
150 timeout: Arc<RwLock<Duration>>,
151 counter: Arc<AtomicUsize>,
152 item_ready_rx: Receiver<I>,
153 item_done_tx: Sender<I>,
154}
155
156impl<I> DepGraphParIter<I>
157where
158 I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
159{
160 pub fn new(ready_nodes: Vec<I>, deps: DependencyMap<I>, rdeps: DependencyMap<I>) -> Self {
165 let timeout = Arc::new(RwLock::new(DEFAULT_TIMEOUT));
166 let counter = Arc::new(AtomicUsize::new(0));
167
168 let (item_ready_tx, item_ready_rx) = crossbeam_channel::unbounded::<I>();
170 let (item_done_tx, item_done_rx) = crossbeam_channel::unbounded::<I>();
171
172 ready_nodes
174 .iter()
175 .for_each(|node| item_ready_tx.send(node.clone()).unwrap());
176
177 let loop_timeout = timeout.clone();
179 let loop_counter = counter.clone();
180
181 thread::spawn(move || {
183 loop {
184 crossbeam_channel::select! {
185 recv(item_done_rx) -> id => {
187 let id = id.unwrap();
188 let next_nodes = remove_node_id::<I>(id, &deps, &rdeps)?;
190
191 next_nodes
193 .iter()
194 .for_each(|node_id| item_ready_tx.send(node_id.clone()).unwrap());
195
196 if deps.read().unwrap().is_empty() {
198 break;
199 }
200 },
201 default(*loop_timeout.read().unwrap()) => {
203 let deps = deps.read().unwrap();
204 let counter_val = loop_counter.load(Ordering::SeqCst);
205 if deps.is_empty() {
206 break;
207 } else if counter_val > 0 {
209 continue;
210 } else {
211 return Err(Error::ResolveGraphError("circular dependency detected"));
212 }
213 },
214 };
215 }
216
217 drop(item_ready_tx);
220 Ok(())
221 });
222
223 DepGraphParIter {
224 timeout,
225 counter,
226
227 item_ready_rx,
228 item_done_tx,
229 }
230 }
231
232 pub fn with_timeout(self, timeout: Duration) -> Self {
233 *self.timeout.write().unwrap() = timeout;
234 self
235 }
236}
237
238impl<I> ParallelIterator for DepGraphParIter<I>
239where
240 I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
241{
242 type Item = Wrapper<I>;
243
244 fn drive_unindexed<C>(self, consumer: C) -> C::Result
245 where
246 C: UnindexedConsumer<Self::Item>,
247 {
248 bridge(self, consumer)
249 }
250}
251
252impl<I> IndexedParallelIterator for DepGraphParIter<I>
253where
254 I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
255{
256 fn len(&self) -> usize {
257 num_cpus::get()
258 }
259
260 fn drive<C>(self, consumer: C) -> C::Result
261 where
262 C: Consumer<Self::Item>,
263 {
264 bridge(self, consumer)
265 }
266
267 fn with_producer<CB>(self, callback: CB) -> CB::Output
268 where
269 CB: ProducerCallback<Self::Item>,
270 {
271 callback.callback(DepGraphProducer {
272 counter: self.counter.clone(),
273 item_ready_rx: self.item_ready_rx,
274 item_done_tx: self.item_done_tx,
275 })
276 }
277}
278
279struct DepGraphProducer<I>
280where
281 I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
282{
283 counter: Arc<AtomicUsize>,
284 item_ready_rx: Receiver<I>,
285 item_done_tx: Sender<I>,
286}
287
288impl<I> Iterator for DepGraphProducer<I>
289where
290 I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
291{
292 type Item = Wrapper<I>;
293
294 fn next(&mut self) -> Option<Self::Item> {
295 match self.item_ready_rx.recv() {
297 Ok(item) => Some(Wrapper::new(
298 item,
299 self.counter.clone(),
300 self.item_done_tx.clone(),
301 )),
302 Err(_) => None,
303 }
304 }
305}
306
307impl<I> DoubleEndedIterator for DepGraphProducer<I>
308where
309 I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
310{
311 fn next_back(&mut self) -> Option<Self::Item> {
312 self.next()
313 }
314}
315
316impl<I> ExactSizeIterator for DepGraphProducer<I> where
317 I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static
318{
319}
320
321impl<I> Producer for DepGraphProducer<I>
322where
323 I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
324{
325 type Item = Wrapper<I>;
326 type IntoIter = Self;
327
328 fn into_iter(self) -> Self::IntoIter {
329 Self {
330 counter: self.counter.clone(),
331 item_ready_rx: self.item_ready_rx.clone(),
332 item_done_tx: self.item_done_tx,
333 }
334 }
335
336 fn split_at(self, _: usize) -> (Self, Self) {
337 (
338 Self {
339 counter: self.counter.clone(),
340 item_ready_rx: self.item_ready_rx.clone(),
341 item_done_tx: self.item_done_tx.clone(),
342 },
343 Self {
344 counter: self.counter.clone(),
345 item_ready_rx: self.item_ready_rx.clone(),
346 item_done_tx: self.item_done_tx,
347 },
348 )
349 }
350}