1use std::thread;
6use std::sync::{Arc, Mutex};
7use std::sync::atomic::{AtomicBool, Ordering};
8
9const MIN_CHUNK_SIZE: usize = 1024;
12
13const MAX_CHUNKS_PER_THREAD: usize = 8;
15
16pub fn num_cpus() -> usize {
18 thread::available_parallelism()
19 .map(|n| n.get())
20 .unwrap_or(1)
21}
22
23pub fn get_min_chunk_size() -> usize {
25 std::env::var("AVILA_MIN_CHUNK_SIZE")
26 .ok()
27 .and_then(|s| s.parse().ok())
28 .unwrap_or(MIN_CHUNK_SIZE)
29}
30
31pub fn calculate_chunk_size(len: usize, num_threads: usize) -> usize {
33 let min_chunk = get_min_chunk_size();
34 let max_chunks = num_threads * MAX_CHUNKS_PER_THREAD;
35 let chunk_size = (len + max_chunks - 1) / max_chunks;
36 chunk_size.max(min_chunk)
37}
38
39pub fn parallel_for_each<T, F>(items: &[T], f: F)
41where
42 T: Sync,
43 F: Fn(&T) + Sync + Send,
44{
45 let len = items.len();
46 if len == 0 {
47 return;
48 }
49
50 let num_threads = num_cpus();
51 let chunk_size = calculate_chunk_size(len, num_threads);
52
53 if chunk_size >= len {
54 for item in items {
56 f(item);
57 }
58 return;
59 }
60
61 let f = Arc::new(f);
63
64 thread::scope(|s| {
66 for chunk in items.chunks(chunk_size) {
67 let f = Arc::clone(&f);
68 s.spawn(move || {
69 for item in chunk {
70 f(item);
71 }
72 });
73 }
74 });
75}
76
77pub fn parallel_map<T, R, F>(items: &[T], f: F) -> Vec<R>
79where
80 T: Sync,
81 R: Send + 'static,
82 F: Fn(&T) -> R + Sync + Send,
83{
84 let len = items.len();
85 if len == 0 {
86 return Vec::new();
87 }
88
89 let num_threads = num_cpus();
90 let chunk_size = calculate_chunk_size(len, num_threads);
91
92 if chunk_size >= len {
93 return items.iter().map(&f).collect();
95 }
96
97 let f = Arc::new(f);
99 let chunk_results = Arc::new(Mutex::new(Vec::new()));
101
102 thread::scope(|s| {
103 let mut start_idx = 0;
104 for chunk in items.chunks(chunk_size) {
105 let f = Arc::clone(&f);
106 let chunk_results = Arc::clone(&chunk_results);
107 let chunk_start = start_idx;
108 start_idx += chunk.len();
109
110 s.spawn(move || {
111 let results: Vec<R> = chunk.iter().map(|item| f(item)).collect();
112 chunk_results.lock().unwrap().push((chunk_start, results));
113 });
114 }
115 });
116
117 let mut sorted_chunks = Arc::try_unwrap(chunk_results)
119 .unwrap_or_else(|_| panic!("Failed to unwrap Arc"))
120 .into_inner()
121 .unwrap_or_else(|_| panic!("Failed to acquire lock"));
122
123 sorted_chunks.sort_by_key(|(idx, _)| *idx);
124
125 let mut results = Vec::with_capacity(len);
127 for (_, chunk) in sorted_chunks {
128 results.extend(chunk);
129 }
130 results
131}
132
133pub fn parallel_filter<T, F>(items: &[T], f: F) -> Vec<&T>
135where
136 T: Sync,
137 F: Fn(&T) -> bool + Sync + Send,
138{
139 let len = items.len();
140 if len == 0 {
141 return Vec::new();
142 }
143
144 let num_threads = num_cpus();
145 let chunk_size = calculate_chunk_size(len, num_threads);
146
147 if chunk_size >= len {
148 return items.iter().filter(|item| f(item)).collect();
150 }
151
152 let f = Arc::new(f);
154 let results = Arc::new(Mutex::new(Vec::new()));
156
157 thread::scope(|s| {
158 for chunk in items.chunks(chunk_size) {
159 let f = Arc::clone(&f);
160 let results = Arc::clone(&results);
161 s.spawn(move || {
162 let chunk_results: Vec<&T> = chunk.iter().filter(|item| f(item)).collect();
163 results.lock().unwrap().extend(chunk_results);
164 });
165 }
166 });
167
168 Arc::try_unwrap(results)
170 .unwrap_or_else(|_| panic!("Failed to unwrap Arc"))
171 .into_inner()
172 .unwrap_or_else(|_| panic!("Failed to acquire lock"))
173}
174
175pub fn parallel_reduce<T, F>(items: &[T], reduce_op: F) -> Option<T>
177where
178 T: Clone + Send + Sync,
179 F: Fn(T, T) -> T + Sync + Send,
180{
181 let len = items.len();
182 if len == 0 {
183 return None;
184 }
185
186 let num_threads = num_cpus();
187 let chunk_size = calculate_chunk_size(len, num_threads);
188
189 if chunk_size >= len {
190 return items.iter().cloned().reduce(|a, b| reduce_op(a, b));
192 }
193
194 let reduce_op = Arc::new(reduce_op);
196 let results = Arc::new(Mutex::new(Vec::new()));
198
199 thread::scope(|s| {
200 for chunk in items.chunks(chunk_size) {
201 let reduce_op = Arc::clone(&reduce_op);
202 let results = Arc::clone(&results);
203 s.spawn(move || {
204 if let Some(chunk_result) = chunk.iter().cloned().reduce(|a, b| reduce_op(a, b)) {
205 results.lock().unwrap().push(chunk_result);
206 }
207 });
208 }
209 });
210
211 let final_results = Arc::try_unwrap(results)
213 .unwrap_or_else(|_| panic!("Failed to unwrap Arc"))
214 .into_inner()
215 .unwrap_or_else(|_| panic!("Failed to acquire lock"));
216 final_results.into_iter().reduce(|a, b| reduce_op(a, b))
217}
218
219pub fn parallel_find<T, F>(items: &[T], predicate: F) -> Option<T>
221where
222 T: Clone + Send + Sync,
223 F: Fn(&T) -> bool + Sync + Send,
224{
225 let len = items.len();
226 if len == 0 {
227 return None;
228 }
229
230 let num_threads = num_cpus();
231 let chunk_size = calculate_chunk_size(len, num_threads);
232
233 if chunk_size >= len {
234 return items.iter().find(|item| predicate(item)).cloned();
236 }
237
238 let predicate = Arc::new(predicate);
240 let result = Arc::new(Mutex::new(None));
241 let found_flag = Arc::new(AtomicBool::new(false)); thread::scope(|s| {
244 for chunk in items.chunks(chunk_size) {
245 let predicate = Arc::clone(&predicate);
246 let result = Arc::clone(&result);
247 let found_flag = Arc::clone(&found_flag);
248 s.spawn(move || {
249 if found_flag.load(Ordering::Relaxed) {
251 return;
252 }
253 if let Some(found) = chunk.iter().find(|item| predicate(item)) {
254 found_flag.store(true, Ordering::Relaxed);
255 let mut res = result.lock().unwrap();
256 if res.is_none() {
257 *res = Some(found.clone());
258 }
259 }
260 });
261 }
262 });
263
264 Arc::try_unwrap(result)
265 .unwrap_or_else(|_| panic!("Failed to unwrap Arc"))
266 .into_inner()
267 .unwrap_or_else(|_| panic!("Failed to acquire lock"))
268}
269
270pub fn parallel_count<T, F>(items: &[T], predicate: F) -> usize
272where
273 T: Sync,
274 F: Fn(&T) -> bool + Sync + Send,
275{
276 let len = items.len();
277 if len == 0 {
278 return 0;
279 }
280
281 let num_threads = num_cpus();
282 let chunk_size = calculate_chunk_size(len, num_threads);
283
284 if chunk_size >= len {
285 return items.iter().filter(|item| predicate(item)).count();
287 }
288
289 let predicate = Arc::new(predicate);
291 let counts = Arc::new(Mutex::new(Vec::new()));
292
293 thread::scope(|s| {
294 for chunk in items.chunks(chunk_size) {
295 let predicate = Arc::clone(&predicate);
296 let counts = Arc::clone(&counts);
297 s.spawn(move || {
298 let count = chunk.iter().filter(|item| predicate(item)).count();
299 counts.lock().unwrap().push(count);
300 });
301 }
302 });
303
304 Arc::try_unwrap(counts)
305 .unwrap_or_else(|_| panic!("Failed to unwrap Arc"))
306 .into_inner()
307 .unwrap_or_else(|_| panic!("Failed to acquire lock"))
308 .into_iter()
309 .sum()
310}
311
312pub fn parallel_partition<T, F>(items: &[T], predicate: F) -> (Vec<T>, Vec<T>)
314where
315 T: Clone + Send + Sync,
316 F: Fn(&T) -> bool + Sync + Send,
317{
318 let len = items.len();
319 if len == 0 {
320 return (Vec::new(), Vec::new());
321 }
322
323 let num_threads = num_cpus();
324 let chunk_size = calculate_chunk_size(len, num_threads);
325
326 if chunk_size >= len {
327 let mut true_vec = Vec::new();
329 let mut false_vec = Vec::new();
330 for item in items {
331 if predicate(item) {
332 true_vec.push(item.clone());
333 } else {
334 false_vec.push(item.clone());
335 }
336 }
337 return (true_vec, false_vec);
338 }
339
340 let predicate = Arc::new(predicate);
342 let results = Arc::new(Mutex::new(Vec::new()));
343
344 thread::scope(|s| {
345 for chunk in items.chunks(chunk_size) {
346 let predicate = Arc::clone(&predicate);
347 let results = Arc::clone(&results);
348 s.spawn(move || {
349 let mut true_vec = Vec::new();
350 let mut false_vec = Vec::new();
351 for item in chunk {
352 if predicate(item) {
353 true_vec.push(item.clone());
354 } else {
355 false_vec.push(item.clone());
356 }
357 }
358 results.lock().unwrap().push((true_vec, false_vec));
359 });
360 }
361 });
362
363 let chunk_results = Arc::try_unwrap(results)
364 .unwrap_or_else(|_| panic!("Failed to unwrap Arc"))
365 .into_inner()
366 .unwrap_or_else(|_| panic!("Failed to acquire lock"));
367
368 let mut final_true = Vec::new();
369 let mut final_false = Vec::new();
370 for (true_vec, false_vec) in chunk_results {
371 final_true.extend(true_vec);
372 final_false.extend(false_vec);
373 }
374 (final_true, final_false)
375}
376
377pub fn parallel_sum<T>(items: &[T]) -> T
379where
380 T: Clone + Send + Sync + std::iter::Sum,
381{
382 let len = items.len();
383 if len == 0 {
384 panic!("Cannot sum empty collection");
385 }
386
387 let num_threads = num_cpus();
388 let chunk_size = calculate_chunk_size(len, num_threads);
389
390 if chunk_size >= len {
391 return items.iter().cloned().sum();
393 }
394
395 let results = Arc::new(Mutex::new(Vec::new()));
397
398 thread::scope(|s| {
399 for chunk in items.chunks(chunk_size) {
400 let results = Arc::clone(&results);
401 s.spawn(move || {
402 let chunk_sum: T = chunk.iter().cloned().sum();
403 results.lock().unwrap().push(chunk_sum);
404 });
405 }
406 });
407
408 Arc::try_unwrap(results)
410 .unwrap_or_else(|_| panic!("Failed to unwrap Arc"))
411 .into_inner()
412 .unwrap_or_else(|_| panic!("Failed to acquire lock"))
413 .into_iter()
414 .sum()
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[test]
422 fn test_parallel_for_each() {
423 let data = vec![1, 2, 3, 4, 5];
424 let counter = Arc::new(Mutex::new(0));
425
426 parallel_for_each(&data, |_| {
427 *counter.lock().unwrap() += 1;
428 });
429
430 assert_eq!(*counter.lock().unwrap(), 5);
431 }
432
433 #[test]
434 fn test_parallel_map() {
435 let data = vec![1, 2, 3, 4, 5];
436 let result = parallel_map(&data, |x| x * 2);
437
438 let mut sorted_result = result;
439 sorted_result.sort();
440 assert_eq!(sorted_result, vec![2, 4, 6, 8, 10]);
441 }
442
443 #[test]
444 fn test_parallel_filter() {
445 let data = vec![1, 2, 3, 4, 5, 6];
446 let result = parallel_filter(&data, |x| *x % 2 == 0);
447
448 let mut values: Vec<i32> = result.into_iter().map(|x| *x).collect();
449 values.sort();
450 assert_eq!(values, vec![2, 4, 6]);
451 }
452
453 #[test]
454 fn test_parallel_reduce() {
455 let data = vec![1, 2, 3, 4, 5];
456 let result = parallel_reduce(&data, |a, b| a + b);
457 assert_eq!(result, Some(15));
458 }
459
460 #[test]
461 fn test_parallel_sum() {
462 let data = vec![1, 2, 3, 4, 5];
463 let result = parallel_sum(&data);
464 assert_eq!(result, 15);
465 }
466
467 #[test]
468 fn test_parallel_find() {
469 let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
470 let result = parallel_find(&data, |x| *x > 5);
471 assert!(result.is_some());
472 assert!(result.unwrap() > 5);
473 }
474
475 #[test]
476 fn test_parallel_count() {
477 let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
478 let count = parallel_count(&data, |x| *x % 2 == 0);
479 assert_eq!(count, 5);
480 }
481
482 #[test]
483 fn test_parallel_partition() {
484 let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
485 let (evens, odds) = parallel_partition(&data, |x| *x % 2 == 0);
486 assert_eq!(evens.len(), 5);
487 assert_eq!(odds.len(), 5);
488 assert!(evens.iter().all(|x| x % 2 == 0));
489 assert!(odds.iter().all(|x| x % 2 == 1));
490 }
491}