use rayon::prelude::*;
pub trait Mergeable: Send + Default {
fn merge(&mut self, other: Self);
}
pub fn fold_reduce<T, I, F, M>(items: I, fold_fn: F, merge_fn: M) -> T
where
T: Send + Default,
I: ParallelIterator,
F: Fn(T, I::Item) -> T + Sync + Send,
M: Fn(T, T) -> T + Sync + Send,
{
items.fold(T::default, fold_fn).reduce(T::default, merge_fn)
}
pub fn fold_reduce_with<T, I, Init, F, M>(items: I, init: Init, fold_fn: F, merge_fn: M) -> T
where
T: Send,
I: ParallelIterator,
Init: Fn() -> T + Sync + Send + Clone,
F: Fn(T, I::Item) -> T + Sync + Send,
M: Fn(T, T) -> T + Sync + Send,
{
items.fold(init.clone(), fold_fn).reduce(init, merge_fn)
}
pub fn parallel_count<T, I, P>(items: I, predicate: P) -> usize
where
T: Send,
I: ParallelIterator<Item = T>,
P: Fn(&T) -> bool + Sync + Send,
{
items
.fold(|| 0usize, |count, item| count + predicate(&item) as usize)
.reduce(|| 0, |a, b| a + b)
}
pub fn parallel_sum<T, I, F>(items: I, extract: F) -> f64
where
T: Send,
I: ParallelIterator<Item = T>,
F: Fn(&T) -> f64 + Sync + Send,
{
items
.fold(|| 0.0f64, |sum, item| sum + extract(&item))
.reduce(|| 0.0, |a, b| a + b)
}
pub fn parallel_sum_i64<T, I, F>(items: I, extract: F) -> i64
where
T: Send,
I: ParallelIterator<Item = T>,
F: Fn(&T) -> i64 + Sync + Send,
{
items
.fold(|| 0i64, |sum, item| sum + extract(&item))
.reduce(|| 0, |a, b| a + b)
}
pub fn parallel_min<T, I, F, V>(items: I, extract: F) -> Option<V>
where
T: Send,
V: Send + Ord + Copy,
I: ParallelIterator<Item = T>,
F: Fn(&T) -> V + Sync + Send,
{
items
.fold(
|| None,
|min: Option<V>, item| {
let val = extract(&item);
Some(match min {
Some(m) if m < val => m,
_ => val,
})
},
)
.reduce(
|| None,
|a, b| match (a, b) {
(Some(va), Some(vb)) => Some(if va < vb { va } else { vb }),
(Some(v), None) | (None, Some(v)) => Some(v),
(None, None) => None,
},
)
}
pub fn parallel_max<T, I, F, V>(items: I, extract: F) -> Option<V>
where
T: Send,
V: Send + Ord + Copy,
I: ParallelIterator<Item = T>,
F: Fn(&T) -> V + Sync + Send,
{
items
.fold(
|| None,
|max: Option<V>, item| {
let val = extract(&item);
Some(match max {
Some(m) if m > val => m,
_ => val,
})
},
)
.reduce(
|| None,
|a, b| match (a, b) {
(Some(va), Some(vb)) => Some(if va > vb { va } else { vb }),
(Some(v), None) | (None, Some(v)) => Some(v),
(None, None) => None,
},
)
}
pub fn parallel_try_collect<T, E, I, F, R>(items: I, process: F) -> (Vec<R>, Vec<E>)
where
T: Send,
E: Send,
R: Send,
I: ParallelIterator<Item = T>,
F: Fn(T) -> Result<R, E> + Sync + Send,
{
items
.fold(
|| (Vec::new(), Vec::new()),
|(mut ok, mut err), item| {
match process(item) {
Ok(r) => ok.push(r),
Err(e) => err.push(e),
}
(ok, err)
},
)
.reduce(
|| (Vec::new(), Vec::new()),
|(mut ok1, mut err1), (ok2, err2)| {
ok1.extend(ok2);
err1.extend(err2);
(ok1, err1)
},
)
}
pub fn parallel_stats<T, I, F>(items: I, extract: F) -> (usize, f64, Option<f64>, Option<f64>)
where
T: Send,
I: ParallelIterator<Item = T>,
F: Fn(&T) -> f64 + Sync + Send,
{
items
.fold(
|| (0usize, 0.0f64, None::<f64>, None::<f64>),
|(count, sum, min, max), item| {
let val = extract(&item);
(
count + 1,
sum + val,
Some(match min {
Some(m) if m < val => m,
_ => val,
}),
Some(match max {
Some(m) if m > val => m,
_ => val,
}),
)
},
)
.reduce(
|| (0, 0.0, None, None),
|(c1, s1, min1, max1), (c2, s2, min2, max2)| {
let min = match (min1, min2) {
(Some(a), Some(b)) => Some(a.min(b)),
(Some(v), None) | (None, Some(v)) => Some(v),
(None, None) => None,
};
let max = match (max1, max2) {
(Some(a), Some(b)) => Some(a.max(b)),
(Some(v), None) | (None, Some(v)) => Some(v),
(None, None) => None,
};
(c1 + c2, s1 + s2, min, max)
},
)
}
pub fn parallel_partition<T, I, K, V, KeyFn, ValFn>(
items: I,
key_fn: KeyFn,
val_fn: ValFn,
) -> std::collections::HashMap<K, Vec<V>>
where
T: Send,
K: Send + Eq + std::hash::Hash + Clone,
V: Send,
I: ParallelIterator<Item = T>,
KeyFn: Fn(&T) -> K + Sync + Send,
ValFn: Fn(T) -> V + Sync + Send,
{
items
.fold(std::collections::HashMap::new, |mut map, item| {
let key = key_fn(&item);
let val = val_fn(item);
map.entry(key).or_insert_with(Vec::new).push(val);
map
})
.reduce(std::collections::HashMap::new, |mut map1, map2| {
for (key, mut values) in map2 {
map1.entry(key).or_insert_with(Vec::new).append(&mut values);
}
map1
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parallel_count() {
let numbers: Vec<i32> = (0..1000).collect();
let even_count = parallel_count(numbers.par_iter(), |n| *n % 2 == 0);
assert_eq!(even_count, 500);
}
#[test]
fn test_parallel_sum() {
let numbers: Vec<i32> = (1..=100).collect();
let total = parallel_sum(numbers.par_iter(), |n| f64::from(**n));
assert!((total - 5050.0).abs() < 0.001);
}
#[test]
fn test_parallel_sum_i64() {
let numbers: Vec<i32> = (1..=100).collect();
let total = parallel_sum_i64(numbers.par_iter(), |n| i64::from(**n));
assert_eq!(total, 5050);
}
#[test]
fn test_parallel_min() {
let numbers: Vec<i32> = vec![5, 3, 8, 1, 9, 2];
let min = parallel_min(numbers.par_iter(), |n| *n);
assert_eq!(min, Some(&1));
let empty: Vec<i32> = vec![];
let min_empty: Option<&i32> = parallel_min(empty.par_iter(), |n| *n);
assert_eq!(min_empty, None);
}
#[test]
fn test_parallel_max() {
let numbers: Vec<i32> = vec![5, 3, 8, 1, 9, 2];
let max = parallel_max(numbers.par_iter(), |n| *n);
assert_eq!(max, Some(&9));
}
#[test]
fn test_parallel_try_collect() {
let items = vec!["1", "two", "3", "four", "5"];
let (successes, errors): (Vec<i32>, Vec<String>) =
parallel_try_collect(items.into_par_iter(), |s| {
s.parse::<i32>().map_err(|e| e.to_string())
});
assert_eq!(successes.len(), 3);
assert!(successes.contains(&1));
assert!(successes.contains(&3));
assert!(successes.contains(&5));
assert_eq!(errors.len(), 2);
}
#[test]
fn test_parallel_stats() {
let numbers: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let (count, sum, min, max) = parallel_stats(numbers.into_par_iter(), |n| *n);
assert_eq!(count, 5);
assert!((sum - 15.0).abs() < 0.001);
assert!((min.unwrap() - 1.0).abs() < 0.001);
assert!((max.unwrap() - 5.0).abs() < 0.001);
}
#[test]
fn test_parallel_partition() {
let items: Vec<(i32, &str)> = vec![(1, "a"), (2, "b"), (1, "c"), (2, "d"), (1, "e")];
let groups = parallel_partition(items.into_par_iter(), |(k, _)| *k, |(_, v)| v);
assert_eq!(groups.get(&1).map(|v| v.len()), Some(3));
assert_eq!(groups.get(&2).map(|v| v.len()), Some(2));
}
#[test]
fn test_fold_reduce() {
let items: Vec<i32> = (1..=10).collect();
let sum: i32 = fold_reduce(items.into_par_iter(), |acc, item| acc + item, |a, b| a + b);
assert_eq!(sum, 55);
}
#[test]
fn test_fold_reduce_with_custom_init() {
let items: Vec<i32> = (1..=10).collect();
let sum: i32 = fold_reduce_with(
items.into_par_iter(),
|| 100, |acc, item| acc + item,
|a, b| a + b - 100, );
assert!(sum >= 55);
}
}