ordered_parallel_iterator/
lib.rs1use crossbeam::channel::bounded;
33use crossbeam::deque::{Steal, Stealer, Worker};
34use std::sync::atomic::{AtomicBool, Ordering};
35
36use std::marker::Send;
37use std::marker::Sync;
38use std::sync::Arc;
39use std::thread;
40use std::thread::JoinHandle;
41
42use std_semaphore::Semaphore;
43
44pub struct OrderedParallelIterator<O> {
45 scheduler_thread: Option<JoinHandle<()>>,
46 tasks: Stealer<JoinHandle<O>>,
47 semaphore: Arc<Semaphore>,
48 running: Arc<AtomicBool>,
49}
50
51impl<O> OrderedParallelIterator<O> {
52 pub fn new<PC, XC, P, X, I>(producer_ctor: PC, xform_ctor: XC) -> Self
53 where
54 PC: 'static + Send + FnOnce() -> P,
55 XC: 'static + Send + Sync + Fn() -> X,
56 X: FnMut(I) -> O,
57 I: 'static + Send,
58 O: 'static + Send,
59 P: IntoIterator<Item = I>,
60 {
61 let semaphore = Arc::new(Semaphore::new(num_cpus::get() as isize));
62 let (tx, rx) = bounded(num_cpus::get());
63 let semaphore_copy = semaphore.clone();
64 let xform_ctor = Arc::new(xform_ctor);
65 let running_flag = Arc::new(AtomicBool::new(true));
66 let running = running_flag.clone();
67 let scheduler_thread = Some(thread::spawn(move || {
68 let tasks = Worker::new_fifo();
69 let mut first = true;
70 for e in producer_ctor() {
71 semaphore_copy.acquire();
72 let xform_ctor = xform_ctor.clone();
73 let worker_thread = thread::spawn(move || {
74 let mut xform = xform_ctor();
75 xform(e)
76 });
77 tasks.push(worker_thread);
78 if first {
79 let stealer = tasks.stealer();
80 tx.send(stealer).unwrap();
81 first = false;
82 }
83 }
84 running_flag.store(false, Ordering::Relaxed);
85 if first {
86 let stealer = tasks.stealer();
88 tx.send(stealer).unwrap();
89 }
90 }));
91
92 let tasks = rx.recv().unwrap();
93
94 Self {
95 scheduler_thread,
96 tasks,
97 semaphore,
98 running,
99 }
100 }
101}
102
103impl<T> Iterator for OrderedParallelIterator<T> {
104 type Item = T;
105
106 fn next(&mut self) -> Option<T> {
107 self.semaphore.release();
108 loop {
109 let item = self.tasks.steal();
110 match item {
111 Steal::Success(x) => {
112 return Some(x.join().expect("Cannot get data from thread"));
113 }
114 Steal::Empty => {
115 if !self.running.load(Ordering::Relaxed) {
116 break;
117 }
118 }
119 Steal::Retry => (),
120 }
121 }
122
123 self.scheduler_thread
124 .take()
125 .unwrap()
126 .join()
127 .expect("The scheduler thread has paniced.");
128
129 None
130 }
131}
132
133#[cfg(test)]
134mod tests {
135
136 fn run_me(x: usize) -> usize {
137 x + 1
138 }
139
140 #[test]
141 fn it_works() {
142 let mut iterator = crate::OrderedParallelIterator::new(|| 0..10, || run_me);
143 for i in 0..10 {
144 assert_eq!(iterator.next(), Some(i + 1));
145 }
146 }
147
148 #[test]
149 fn empty() {
150 for _ in crate::OrderedParallelIterator::new(|| 0..0, || run_me) {
151 panic!("Must not reach this point");
152 }
153 }
154
155}