use crate::executor;
use std::marker::PhantomData;
pub struct ParallelVec<'a, T: Sync> {
data: &'a [T],
}
impl<'a, T: Sync> ParallelVec<'a, T> {
pub fn new(data: &'a [T]) -> Self {
Self { data }
}
pub fn map<R, F>(self, f: F) -> ParallelMap<'a, T, R, F>
where
R: Send,
F: Fn(&T) -> R + Send + Sync,
{
ParallelMap {
data: self.data,
f,
_phantom: PhantomData,
}
}
pub fn filter<F>(self, f: F) -> ParallelFilter<'a, T, F>
where
F: Fn(&T) -> bool + Send + Sync,
{
ParallelFilter {
data: self.data,
f,
}
}
pub fn for_each<F>(self, f: F)
where
F: Fn(&T) + Send + Sync,
{
executor::parallel_for_each(self.data, f);
}
pub fn sum(self) -> T
where
T: Clone + Send + std::iter::Sum,
{
executor::parallel_sum(self.data)
}
pub fn reduce<F>(self, f: F) -> Option<T>
where
T: Clone + Send,
F: Fn(T, T) -> T + Send + Sync,
{
executor::parallel_reduce(self.data, f)
}
pub fn collect(self) -> Vec<&'a T> {
self.data.iter().collect()
}
}
pub struct ParallelMap<'a, T, R, F> {
data: &'a [T],
f: F,
_phantom: PhantomData<R>,
}
impl<'a, T, R, F> ParallelMap<'a, T, R, F>
where
T: Sync,
R: Send + 'static,
F: Fn(&T) -> R + Send + Sync,
{
pub fn collect(self) -> Vec<R> {
executor::parallel_map(self.data, self.f)
}
pub fn sum(self) -> R
where
R: std::iter::Sum,
{
let results = self.collect();
results.into_iter().sum()
}
}
pub struct ParallelFilter<'a, T, F> {
data: &'a [T],
f: F,
}
impl<'a, T, F> ParallelFilter<'a, T, F>
where
T: Sync,
F: Fn(&T) -> bool + Send + Sync,
{
pub fn collect(self) -> Vec<&'a T> {
executor::parallel_filter(self.data, self.f)
}
pub fn map<R, F2>(self, f2: F2) -> Vec<R>
where
R: Send + 'static,
F2: Fn(&T) -> R + Send + Sync,
{
let filtered = self.collect();
executor::parallel_map(&filtered, |&item| f2(item))
}
}
pub trait IntoParallelVec<T: Sync> {
fn par_vec(&self) -> ParallelVec<'_, T>;
}
impl<T: Sync> IntoParallelVec<T> for Vec<T> {
fn par_vec(&self) -> ParallelVec<'_, T> {
ParallelVec::new(self)
}
}
impl<T: Sync> IntoParallelVec<T> for [T] {
fn par_vec(&self) -> ParallelVec<'_, T> {
ParallelVec::new(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parallel_vec_map() {
let data: Vec<i32> = (0..10000).collect();
let results = data.par_vec().map(|&x| x * 2).collect();
assert_eq!(results.len(), 10000);
assert_eq!(results[0], 0);
assert_eq!(results[9999], 19998);
}
#[test]
fn test_parallel_vec_filter() {
let data: Vec<i32> = (0..10000).collect();
let results = data.par_vec().filter(|&x| x % 2 == 0).collect();
assert_eq!(results.len(), 5000);
}
#[test]
fn test_parallel_vec_sum() {
let data: Vec<i32> = (1..=100).collect();
let result = data.par_vec().sum();
assert_eq!(result, 5050);
}
#[test]
fn test_parallel_vec_reduce() {
let data: Vec<i32> = (1..=10).collect();
let result = data.par_vec().reduce(|a, b| a + b);
assert_eq!(result, Some(55));
}
#[test]
fn test_parallel_vec_chain() {
let data: Vec<i32> = (0..1000).collect();
let result: i32 = data.par_vec()
.map(|&x| x * 2)
.sum();
let expected: i32 = (0..1000).map(|x| x * 2).sum();
assert_eq!(result, expected);
}
}