#![doc(
html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
html_favicon_url = "https://commonware.xyz/favicon.ico"
)]
#![cfg_attr(not(any(feature = "std", test)), no_std)]
commonware_macros::stability_scope!(BETA {
use cfg_if::cfg_if;
use core::fmt;
cfg_if! {
if #[cfg(feature = "std")] {
use rayon::{
iter::{IntoParallelIterator, ParallelIterator},
ThreadPool as RThreadPool, ThreadPoolBuildError, ThreadPoolBuilder,
};
use std::{num::NonZeroUsize, sync::Arc};
} else {
extern crate alloc;
use alloc::vec::Vec;
}
}
pub trait Strategy: Clone + Send + Sync + fmt::Debug + 'static {
fn fold_init<I, INIT, T, R, ID, F, RD>(
&self,
iter: I,
init: INIT,
identity: ID,
fold_op: F,
reduce_op: RD,
) -> R
where
I: IntoIterator<IntoIter: Send, Item: Send> + Send,
INIT: Fn() -> T + Send + Sync,
T: Send,
R: Send,
ID: Fn() -> R + Send + Sync,
F: Fn(R, &mut T, I::Item) -> R + Send + Sync,
RD: Fn(R, R) -> R + Send + Sync;
fn fold<I, R, ID, F, RD>(&self, iter: I, identity: ID, fold_op: F, reduce_op: RD) -> R
where
I: IntoIterator<IntoIter: Send, Item: Send> + Send,
R: Send,
ID: Fn() -> R + Send + Sync,
F: Fn(R, I::Item) -> R + Send + Sync,
RD: Fn(R, R) -> R + Send + Sync,
{
self.fold_init(
iter,
|| (),
identity,
|acc, _, item| fold_op(acc, item),
reduce_op,
)
}
fn map_collect_vec<I, F, T>(&self, iter: I, map_op: F) -> Vec<T>
where
I: IntoIterator<IntoIter: Send, Item: Send> + Send,
F: Fn(I::Item) -> T + Send + Sync,
T: Send,
{
self.fold(
iter,
Vec::new,
|mut acc, item| {
acc.push(map_op(item));
acc
},
|mut a, b| {
a.extend(b);
a
},
)
}
fn map_init_collect_vec<I, INIT, T, F, R>(&self, iter: I, init: INIT, map_op: F) -> Vec<R>
where
I: IntoIterator<IntoIter: Send, Item: Send> + Send,
INIT: Fn() -> T + Send + Sync,
T: Send,
F: Fn(&mut T, I::Item) -> R + Send + Sync,
R: Send,
{
self.fold_init(
iter,
init,
Vec::new,
|mut acc, init_val, item| {
acc.push(map_op(init_val, item));
acc
},
|mut a, b| {
a.extend(b);
a
},
)
}
fn map_partition_collect_vec<I, F, K, U>(&self, iter: I, map_op: F) -> (Vec<U>, Vec<K>)
where
I: IntoIterator<IntoIter: Send, Item: Send> + Send,
F: Fn(I::Item) -> (K, Option<U>) + Send + Sync,
K: Send,
U: Send,
{
self.fold(
iter,
|| (Vec::new(), Vec::new()),
|(mut results, mut filtered), item| {
let (key, value) = map_op(item);
match value {
Some(v) => results.push(v),
None => filtered.push(key),
}
(results, filtered)
},
|(mut r1, mut f1), (r2, f2)| {
r1.extend(r2);
f1.extend(f2);
(r1, f1)
},
)
}
fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
RA: Send,
RB: Send;
fn parallelism_hint(&self) -> usize;
}
#[derive(Default, Debug, Clone)]
pub struct Sequential;
impl Strategy for Sequential {
fn fold_init<I, INIT, T, R, ID, F, RD>(
&self,
iter: I,
init: INIT,
identity: ID,
fold_op: F,
_reduce_op: RD,
) -> R
where
I: IntoIterator<IntoIter: Send, Item: Send> + Send,
INIT: Fn() -> T + Send + Sync,
T: Send,
R: Send,
ID: Fn() -> R + Send + Sync,
F: Fn(R, &mut T, I::Item) -> R + Send + Sync,
RD: Fn(R, R) -> R + Send + Sync,
{
let mut init_val = init();
iter.into_iter()
.fold(identity(), |acc, item| fold_op(acc, &mut init_val, item))
}
fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
RA: Send,
RB: Send,
{
(a(), b())
}
fn parallelism_hint(&self) -> usize {
1
}
}
});
commonware_macros::stability_scope!(BETA, cfg(feature = "std") {
pub type ThreadPool = Arc<RThreadPool>;
#[derive(Debug, Clone)]
pub struct Rayon {
thread_pool: ThreadPool,
}
impl Rayon {
pub fn new(num_threads: NonZeroUsize) -> Result<Self, ThreadPoolBuildError> {
ThreadPoolBuilder::new()
.num_threads(num_threads.get())
.build()
.map(|pool| Self::with_pool(Arc::new(pool)))
}
pub const fn with_pool(thread_pool: ThreadPool) -> Self {
Self { thread_pool }
}
}
impl Strategy for Rayon {
fn fold_init<I, INIT, T, R, ID, F, RD>(
&self,
iter: I,
init: INIT,
identity: ID,
fold_op: F,
reduce_op: RD,
) -> R
where
I: IntoIterator<IntoIter: Send, Item: Send> + Send,
INIT: Fn() -> T + Send + Sync,
T: Send,
R: Send,
ID: Fn() -> R + Send + Sync,
F: Fn(R, &mut T, I::Item) -> R + Send + Sync,
RD: Fn(R, R) -> R + Send + Sync,
{
self.thread_pool.install(|| {
let items: Vec<I::Item> = iter.into_iter().collect();
items
.into_par_iter()
.fold(
|| (init(), identity()),
|(mut init_val, acc), item| {
let new_acc = fold_op(acc, &mut init_val, item);
(init_val, new_acc)
},
)
.map(|(_, acc)| acc)
.reduce(&identity, reduce_op)
})
}
fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
RA: Send,
RB: Send,
{
self.thread_pool.install(|| rayon::join(a, b))
}
fn parallelism_hint(&self) -> usize {
self.thread_pool.current_num_threads()
}
}
});
#[cfg(test)]
mod test {
use crate::{Rayon, Sequential, Strategy};
use core::num::NonZeroUsize;
use proptest::prelude::*;
fn parallel_strategy() -> Rayon {
Rayon::new(NonZeroUsize::new(4).unwrap()).unwrap()
}
proptest! {
#[test]
fn parallel_fold_init_matches_sequential(data in prop::collection::vec(any::<i32>(), 0..500)) {
let sequential = Sequential;
let parallel = parallel_strategy();
let seq_result: Vec<i32> = sequential.fold_init(
&data,
|| (),
Vec::new,
|mut acc, _, &x| { acc.push(x.wrapping_mul(2)); acc },
|mut a, b| { a.extend(b); a },
);
let par_result: Vec<i32> = parallel.fold_init(
&data,
|| (),
Vec::new,
|mut acc, _, &x| { acc.push(x.wrapping_mul(2)); acc },
|mut a, b| { a.extend(b); a },
);
prop_assert_eq!(seq_result, par_result);
}
#[test]
fn fold_equals_fold_init(data in prop::collection::vec(any::<i32>(), 0..500)) {
let s = Sequential;
let via_fold: Vec<i32> = s.fold(
&data,
Vec::new,
|mut acc, &x| { acc.push(x); acc },
|mut a, b| { a.extend(b); a },
);
let via_fold_init: Vec<i32> = s.fold_init(
&data,
|| (),
Vec::new,
|mut acc, _, &x| { acc.push(x); acc },
|mut a, b| { a.extend(b); a },
);
prop_assert_eq!(via_fold, via_fold_init);
}
#[test]
fn map_collect_vec_equals_fold(data in prop::collection::vec(any::<i32>(), 0..500)) {
let s = Sequential;
let map_op = |&x: &i32| x.wrapping_mul(3);
let via_map: Vec<i32> = s.map_collect_vec(&data, map_op);
let via_fold: Vec<i32> = s.fold(
&data,
Vec::new,
|mut acc, item| { acc.push(map_op(item)); acc },
|mut a, b| { a.extend(b); a },
);
prop_assert_eq!(via_map, via_fold);
}
#[test]
fn map_init_collect_vec_equals_fold_init(data in prop::collection::vec(any::<i32>(), 0..500)) {
let s = Sequential;
let via_map: Vec<i32> = s.map_init_collect_vec(
&data,
|| 0i32,
|counter, &x| { *counter += 1; x.wrapping_add(*counter) },
);
let via_fold_init: Vec<i32> = s.fold_init(
&data,
|| 0i32,
Vec::new,
|mut acc, counter, &x| {
*counter += 1;
acc.push(x.wrapping_add(*counter));
acc
},
|mut a, b| { a.extend(b); a },
);
prop_assert_eq!(via_map, via_fold_init);
}
#[test]
fn map_partition_collect_vec_returns_valid_results(data in prop::collection::vec(any::<i32>(), 0..500)) {
let s = Sequential;
let map_op = |&x: &i32| {
let value = if x % 2 == 0 { Some(x.wrapping_mul(2)) } else { None };
(x, value)
};
let (results, filtered) = s.map_partition_collect_vec(data.iter(), map_op);
let expected_results: Vec<i32> = data.iter().filter(|&&x| x % 2 == 0).map(|&x| x.wrapping_mul(2)).collect();
prop_assert_eq!(results, expected_results);
let expected_filtered: Vec<i32> = data.iter().filter(|&&x| x % 2 != 0).copied().collect();
prop_assert_eq!(filtered, expected_filtered);
}
}
}