gix_features/parallel/
in_parallel.rs

1use std::sync::atomic::{AtomicBool, AtomicIsize, AtomicUsize, Ordering};
2
3use crate::parallel::{num_threads, Reduce};
4
5/// A scope to start threads within.
6pub type Scope<'scope, 'env> = std::thread::Scope<'scope, 'env>;
7
8/// Runs `left` and `right` in parallel, returning their output when both are done.
9pub fn join<O1: Send, O2: Send>(left: impl FnOnce() -> O1 + Send, right: impl FnOnce() -> O2 + Send) -> (O1, O2) {
10    std::thread::scope(|s| {
11        let left = std::thread::Builder::new()
12            .name("gitoxide.join.left".into())
13            .spawn_scoped(s, left)
14            .expect("valid name");
15        let right = std::thread::Builder::new()
16            .name("gitoxide.join.right".into())
17            .spawn_scoped(s, right)
18            .expect("valid name");
19        (left.join().unwrap(), right.join().unwrap())
20    })
21}
22
23/// Runs `f` with a scope to be used for spawning threads that will not outlive the function call.
24/// That way it's possible to handle threads without needing the 'static lifetime for data they interact with.
25///
26/// Note that the threads should not rely on actual parallelism as threading might be turned off entirely, hence should not
27/// connect each other with channels as deadlock would occur in single-threaded mode.
28pub fn threads<'env, F, R>(f: F) -> R
29where
30    F: for<'scope> FnOnce(&'scope std::thread::Scope<'scope, 'env>) -> R,
31{
32    std::thread::scope(f)
33}
34
35/// Create a builder for threads which allows them to be spawned into a scope and configured prior to spawning.
36pub fn build_thread() -> std::thread::Builder {
37    std::thread::Builder::new()
38}
39
40/// Read items from `input` and `consume` them in multiple threads,
41/// whose output is collected by a `reducer`. Its task is to
42/// aggregate these outputs into the final result returned by this function, with the benefit of not having to be thread-safe.
43///
44/// * if `thread_limit` is `Some`, then the given number of threads will be used. If `None`, all logical cores will be used.
45/// * `new_thread_state(thread_number) -> State` produces thread-local state once per thread to be passed to `consume`
46/// * `consume(Item, &mut State) -> Output` produces an output given an input obtained by `input` along with mutable state initially
47///   created by `new_thread_state(…)`.
48/// * For `reducer`, see the [`Reduce`] trait
49pub fn in_parallel<I, S, O, R>(
50    input: impl Iterator<Item = I> + Send,
51    thread_limit: Option<usize>,
52    new_thread_state: impl FnOnce(usize) -> S + Send + Clone,
53    consume: impl FnMut(I, &mut S) -> O + Send + Clone,
54    mut reducer: R,
55) -> Result<<R as Reduce>::Output, <R as Reduce>::Error>
56where
57    R: Reduce<Input = O>,
58    I: Send,
59    O: Send,
60{
61    let num_threads = num_threads(thread_limit);
62    std::thread::scope(move |s| {
63        let receive_result = {
64            let (send_input, receive_input) = crossbeam_channel::bounded::<I>(num_threads);
65            let (send_result, receive_result) = crossbeam_channel::bounded::<O>(num_threads);
66            for thread_id in 0..num_threads {
67                std::thread::Builder::new()
68                    .name(format!("gitoxide.in_parallel.produce.{thread_id}"))
69                    .spawn_scoped(s, {
70                        let send_result = send_result.clone();
71                        let receive_input = receive_input.clone();
72                        let new_thread_state = new_thread_state.clone();
73                        let mut consume = consume.clone();
74                        move || {
75                            let mut state = new_thread_state(thread_id);
76                            for item in receive_input {
77                                if send_result.send(consume(item, &mut state)).is_err() {
78                                    break;
79                                }
80                            }
81                        }
82                    })
83                    .expect("valid name");
84            }
85            std::thread::Builder::new()
86                .name("gitoxide.in_parallel.feed".into())
87                .spawn_scoped(s, move || {
88                    for item in input {
89                        if send_input.send(item).is_err() {
90                            break;
91                        }
92                    }
93                })
94                .expect("valid name");
95            receive_result
96        };
97
98        for item in receive_result {
99            drop(reducer.feed(item)?);
100        }
101        reducer.finalize()
102    })
103}
104
105/// Read items from `input` and `consume` them in multiple threads,
106/// whose output is collected by a `reducer`. Its task is to
107/// aggregate these outputs into the final result returned by this function with the benefit of not having to be thread-safe.
108/// Call `finalize` to finish the computation, once per thread, if there was no error sending results earlier.
109///
110/// * if `thread_limit` is `Some`, then the given number of threads will be used. If `None`, all logical cores will be used.
111/// * `new_thread_state(thread_number) -> State` produces thread-local state once per thread to be passed to `consume`
112/// * `consume(Item, &mut State) -> Output` produces an output given an input obtained by `input` along with mutable state initially
113///   created by `new_thread_state(…)`.
114/// * `finalize(State) -> Output` is called to potentially process remaining work that was placed in `State`.
115/// * For `reducer`, see the [`Reduce`] trait
116pub fn in_parallel_with_finalize<I, S, O, R>(
117    input: impl Iterator<Item = I> + Send,
118    thread_limit: Option<usize>,
119    new_thread_state: impl FnOnce(usize) -> S + Send + Clone,
120    consume: impl FnMut(I, &mut S) -> O + Send + Clone,
121    finalize: impl FnOnce(S) -> O + Send + Clone,
122    mut reducer: R,
123) -> Result<<R as Reduce>::Output, <R as Reduce>::Error>
124where
125    R: Reduce<Input = O>,
126    I: Send,
127    O: Send,
128{
129    let num_threads = num_threads(thread_limit);
130    std::thread::scope(move |s| {
131        let receive_result = {
132            let (send_input, receive_input) = crossbeam_channel::bounded::<I>(num_threads);
133            let (send_result, receive_result) = crossbeam_channel::bounded::<O>(num_threads);
134            for thread_id in 0..num_threads {
135                std::thread::Builder::new()
136                    .name(format!("gitoxide.in_parallel.produce.{thread_id}"))
137                    .spawn_scoped(s, {
138                        let send_result = send_result.clone();
139                        let receive_input = receive_input.clone();
140                        let new_thread_state = new_thread_state.clone();
141                        let mut consume = consume.clone();
142                        let finalize = finalize.clone();
143                        move || {
144                            let mut state = new_thread_state(thread_id);
145                            let mut can_send = true;
146                            for item in receive_input {
147                                if send_result.send(consume(item, &mut state)).is_err() {
148                                    can_send = false;
149                                    break;
150                                }
151                            }
152                            if can_send {
153                                send_result.send(finalize(state)).ok();
154                            }
155                        }
156                    })
157                    .expect("valid name");
158            }
159            std::thread::Builder::new()
160                .name("gitoxide.in_parallel.feed".into())
161                .spawn_scoped(s, move || {
162                    for item in input {
163                        if send_input.send(item).is_err() {
164                            break;
165                        }
166                    }
167                })
168                .expect("valid name");
169            receive_result
170        };
171
172        for item in receive_result {
173            drop(reducer.feed(item)?);
174        }
175        reducer.finalize()
176    })
177}
178
179/// An experiment to have fine-grained per-item parallelization with built-in aggregation via thread state.
180/// This is only good for operations where near-random access isn't detrimental, so it's not usually great
181/// for file-io as it won't make use of sorted inputs well.
182/// Note that `periodic` is not guaranteed to be called in case other threads come up first and finish too fast.
183/// `consume(&mut item, &mut stat, &Scope, &threads_available, &should_interrupt)` is called for performing the actual computation.
184/// Note that `threads_available` should be decremented to start a thread that can steal your own work (as stored in `item`),
185/// which allows callees to implement their own work-stealing in case the work is distributed unevenly.
186/// Work stealing should only start after having processed at least one item to give all threads naturally operating on the slice
187/// some time to start. Starting threads while slice-workers are still starting up would lead to over-allocation of threads,
188/// which is why the number of threads left may turn negative. Once threads are started and stopped, be sure to adjust
189/// the thread-count accordingly.
190// TODO: better docs
191pub fn in_parallel_with_slice<I, S, R, E>(
192    input: &mut [I],
193    thread_limit: Option<usize>,
194    new_thread_state: impl FnOnce(usize) -> S + Send + Clone,
195    consume: impl FnMut(&mut I, &mut S, &AtomicIsize, &AtomicBool) -> Result<(), E> + Send + Clone,
196    mut periodic: impl FnMut() -> Option<std::time::Duration> + Send,
197    state_to_rval: impl FnOnce(S) -> R + Send + Clone,
198) -> Result<Vec<R>, E>
199where
200    I: Send,
201    E: Send,
202    R: Send,
203{
204    let num_threads = num_threads(thread_limit);
205    let mut results = Vec::with_capacity(num_threads);
206    let stop_everything = &AtomicBool::default();
207    let index = &AtomicUsize::default();
208    let threads_left = &AtomicIsize::new(num_threads as isize);
209
210    std::thread::scope({
211        move |s| {
212            std::thread::Builder::new()
213                .name("gitoxide.in_parallel_with_slice.watch-interrupts".into())
214                .spawn_scoped(s, {
215                    move || loop {
216                        if stop_everything.load(Ordering::Relaxed) {
217                            break;
218                        }
219
220                        match periodic() {
221                            Some(duration) => std::thread::sleep(duration),
222                            None => {
223                                stop_everything.store(true, Ordering::Relaxed);
224                                break;
225                            }
226                        }
227                    }
228                })
229                .expect("valid name");
230
231            let input_len = input.len();
232            struct Input<I>(*mut I)
233            where
234                I: Send;
235
236            // SAFETY: I is Send, and we only use the pointer for creating new
237            // pointers (within the input slice) from the threads.
238            #[allow(unsafe_code)]
239            unsafe impl<I> Send for Input<I> where I: Send {}
240
241            let threads: Vec<_> = (0..num_threads)
242                .map(|thread_id| {
243                    std::thread::Builder::new()
244                        .name(format!("gitoxide.in_parallel_with_slice.produce.{thread_id}"))
245                        .spawn_scoped(s, {
246                            let new_thread_state = new_thread_state.clone();
247                            let state_to_rval = state_to_rval.clone();
248                            let mut consume = consume.clone();
249                            let input = Input(input.as_mut_ptr());
250                            move || {
251                                let _ = &input;
252                                threads_left.fetch_sub(1, Ordering::SeqCst);
253                                let mut state = new_thread_state(thread_id);
254                                let res = (|| {
255                                    while let Ok(input_index) =
256                                        index.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |x| {
257                                            (x < input_len).then_some(x + 1)
258                                        })
259                                    {
260                                        if stop_everything.load(Ordering::Relaxed) {
261                                            break;
262                                        }
263                                        // SAFETY: our atomic counter for `input_index` is only ever incremented, yielding
264                                        //         each item exactly once.
265                                        let item = {
266                                            #[allow(unsafe_code)]
267                                            unsafe {
268                                                &mut *input.0.add(input_index)
269                                            }
270                                        };
271                                        if let Err(err) = consume(item, &mut state, threads_left, stop_everything) {
272                                            stop_everything.store(true, Ordering::Relaxed);
273                                            return Err(err);
274                                        }
275                                    }
276                                    Ok(state_to_rval(state))
277                                })();
278                                threads_left.fetch_add(1, Ordering::SeqCst);
279                                res
280                            }
281                        })
282                        .expect("valid name")
283                })
284                .collect();
285            for thread in threads {
286                match thread.join() {
287                    Ok(res) => {
288                        results.push(res?);
289                    }
290                    Err(err) => {
291                        // a panic happened, stop the world gracefully (even though we panic later)
292                        stop_everything.store(true, Ordering::Relaxed);
293                        std::panic::resume_unwind(err);
294                    }
295                }
296            }
297
298            stop_everything.store(true, Ordering::Relaxed);
299            Ok(results)
300        }
301    })
302}