gix_features/parallel/in_parallel.rs
1use std::sync::atomic::{AtomicBool, AtomicIsize, AtomicUsize, Ordering};
2
3use crate::parallel::{num_threads, Reduce};
4
5/// A scope to start threads within.
6pub type Scope<'scope, 'env> = std::thread::Scope<'scope, 'env>;
7
8/// Runs `left` and `right` in parallel, returning their output when both are done.
9pub fn join<O1: Send, O2: Send>(left: impl FnOnce() -> O1 + Send, right: impl FnOnce() -> O2 + Send) -> (O1, O2) {
10 std::thread::scope(|s| {
11 let left = std::thread::Builder::new()
12 .name("gitoxide.join.left".into())
13 .spawn_scoped(s, left)
14 .expect("valid name");
15 let right = std::thread::Builder::new()
16 .name("gitoxide.join.right".into())
17 .spawn_scoped(s, right)
18 .expect("valid name");
19 (left.join().unwrap(), right.join().unwrap())
20 })
21}
22
23/// Runs `f` with a scope to be used for spawning threads that will not outlive the function call.
24/// That way it's possible to handle threads without needing the 'static lifetime for data they interact with.
25///
26/// Note that the threads should not rely on actual parallelism as threading might be turned off entirely, hence should not
27/// connect each other with channels as deadlock would occur in single-threaded mode.
28pub fn threads<'env, F, R>(f: F) -> R
29where
30 F: for<'scope> FnOnce(&'scope std::thread::Scope<'scope, 'env>) -> R,
31{
32 std::thread::scope(f)
33}
34
35/// Create a builder for threads which allows them to be spawned into a scope and configured prior to spawning.
36pub fn build_thread() -> std::thread::Builder {
37 std::thread::Builder::new()
38}
39
40/// Read items from `input` and `consume` them in multiple threads,
41/// whose output is collected by a `reducer`. Its task is to
42/// aggregate these outputs into the final result returned by this function, with the benefit of not having to be thread-safe.
43///
44/// * if `thread_limit` is `Some`, then the given number of threads will be used. If `None`, all logical cores will be used.
45/// * `new_thread_state(thread_number) -> State` produces thread-local state once per thread to be passed to `consume`
46/// * `consume(Item, &mut State) -> Output` produces an output given an input obtained by `input` along with mutable state initially
47/// created by `new_thread_state(…)`.
48/// * For `reducer`, see the [`Reduce`] trait
49pub fn in_parallel<I, S, O, R>(
50 input: impl Iterator<Item = I> + Send,
51 thread_limit: Option<usize>,
52 new_thread_state: impl FnOnce(usize) -> S + Send + Clone,
53 consume: impl FnMut(I, &mut S) -> O + Send + Clone,
54 mut reducer: R,
55) -> Result<<R as Reduce>::Output, <R as Reduce>::Error>
56where
57 R: Reduce<Input = O>,
58 I: Send,
59 O: Send,
60{
61 let num_threads = num_threads(thread_limit);
62 std::thread::scope(move |s| {
63 let receive_result = {
64 let (send_input, receive_input) = crossbeam_channel::bounded::<I>(num_threads);
65 let (send_result, receive_result) = crossbeam_channel::bounded::<O>(num_threads);
66 for thread_id in 0..num_threads {
67 std::thread::Builder::new()
68 .name(format!("gitoxide.in_parallel.produce.{thread_id}"))
69 .spawn_scoped(s, {
70 let send_result = send_result.clone();
71 let receive_input = receive_input.clone();
72 let new_thread_state = new_thread_state.clone();
73 let mut consume = consume.clone();
74 move || {
75 let mut state = new_thread_state(thread_id);
76 for item in receive_input {
77 if send_result.send(consume(item, &mut state)).is_err() {
78 break;
79 }
80 }
81 }
82 })
83 .expect("valid name");
84 }
85 std::thread::Builder::new()
86 .name("gitoxide.in_parallel.feed".into())
87 .spawn_scoped(s, move || {
88 for item in input {
89 if send_input.send(item).is_err() {
90 break;
91 }
92 }
93 })
94 .expect("valid name");
95 receive_result
96 };
97
98 for item in receive_result {
99 drop(reducer.feed(item)?);
100 }
101 reducer.finalize()
102 })
103}
104
105/// Read items from `input` and `consume` them in multiple threads,
106/// whose output is collected by a `reducer`. Its task is to
107/// aggregate these outputs into the final result returned by this function with the benefit of not having to be thread-safe.
108/// Call `finalize` to finish the computation, once per thread, if there was no error sending results earlier.
109///
110/// * if `thread_limit` is `Some`, then the given number of threads will be used. If `None`, all logical cores will be used.
111/// * `new_thread_state(thread_number) -> State` produces thread-local state once per thread to be passed to `consume`
112/// * `consume(Item, &mut State) -> Output` produces an output given an input obtained by `input` along with mutable state initially
113/// created by `new_thread_state(…)`.
114/// * `finalize(State) -> Output` is called to potentially process remaining work that was placed in `State`.
115/// * For `reducer`, see the [`Reduce`] trait
116pub fn in_parallel_with_finalize<I, S, O, R>(
117 input: impl Iterator<Item = I> + Send,
118 thread_limit: Option<usize>,
119 new_thread_state: impl FnOnce(usize) -> S + Send + Clone,
120 consume: impl FnMut(I, &mut S) -> O + Send + Clone,
121 finalize: impl FnOnce(S) -> O + Send + Clone,
122 mut reducer: R,
123) -> Result<<R as Reduce>::Output, <R as Reduce>::Error>
124where
125 R: Reduce<Input = O>,
126 I: Send,
127 O: Send,
128{
129 let num_threads = num_threads(thread_limit);
130 std::thread::scope(move |s| {
131 let receive_result = {
132 let (send_input, receive_input) = crossbeam_channel::bounded::<I>(num_threads);
133 let (send_result, receive_result) = crossbeam_channel::bounded::<O>(num_threads);
134 for thread_id in 0..num_threads {
135 std::thread::Builder::new()
136 .name(format!("gitoxide.in_parallel.produce.{thread_id}"))
137 .spawn_scoped(s, {
138 let send_result = send_result.clone();
139 let receive_input = receive_input.clone();
140 let new_thread_state = new_thread_state.clone();
141 let mut consume = consume.clone();
142 let finalize = finalize.clone();
143 move || {
144 let mut state = new_thread_state(thread_id);
145 let mut can_send = true;
146 for item in receive_input {
147 if send_result.send(consume(item, &mut state)).is_err() {
148 can_send = false;
149 break;
150 }
151 }
152 if can_send {
153 send_result.send(finalize(state)).ok();
154 }
155 }
156 })
157 .expect("valid name");
158 }
159 std::thread::Builder::new()
160 .name("gitoxide.in_parallel.feed".into())
161 .spawn_scoped(s, move || {
162 for item in input {
163 if send_input.send(item).is_err() {
164 break;
165 }
166 }
167 })
168 .expect("valid name");
169 receive_result
170 };
171
172 for item in receive_result {
173 drop(reducer.feed(item)?);
174 }
175 reducer.finalize()
176 })
177}
178
179/// An experiment to have fine-grained per-item parallelization with built-in aggregation via thread state.
180/// This is only good for operations where near-random access isn't detrimental, so it's not usually great
181/// for file-io as it won't make use of sorted inputs well.
182/// Note that `periodic` is not guaranteed to be called in case other threads come up first and finish too fast.
183/// `consume(&mut item, &mut stat, &Scope, &threads_available, &should_interrupt)` is called for performing the actual computation.
184/// Note that `threads_available` should be decremented to start a thread that can steal your own work (as stored in `item`),
185/// which allows callees to implement their own work-stealing in case the work is distributed unevenly.
186/// Work stealing should only start after having processed at least one item to give all threads naturally operating on the slice
187/// some time to start. Starting threads while slice-workers are still starting up would lead to over-allocation of threads,
188/// which is why the number of threads left may turn negative. Once threads are started and stopped, be sure to adjust
189/// the thread-count accordingly.
190// TODO: better docs
191pub fn in_parallel_with_slice<I, S, R, E>(
192 input: &mut [I],
193 thread_limit: Option<usize>,
194 new_thread_state: impl FnOnce(usize) -> S + Send + Clone,
195 consume: impl FnMut(&mut I, &mut S, &AtomicIsize, &AtomicBool) -> Result<(), E> + Send + Clone,
196 mut periodic: impl FnMut() -> Option<std::time::Duration> + Send,
197 state_to_rval: impl FnOnce(S) -> R + Send + Clone,
198) -> Result<Vec<R>, E>
199where
200 I: Send,
201 E: Send,
202 R: Send,
203{
204 let num_threads = num_threads(thread_limit);
205 let mut results = Vec::with_capacity(num_threads);
206 let stop_everything = &AtomicBool::default();
207 let index = &AtomicUsize::default();
208 let threads_left = &AtomicIsize::new(num_threads as isize);
209
210 std::thread::scope({
211 move |s| {
212 std::thread::Builder::new()
213 .name("gitoxide.in_parallel_with_slice.watch-interrupts".into())
214 .spawn_scoped(s, {
215 move || loop {
216 if stop_everything.load(Ordering::Relaxed) {
217 break;
218 }
219
220 match periodic() {
221 Some(duration) => std::thread::sleep(duration),
222 None => {
223 stop_everything.store(true, Ordering::Relaxed);
224 break;
225 }
226 }
227 }
228 })
229 .expect("valid name");
230
231 let input_len = input.len();
232 struct Input<I>(*mut I)
233 where
234 I: Send;
235
236 // SAFETY: I is Send, and we only use the pointer for creating new
237 // pointers (within the input slice) from the threads.
238 #[allow(unsafe_code)]
239 unsafe impl<I> Send for Input<I> where I: Send {}
240
241 let threads: Vec<_> = (0..num_threads)
242 .map(|thread_id| {
243 std::thread::Builder::new()
244 .name(format!("gitoxide.in_parallel_with_slice.produce.{thread_id}"))
245 .spawn_scoped(s, {
246 let new_thread_state = new_thread_state.clone();
247 let state_to_rval = state_to_rval.clone();
248 let mut consume = consume.clone();
249 let input = Input(input.as_mut_ptr());
250 move || {
251 let _ = &input;
252 threads_left.fetch_sub(1, Ordering::SeqCst);
253 let mut state = new_thread_state(thread_id);
254 let res = (|| {
255 while let Ok(input_index) =
256 index.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |x| {
257 (x < input_len).then_some(x + 1)
258 })
259 {
260 if stop_everything.load(Ordering::Relaxed) {
261 break;
262 }
263 // SAFETY: our atomic counter for `input_index` is only ever incremented, yielding
264 // each item exactly once.
265 let item = {
266 #[allow(unsafe_code)]
267 unsafe {
268 &mut *input.0.add(input_index)
269 }
270 };
271 if let Err(err) = consume(item, &mut state, threads_left, stop_everything) {
272 stop_everything.store(true, Ordering::Relaxed);
273 return Err(err);
274 }
275 }
276 Ok(state_to_rval(state))
277 })();
278 threads_left.fetch_add(1, Ordering::SeqCst);
279 res
280 }
281 })
282 .expect("valid name")
283 })
284 .collect();
285 for thread in threads {
286 match thread.join() {
287 Ok(res) => {
288 results.push(res?);
289 }
290 Err(err) => {
291 // a panic happened, stop the world gracefully (even though we panic later)
292 stop_everything.store(true, Ordering::Relaxed);
293 std::panic::resume_unwind(err);
294 }
295 }
296 }
297
298 stop_everything.store(true, Ordering::Relaxed);
299 Ok(results)
300 }
301 })
302}