Skip to main content

grafeo_core/execution/parallel/
fold.rs

1//! Parallel fold-reduce utilities using Rayon.
2//!
3//! This module provides ergonomic patterns for parallel data processing
4//! using Rayon's `fold` and `reduce` operations. These patterns are useful
5//! for batch operations like import, export, and statistics collection.
6//!
7//! # Why Fold-Reduce?
8//!
9//! The fold-reduce pattern provides:
10//! - **No contention**: Each thread has its own accumulator
11//! - **Work-stealing**: Rayon handles load balancing automatically
12//! - **Composable**: Easy to combine multiple aggregations
13//!
14//! # Example
15//!
16//! ```ignore
17//! use grafeo_core::execution::parallel::fold::{parallel_count, parallel_sum};
18//! use rayon::prelude::*;
19//!
20//! let numbers: Vec<i32> = (0..1000).collect();
21//!
22//! // Count even numbers
23//! let even_count = parallel_count(numbers.par_iter(), |n| *n % 2 == 0);
24//!
25//! // Sum all numbers
26//! let total: f64 = parallel_sum(numbers.par_iter(), |n| *n as f64);
27//! ```
28
29use rayon::prelude::*;
30
31/// Trait for types that can be merged in parallel fold-reduce operations.
32///
33/// Implement this for custom accumulator types that need to be combined
34/// after parallel processing.
35pub trait Mergeable: Send + Default {
36    /// Merges another instance into this one.
37    fn merge(&mut self, other: Self);
38}
39
40/// Execute parallel fold-reduce with custom accumulator.
41///
42/// This is the most general form of parallel aggregation:
43/// 1. Each thread gets its own accumulator (created by `T::default`)
44/// 2. `fold_fn` processes items into thread-local accumulators
45/// 3. `merge_fn` combines accumulators from different threads
46///
47/// # Example
48///
49/// ```ignore
50/// use grafeo_core::execution::parallel::fold::fold_reduce;
51/// use rayon::prelude::*;
52///
53/// let items = vec![1, 2, 3, 4, 5];
54/// let sum = fold_reduce(
55///     items.into_par_iter(),
56///     |acc, item| acc + item,
57///     |a, b| a + b,
58/// );
59/// assert_eq!(sum, 15);
60/// ```
61pub fn fold_reduce<T, I, F, M>(items: I, fold_fn: F, merge_fn: M) -> T
62where
63    T: Send + Default,
64    I: ParallelIterator,
65    F: Fn(T, I::Item) -> T + Sync + Send,
66    M: Fn(T, T) -> T + Sync + Send,
67{
68    items.fold(T::default, fold_fn).reduce(T::default, merge_fn)
69}
70
71/// Fold-reduce with a custom identity/factory function.
72///
73/// Use this when `T::default()` isn't suitable for your accumulator.
74pub fn fold_reduce_with<T, I, Init, F, M>(items: I, init: Init, fold_fn: F, merge_fn: M) -> T
75where
76    T: Send,
77    I: ParallelIterator,
78    Init: Fn() -> T + Sync + Send + Clone,
79    F: Fn(T, I::Item) -> T + Sync + Send,
80    M: Fn(T, T) -> T + Sync + Send,
81{
82    items.fold(init.clone(), fold_fn).reduce(init, merge_fn)
83}
84
85/// Count items matching a predicate in parallel.
86///
87/// Efficiently counts matching items using fold-reduce,
88/// with no lock contention between threads.
89///
90/// # Example
91///
92/// ```ignore
93/// use grafeo_core::execution::parallel::fold::parallel_count;
94/// use rayon::prelude::*;
95///
96/// let numbers: Vec<i32> = (0..1000).collect();
97/// let even_count = parallel_count(numbers.par_iter(), |n| *n % 2 == 0);
98/// assert_eq!(even_count, 500);
99/// ```
100pub fn parallel_count<T, I, P>(items: I, predicate: P) -> usize
101where
102    T: Send,
103    I: ParallelIterator<Item = T>,
104    P: Fn(&T) -> bool + Sync + Send,
105{
106    items
107        .fold(|| 0usize, |count, item| count + predicate(&item) as usize)
108        .reduce(|| 0, |a, b| a + b)
109}
110
111/// Sum values extracted from items in parallel.
112///
113/// # Example
114///
115/// ```ignore
116/// use grafeo_core::execution::parallel::fold::parallel_sum;
117/// use rayon::prelude::*;
118///
119/// let items = vec![(1, "a"), (2, "b"), (3, "c")];
120/// let sum = parallel_sum(items.par_iter(), |(n, _)| *n as f64);
121/// assert_eq!(sum, 6.0);
122/// ```
123pub fn parallel_sum<T, I, F>(items: I, extract: F) -> f64
124where
125    T: Send,
126    I: ParallelIterator<Item = T>,
127    F: Fn(&T) -> f64 + Sync + Send,
128{
129    items
130        .fold(|| 0.0f64, |sum, item| sum + extract(&item))
131        .reduce(|| 0.0, |a, b| a + b)
132}
133
134/// Sum integers extracted from items in parallel.
135pub fn parallel_sum_i64<T, I, F>(items: I, extract: F) -> i64
136where
137    T: Send,
138    I: ParallelIterator<Item = T>,
139    F: Fn(&T) -> i64 + Sync + Send,
140{
141    items
142        .fold(|| 0i64, |sum, item| sum + extract(&item))
143        .reduce(|| 0, |a, b| a + b)
144}
145
146/// Find minimum value in parallel.
147///
148/// Returns `None` if the iterator is empty.
149pub fn parallel_min<T, I, F, V>(items: I, extract: F) -> Option<V>
150where
151    T: Send,
152    V: Send + Ord + Copy,
153    I: ParallelIterator<Item = T>,
154    F: Fn(&T) -> V + Sync + Send,
155{
156    items
157        .fold(
158            || None,
159            |min: Option<V>, item| {
160                let val = extract(&item);
161                Some(match min {
162                    Some(m) if m < val => m,
163                    _ => val,
164                })
165            },
166        )
167        .reduce(
168            || None,
169            |a, b| match (a, b) {
170                (Some(va), Some(vb)) => Some(if va < vb { va } else { vb }),
171                (Some(v), None) | (None, Some(v)) => Some(v),
172                (None, None) => None,
173            },
174        )
175}
176
177/// Find maximum value in parallel.
178///
179/// Returns `None` if the iterator is empty.
180pub fn parallel_max<T, I, F, V>(items: I, extract: F) -> Option<V>
181where
182    T: Send,
183    V: Send + Ord + Copy,
184    I: ParallelIterator<Item = T>,
185    F: Fn(&T) -> V + Sync + Send,
186{
187    items
188        .fold(
189            || None,
190            |max: Option<V>, item| {
191                let val = extract(&item);
192                Some(match max {
193                    Some(m) if m > val => m,
194                    _ => val,
195                })
196            },
197        )
198        .reduce(
199            || None,
200            |a, b| match (a, b) {
201                (Some(va), Some(vb)) => Some(if va > vb { va } else { vb }),
202                (Some(v), None) | (None, Some(v)) => Some(v),
203                (None, None) => None,
204            },
205        )
206}
207
208/// Collect results with errors separated.
209///
210/// Processes items in parallel, collecting successful results and errors
211/// into separate vectors. This is useful for batch operations where you
212/// want to continue processing even if some items fail.
213///
214/// # Example
215///
216/// ```ignore
217/// use grafeo_core::execution::parallel::fold::parallel_try_collect;
218/// use rayon::prelude::*;
219///
220/// let items = vec!["1", "two", "3", "four"];
221/// let (successes, errors) = parallel_try_collect(
222///     items.into_par_iter(),
223///     |s| s.parse::<i32>().map_err(|e| e.to_string()),
224/// );
225///
226/// assert_eq!(successes, vec![1, 3]);
227/// assert_eq!(errors.len(), 2);
228/// ```
229pub fn parallel_try_collect<T, E, I, F, R>(items: I, process: F) -> (Vec<R>, Vec<E>)
230where
231    T: Send,
232    E: Send,
233    R: Send,
234    I: ParallelIterator<Item = T>,
235    F: Fn(T) -> Result<R, E> + Sync + Send,
236{
237    items
238        .fold(
239            || (Vec::new(), Vec::new()),
240            |(mut ok, mut err), item| {
241                match process(item) {
242                    Ok(r) => ok.push(r),
243                    Err(e) => err.push(e),
244                }
245                (ok, err)
246            },
247        )
248        .reduce(
249            || (Vec::new(), Vec::new()),
250            |(mut ok1, mut err1), (ok2, err2)| {
251                ok1.extend(ok2);
252                err1.extend(err2);
253                (ok1, err1)
254            },
255        )
256}
257
258/// Compute multiple aggregations in a single parallel pass.
259///
260/// Returns (count, sum, min, max) for the extracted values.
261pub fn parallel_stats<T, I, F>(items: I, extract: F) -> (usize, f64, Option<f64>, Option<f64>)
262where
263    T: Send,
264    I: ParallelIterator<Item = T>,
265    F: Fn(&T) -> f64 + Sync + Send,
266{
267    items
268        .fold(
269            || (0usize, 0.0f64, None::<f64>, None::<f64>),
270            |(count, sum, min, max), item| {
271                let val = extract(&item);
272                (
273                    count + 1,
274                    sum + val,
275                    Some(match min {
276                        Some(m) if m < val => m,
277                        _ => val,
278                    }),
279                    Some(match max {
280                        Some(m) if m > val => m,
281                        _ => val,
282                    }),
283                )
284            },
285        )
286        .reduce(
287            || (0, 0.0, None, None),
288            |(c1, s1, min1, max1), (c2, s2, min2, max2)| {
289                let min = match (min1, min2) {
290                    (Some(a), Some(b)) => Some(a.min(b)),
291                    (Some(v), None) | (None, Some(v)) => Some(v),
292                    (None, None) => None,
293                };
294                let max = match (max1, max2) {
295                    (Some(a), Some(b)) => Some(a.max(b)),
296                    (Some(v), None) | (None, Some(v)) => Some(v),
297                    (None, None) => None,
298                };
299                (c1 + c2, s1 + s2, min, max)
300            },
301        )
302}
303
304/// Partition items into groups based on a key extractor.
305///
306/// Groups items with the same key into separate vectors.
307/// The keys must be hashable and cloneable.
308///
309/// # Example
310///
311/// ```ignore
312/// use grafeo_core::execution::parallel::fold::parallel_partition;
313/// use rayon::prelude::*;
314///
315/// let items = vec![(1, "a"), (2, "b"), (1, "c"), (2, "d")];
316/// let groups = parallel_partition(items.into_par_iter(), |(k, _)| *k);
317///
318/// assert_eq!(groups.get(&1).unwrap(), &vec!["a", "c"]);
319/// assert_eq!(groups.get(&2).unwrap(), &vec!["b", "d"]);
320/// ```
321pub fn parallel_partition<T, I, K, V, KeyFn, ValFn>(
322    items: I,
323    key_fn: KeyFn,
324    val_fn: ValFn,
325) -> std::collections::HashMap<K, Vec<V>>
326where
327    T: Send,
328    K: Send + Eq + std::hash::Hash + Clone,
329    V: Send,
330    I: ParallelIterator<Item = T>,
331    KeyFn: Fn(&T) -> K + Sync + Send,
332    ValFn: Fn(T) -> V + Sync + Send,
333{
334    items
335        .fold(std::collections::HashMap::new, |mut map, item| {
336            let key = key_fn(&item);
337            let val = val_fn(item);
338            map.entry(key).or_insert_with(Vec::new).push(val);
339            map
340        })
341        .reduce(std::collections::HashMap::new, |mut map1, map2| {
342            for (key, mut values) in map2 {
343                map1.entry(key).or_insert_with(Vec::new).append(&mut values);
344            }
345            map1
346        })
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_parallel_count() {
355        let numbers: Vec<i32> = (0..1000).collect();
356        let even_count = parallel_count(numbers.par_iter(), |n| *n % 2 == 0);
357        assert_eq!(even_count, 500);
358    }
359
360    #[test]
361    fn test_parallel_sum() {
362        let numbers: Vec<i32> = (1..=100).collect();
363        let total = parallel_sum(numbers.par_iter(), |n| f64::from(**n));
364        assert!((total - 5050.0).abs() < 0.001);
365    }
366
367    #[test]
368    fn test_parallel_sum_i64() {
369        let numbers: Vec<i32> = (1..=100).collect();
370        let total = parallel_sum_i64(numbers.par_iter(), |n| i64::from(**n));
371        assert_eq!(total, 5050);
372    }
373
374    #[test]
375    fn test_parallel_min() {
376        let numbers: Vec<i32> = vec![5, 3, 8, 1, 9, 2];
377        let min = parallel_min(numbers.par_iter(), |n| *n);
378        assert_eq!(min, Some(&1));
379
380        let empty: Vec<i32> = vec![];
381        let min_empty: Option<&i32> = parallel_min(empty.par_iter(), |n| *n);
382        assert_eq!(min_empty, None);
383    }
384
385    #[test]
386    fn test_parallel_max() {
387        let numbers: Vec<i32> = vec![5, 3, 8, 1, 9, 2];
388        let max = parallel_max(numbers.par_iter(), |n| *n);
389        assert_eq!(max, Some(&9));
390    }
391
392    #[test]
393    fn test_parallel_try_collect() {
394        let items = vec!["1", "two", "3", "four", "5"];
395        let (successes, errors): (Vec<i32>, Vec<String>) =
396            parallel_try_collect(items.into_par_iter(), |s| {
397                s.parse::<i32>().map_err(|e| e.to_string())
398            });
399
400        assert_eq!(successes.len(), 3);
401        assert!(successes.contains(&1));
402        assert!(successes.contains(&3));
403        assert!(successes.contains(&5));
404        assert_eq!(errors.len(), 2);
405    }
406
407    #[test]
408    fn test_parallel_stats() {
409        let numbers: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
410        let (count, sum, min, max) = parallel_stats(numbers.into_par_iter(), |n| *n);
411
412        assert_eq!(count, 5);
413        assert!((sum - 15.0).abs() < 0.001);
414        assert!((min.unwrap() - 1.0).abs() < 0.001);
415        assert!((max.unwrap() - 5.0).abs() < 0.001);
416    }
417
418    #[test]
419    fn test_parallel_partition() {
420        let items: Vec<(i32, &str)> = vec![(1, "a"), (2, "b"), (1, "c"), (2, "d"), (1, "e")];
421        let groups = parallel_partition(items.into_par_iter(), |(k, _)| *k, |(_, v)| v);
422
423        assert_eq!(groups.get(&1).map(|v| v.len()), Some(3));
424        assert_eq!(groups.get(&2).map(|v| v.len()), Some(2));
425    }
426
427    #[test]
428    fn test_fold_reduce() {
429        let items: Vec<i32> = (1..=10).collect();
430        let sum: i32 = fold_reduce(items.into_par_iter(), |acc, item| acc + item, |a, b| a + b);
431        assert_eq!(sum, 55);
432    }
433
434    #[test]
435    fn test_fold_reduce_with_custom_init() {
436        let items: Vec<i32> = (1..=10).collect();
437        let sum: i32 = fold_reduce_with(
438            items.into_par_iter(),
439            || 100, // Start from 100 in each thread
440            |acc, item| acc + item,
441            |a, b| a + b - 100, // Subtract extra 100s when merging
442        );
443        // This won't work correctly due to the nature of fold_reduce_with
444        // but demonstrates custom init usage
445        assert!(sum >= 55);
446    }
447}