use crate::runtime;
use futures::{Stream, Future};
use std::pin::Pin;
use std::sync::Arc;
use async_channel::{bounded, Sender, Receiver};
use parking_lot::Mutex;
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
pub trait ParStreamExt: Stream + Sized {
fn par_for_each<F, Fut>(self, f: F) -> ParForEach<Self, F>
where
F: Fn(Self::Item) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
Self::Item: Send + 'static;
fn par_map<F, Fut, R>(self, f: F) -> ParMap<Self, F>
where
F: Fn(Self::Item) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
R: Send + 'static,
Self::Item: Send + 'static;
fn par_filter<F, Fut>(self, f: F) -> ParFilter<Self, F>
where
F: Fn(&Self::Item) -> Fut + Send + Sync + 'static,
Fut: Future<Output = bool> + Send + 'static,
Self::Item: Send + 'static;
}
impl<S> ParStreamExt for S
where
S: Stream + Send + 'static,
S::Item: Send + 'static,
{
fn par_for_each<F, Fut>(self, f: F) -> ParForEach<Self, F>
where
F: Fn(Self::Item) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
ParForEach {
stream: self,
func: Arc::new(f),
buffer_size: 1000,
max_concurrency: num_cpus::get(),
}
}
fn par_map<F, Fut, R>(self, f: F) -> ParMap<Self, F>
where
F: Fn(Self::Item) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
R: Send + 'static,
{
ParMap {
stream: self,
func: Arc::new(f),
buffer_size: 1000,
max_concurrency: num_cpus::get(),
}
}
fn par_filter<F, Fut>(self, f: F) -> ParFilter<Self, F>
where
F: Fn(&Self::Item) -> Fut + Send + Sync + 'static,
Fut: Future<Output = bool> + Send + 'static,
{
ParFilter {
stream: self,
func: Arc::new(f),
buffer_size: 1000,
max_concurrency: num_cpus::get(),
}
}
}
pub struct ParForEach<S, F> {
stream: S,
func: Arc<F>,
buffer_size: usize,
max_concurrency: usize,
}
impl<S, F, Fut> ParForEach<S, F>
where
S: Stream + Send + 'static,
S::Item: Send + 'static,
F: Fn(S::Item) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
pub async fn execute(self) {
use futures::StreamExt;
use crate::runtime;
let (tx, rx): (Sender<S::Item>, Receiver<S::Item>) = bounded(self.buffer_size);
let func = self.func.clone();
let completion_signals: Vec<_> = (0..self.max_concurrency)
.map(|_| {
let rx = rx.clone();
let func = func.clone();
let (done_tx, done_rx) = bounded(1);
runtime::with_current_runtime(|rt| {
rt.pool.execute(move || {
futures::executor::block_on(async {
while let Ok(item) = rx.recv().await {
func(item).await;
}
let _ = done_tx.try_send(());
});
});
});
done_rx
})
.collect();
let mut stream = Box::pin(self.stream);
while let Some(item) = stream.next().await {
if tx.send(item).await.is_err() {
break;
}
}
drop(tx);
for done_rx in completion_signals {
let _ = done_rx.recv().await;
}
}
}
pub struct ParMap<S, F> {
stream: S,
func: Arc<F>,
buffer_size: usize,
max_concurrency: usize,
}
impl<S, F, Fut, R> ParMap<S, F>
where
S: Stream + Send + 'static,
S::Item: Send + 'static,
F: Fn(S::Item) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
R: Send + 'static,
{
pub async fn collect(self) -> Vec<R> {
use futures::StreamExt;
use crate::runtime;
let results = Arc::new(Mutex::new(Vec::new()));
let (tx, rx): (Sender<(usize, S::Item)>, Receiver<(usize, S::Item)>) = bounded(self.buffer_size);
let func = self.func.clone();
let completion_signals: Vec<_> = (0..self.max_concurrency)
.map(|_| {
let rx = rx.clone();
let func = func.clone();
let results = results.clone();
let (done_tx, done_rx) = bounded(1);
runtime::with_current_runtime(|rt| {
rt.pool.execute(move || {
futures::executor::block_on(async {
while let Ok((idx, item)) = rx.recv().await {
let result = func(item).await;
results.lock().push((idx, result));
}
let _ = done_tx.try_send(());
});
});
});
done_rx
})
.collect();
let mut stream = Box::pin(self.stream);
let mut idx = 0;
while let Some(item) = stream.next().await {
if tx.send((idx, item)).await.is_err() {
break;
}
idx += 1;
}
drop(tx);
for done_rx in completion_signals {
let _ = done_rx.recv().await;
}
let mut result_vec = match Arc::try_unwrap(results) {
Ok(mutex) => mutex.into_inner(),
Err(arc) => {
let mut guard = arc.lock();
std::mem::take(&mut *guard)
}
};
result_vec.sort_by_key(|(idx, _)| *idx);
result_vec.into_iter().map(|(_, r)| r).collect()
}
}
pub struct ParFilter<S, F> {
stream: S,
func: Arc<F>,
buffer_size: usize,
max_concurrency: usize,
}
impl<S, F, Fut> ParFilter<S, F>
where
S: Stream + Send + 'static,
S::Item: Send + Clone + 'static,
F: Fn(&S::Item) -> Fut + Send + Sync + 'static,
Fut: Future<Output = bool> + Send + 'static,
{
pub async fn collect(self) -> Vec<S::Item> {
use futures::StreamExt;
use crate::runtime;
let results = Arc::new(Mutex::new(Vec::new()));
let (tx, rx): (Sender<S::Item>, Receiver<S::Item>) = bounded(self.buffer_size);
let func = self.func.clone();
let completion_signals: Vec<_> = (0..self.max_concurrency)
.map(|_| {
let rx = rx.clone();
let func = func.clone();
let results = results.clone();
let (done_tx, done_rx) = bounded(1);
runtime::with_current_runtime(|rt| {
rt.pool.execute(move || {
futures::executor::block_on(async {
while let Ok(item) = rx.recv().await {
if func(&item).await {
results.lock().push(item);
}
}
let _ = done_tx.try_send(());
});
});
});
done_rx
})
.collect();
let mut stream = Box::pin(self.stream);
while let Some(item) = stream.next().await {
if tx.send(item).await.is_err() {
break;
}
}
drop(tx);
for done_rx in completion_signals {
let _ = done_rx.recv().await;
}
match Arc::try_unwrap(results) {
Ok(mutex) => mutex.into_inner(),
Err(arc) => {
let mut guard = arc.lock();
std::mem::take(&mut *guard)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::stream;
#[test]
fn test_par_for_each() {
use std::sync::atomic::{AtomicUsize, Ordering};
crate::runtime::shutdown();
crate::runtime::init().unwrap();
let counter = Arc::new(AtomicUsize::new(0));
let items = stream::iter(0..10);
let counter_clone = counter.clone();
let fut = items.par_for_each(move |_| {
let c = counter_clone.clone();
async move {
c.fetch_add(1, Ordering::Relaxed);
}
}).execute();
futures::executor::block_on(fut);
assert_eq!(counter.load(Ordering::Relaxed), 10);
crate::runtime::shutdown();
}
#[test]
fn test_par_map() {
crate::runtime::shutdown();
crate::runtime::init().unwrap();
let items = stream::iter(0..10);
let fut = items.par_map(|x| async move { x * 2 }).collect();
let results = futures::executor::block_on(fut);
assert_eq!(results.len(), 10);
assert!(results.contains(&0));
assert!(results.contains(&18));
crate::runtime::shutdown();
}
#[test]
fn test_par_filter() {
crate::runtime::shutdown();
crate::runtime::init().unwrap();
let items = stream::iter(0..10);
let fut = items.par_filter(|x| async move { *x % 2 == 0 }).collect();
let results = futures::executor::block_on(fut);
assert_eq!(results.len(), 5);
assert!(results.contains(&0));
assert!(results.contains(&8));
crate::runtime::shutdown();
}
}