1use std::thread;
6use std::sync::{Arc, Mutex};
7use std::cmp::Ordering;
8use crate::executor::{num_cpus, calculate_chunk_size};
9
10pub fn parallel_sort<T>(items: &mut [T])
12where
13 T: Ord + Send,
14{
15 parallel_sort_by(items, |a, b| a.cmp(b));
16}
17
18pub fn parallel_sort_by<T, F>(items: &mut [T], compare: F)
20where
21 T: Send,
22 F: Fn(&T, &T) -> Ordering + Send + Sync + Copy,
23{
24 let len = items.len();
25 if len <= 1 {
26 return;
27 }
28
29 if len < 10_000 {
31 items.sort_by(compare);
32 return;
33 }
34
35 parallel_merge_sort(items, compare);
36}
37
38fn parallel_merge_sort<T, F>(items: &mut [T], compare: F)
39where
40 T: Send,
41 F: Fn(&T, &T) -> Ordering + Send + Sync + Copy,
42{
43 let len = items.len();
44 if len <= 10_000 {
45 items.sort_by(compare);
46 return;
47 }
48
49 let mid = len / 2;
50 let (left, right) = items.split_at_mut(mid);
51
52 thread::scope(|s| {
53 s.spawn(move || parallel_merge_sort(left, compare));
54 parallel_merge_sort(right, compare);
55 });
56
57 merge(items, mid, compare);
58}
59
60fn merge<T, F>(items: &mut [T], mid: usize, compare: F)
61where
62 T: Send,
63 F: Fn(&T, &T) -> Ordering,
64{
65 let len = items.len();
67 let mut start = 0;
68
69 while start < mid && mid < len {
70 if compare(&items[start], &items[mid]) != Ordering::Greater {
71 start += 1;
72 continue;
73 }
74
75 let mut end = mid;
77 while end < len && compare(&items[start], &items[end]) == Ordering::Greater {
78 end += 1;
79 }
80
81 items[start..end].rotate_right(end - mid);
83 start += end - mid;
84 }
85}
86
87pub fn parallel_partition_advanced<T, F>(items: &[T], predicate: F) -> (Vec<T>, Vec<T>)
89where
90 T: Clone + Send + Sync,
91 F: Fn(&T) -> bool + Send + Sync,
92{
93 let len = items.len();
94 if len == 0 {
95 return (Vec::new(), Vec::new());
96 }
97
98 let num_threads = num_cpus();
99 let chunk_size = calculate_chunk_size(len, num_threads);
100
101 if chunk_size >= len {
102 let (true_items, false_items): (Vec<_>, Vec<_>) = items.iter()
103 .cloned()
104 .partition(|item| predicate(item));
105 return (true_items, false_items);
106 }
107
108 let predicate = Arc::new(predicate);
109 let results = Arc::new(Mutex::new(Vec::new()));
110
111 thread::scope(|s| {
112 for (idx, chunk) in items.chunks(chunk_size).enumerate() {
113 let predicate = Arc::clone(&predicate);
114 let results = Arc::clone(&results);
115 s.spawn(move || {
116 let (true_items, false_items): (Vec<_>, Vec<_>) = chunk.iter()
117 .cloned()
118 .partition(|item| predicate(item));
119 results.lock().unwrap().push((idx, true_items, false_items));
120 });
121 }
122 });
123
124 let mut collected = Arc::try_unwrap(results)
125 .unwrap_or_else(|_| panic!("Failed to unwrap Arc"))
126 .into_inner()
127 .unwrap_or_else(|_| panic!("Failed to acquire lock"));
128
129 collected.sort_by_key(|(idx, _, _)| *idx);
130
131 let (all_true, all_false): (Vec<_>, Vec<_>) = collected
132 .into_iter()
133 .map(|(_, t, f)| (t, f))
134 .unzip();
135
136 (
137 all_true.into_iter().flatten().collect(),
138 all_false.into_iter().flatten().collect(),
139 )
140}
141
142pub fn parallel_zip<T, U, F, R>(left: &[T], right: &[U], f: F) -> Vec<R>
144where
145 T: Sync,
146 U: Sync,
147 R: Send,
148 F: Fn(&T, &U) -> R + Send + Sync,
149{
150 let len = left.len().min(right.len());
151 if len == 0 {
152 return Vec::new();
153 }
154
155 let num_threads = num_cpus();
156 let chunk_size = calculate_chunk_size(len, num_threads);
157
158 if chunk_size >= len {
159 return left.iter()
160 .zip(right.iter())
161 .map(|(l, r)| f(l, r))
162 .collect();
163 }
164
165 let f = Arc::new(f);
166 let results = Arc::new(Mutex::new(Vec::new()));
167
168 thread::scope(|s| {
169 for (idx, (left_chunk, right_chunk)) in left[..len].chunks(chunk_size)
170 .zip(right[..len].chunks(chunk_size))
171 .enumerate()
172 {
173 let f = Arc::clone(&f);
174 let results = Arc::clone(&results);
175 s.spawn(move || {
176 let chunk_results: Vec<_> = left_chunk.iter()
177 .zip(right_chunk.iter())
178 .map(|(l, r)| f(l, r))
179 .collect();
180 results.lock().unwrap().push((idx, chunk_results));
181 });
182 }
183 });
184
185 let mut collected = Arc::try_unwrap(results)
186 .unwrap_or_else(|_| panic!("Failed to unwrap Arc"))
187 .into_inner()
188 .unwrap_or_else(|_| panic!("Failed to acquire lock"));
189
190 collected.sort_by_key(|(idx, _)| *idx);
191 collected.into_iter().flat_map(|(_, results)| results).collect()
192}
193
194pub fn parallel_chunks<T, F, R>(items: &[T], chunk_size: usize, f: F) -> Vec<Vec<R>>
196where
197 T: Sync,
198 R: Send,
199 F: Fn(&[T]) -> Vec<R> + Send + Sync,
200{
201 if items.is_empty() || chunk_size == 0 {
202 return Vec::new();
203 }
204
205 let f = Arc::new(f);
206 let results = Arc::new(Mutex::new(Vec::new()));
207
208 thread::scope(|s| {
209 for (idx, chunk) in items.chunks(chunk_size).enumerate() {
210 let f = Arc::clone(&f);
211 let results = Arc::clone(&results);
212 s.spawn(move || {
213 let chunk_result = f(chunk);
214 results.lock().unwrap().push((idx, chunk_result));
215 });
216 }
217 });
218
219 let mut collected = Arc::try_unwrap(results)
220 .unwrap_or_else(|_| panic!("Failed to unwrap Arc"))
221 .into_inner()
222 .unwrap_or_else(|_| panic!("Failed to acquire lock"));
223
224 collected.sort_by_key(|(idx, _)| *idx);
225 collected.into_iter().map(|(_, results)| results).collect()
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn test_parallel_sort() {
234 let mut data = vec![5, 2, 8, 1, 9, 3, 7, 4, 6];
235 parallel_sort(&mut data);
236 assert_eq!(data, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
237 }
238
239 #[test]
240 fn test_parallel_sort_large() {
241 let mut data: Vec<i32> = (0..100_000).rev().collect();
242 parallel_sort(&mut data);
243 for i in 0..data.len() - 1 {
244 assert!(data[i] <= data[i + 1]);
245 }
246 }
247
248 #[test]
249 fn test_parallel_zip() {
250 let left = vec![1, 2, 3, 4, 5];
251 let right = vec![10, 20, 30, 40, 50];
252 let result = parallel_zip(&left, &right, |a, b| a + b);
253 assert_eq!(result, vec![11, 22, 33, 44, 55]);
254 }
255
256 #[test]
257 fn test_parallel_chunks() {
258 let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
259 let results = parallel_chunks(&data, 3, |chunk| {
260 chunk.iter().map(|x| x * 2).collect()
261 });
262 assert_eq!(results.len(), 4); assert_eq!(results[0], vec![2, 4, 6]);
264 assert_eq!(results[1], vec![8, 10, 12]);
265 }
266}
267
268
269
270
271