git_features/parallel/
in_parallel.rs1use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
2
3use crate::parallel::{num_threads, Reduce};
4
5pub fn join<O1: Send, O2: Send>(left: impl FnOnce() -> O1 + Send, right: impl FnOnce() -> O2 + Send) -> (O1, O2) {
7 crossbeam_utils::thread::scope(|s| {
8 let left = s
9 .builder()
10 .name("gitoxide.join.left".into())
11 .spawn(|_| left())
12 .expect("valid name");
13 let right = s
14 .builder()
15 .name("gitoxide.join.right".into())
16 .spawn(|_| right())
17 .expect("valid name");
18 (left.join().unwrap(), right.join().unwrap())
19 })
20 .unwrap()
21}
22
23pub fn threads<'env, F, R>(f: F) -> std::thread::Result<R>
29where
30 F: FnOnce(&crossbeam_utils::thread::Scope<'env>) -> R,
31{
32 crossbeam_utils::thread::scope(f)
33}
34
35pub fn in_parallel<I, S, O, R>(
45 input: impl Iterator<Item = I> + Send,
46 thread_limit: Option<usize>,
47 new_thread_state: impl Fn(usize) -> S + Send + Clone,
48 consume: impl Fn(I, &mut S) -> O + Send + Clone,
49 mut reducer: R,
50) -> Result<<R as Reduce>::Output, <R as Reduce>::Error>
51where
52 R: Reduce<Input = O>,
53 I: Send,
54 O: Send,
55{
56 let num_threads = num_threads(thread_limit);
57 crossbeam_utils::thread::scope(move |s| {
58 let receive_result = {
59 let (send_input, receive_input) = crossbeam_channel::bounded::<I>(num_threads);
60 let (send_result, receive_result) = crossbeam_channel::bounded::<O>(num_threads);
61 for thread_id in 0..num_threads {
62 s.builder()
63 .name(format!("gitoxide.in_parallel.produce.{thread_id}"))
64 .spawn({
65 let send_result = send_result.clone();
66 let receive_input = receive_input.clone();
67 let new_thread_state = new_thread_state.clone();
68 let consume = consume.clone();
69 move |_| {
70 let mut state = new_thread_state(thread_id);
71 for item in receive_input {
72 if send_result.send(consume(item, &mut state)).is_err() {
73 break;
74 }
75 }
76 }
77 })
78 .expect("valid name");
79 }
80 s.builder()
81 .name("gitoxide.in_parallel.feed".into())
82 .spawn(move |_| {
83 for item in input {
84 if send_input.send(item).is_err() {
85 break;
86 }
87 }
88 })
89 .expect("valid name");
90 receive_result
91 };
92
93 for item in receive_result {
94 drop(reducer.feed(item)?);
95 }
96 reducer.finalize()
97 })
98 .expect("no panic")
99}
100
101pub fn in_parallel_with_slice<I, S, R, E>(
107 input: &mut [I],
108 thread_limit: Option<usize>,
109 new_thread_state: impl FnMut(usize) -> S + Send + Clone,
110 consume: impl FnMut(&mut I, &mut S) -> Result<(), E> + Send + Clone,
111 mut periodic: impl FnMut() -> Option<std::time::Duration> + Send,
112 state_to_rval: impl FnOnce(S) -> R + Send + Clone,
113) -> Result<Vec<R>, E>
114where
115 I: Send,
116 E: Send,
117 R: Send,
118{
119 let num_threads = num_threads(thread_limit);
120 let mut results = Vec::with_capacity(num_threads);
121 let stop_everything = &AtomicBool::default();
122 let index = &AtomicUsize::default();
123
124 crossbeam_utils::thread::scope({
126 move |s| {
127 s.builder()
128 .name("gitoxide.in_parallel_with_slice.watch-interrupts".into())
129 .spawn({
130 move |_| loop {
131 if stop_everything.load(Ordering::Relaxed) {
132 break;
133 }
134
135 match periodic() {
136 Some(duration) => std::thread::sleep(duration),
137 None => {
138 stop_everything.store(true, Ordering::Relaxed);
139 break;
140 }
141 }
142 }
143 })
144 .expect("valid name");
145
146 let input_len = input.len();
147 struct Input<I>(*mut [I])
148 where
149 I: Send;
150
151 #[allow(unsafe_code)]
153 unsafe impl<I> Send for Input<I> where I: Send {}
154
155 let threads: Vec<_> = (0..num_threads)
156 .map(|thread_id| {
157 s.builder()
158 .name(format!("gitoxide.in_parallel_with_slice.produce.{thread_id}"))
159 .spawn({
160 let mut new_thread_state = new_thread_state.clone();
161 let state_to_rval = state_to_rval.clone();
162 let mut consume = consume.clone();
163 let input = Input(input as *mut [I]);
164 move |_| {
165 let mut state = new_thread_state(thread_id);
166 while let Ok(input_index) =
167 index.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |x| {
168 (x < input_len).then_some(x + 1)
169 })
170 {
171 if stop_everything.load(Ordering::Relaxed) {
172 break;
173 }
174 let item = {
177 #[allow(unsafe_code)]
178 unsafe {
179 &mut (&mut *input.0)[input_index]
180 }
181 };
182 if let Err(err) = consume(item, &mut state) {
183 stop_everything.store(true, Ordering::Relaxed);
184 return Err(err);
185 }
186 }
187 Ok(state_to_rval(state))
188 }
189 })
190 .expect("valid name")
191 })
192 .collect();
193 for thread in threads {
194 match thread.join() {
195 Ok(res) => {
196 results.push(res?);
197 }
198 Err(err) => {
199 stop_everything.store(true, Ordering::Relaxed);
201 std::panic::resume_unwind(err);
202 }
203 }
204 }
205
206 stop_everything.store(true, Ordering::Relaxed);
207 Ok(results)
208 }
209 })
210 .expect("no panic")
211}