use std::future::Future;
use futures::{
future::{self, BoxFuture},
stream::FuturesUnordered,
FutureExt, Stream, StreamExt,
};
use tokio::{
sync::mpsc::{self, Receiver},
task::{JoinError, JoinHandle},
};
use crate::{
concurrency::Concurrency,
pumps::{
and_then::AndThenPump,
catch::CatchPump,
filter_map::FilterMapPump,
flatten::{FlattenConcurrency, FlattenPump},
flatten_iter::FlattenIterPump,
map::MapPump,
map_err::MapErrPump,
map_ok::MapOkPump,
try_filter_map::TryFilterMapPump,
},
Pump,
};
pub struct Pipeline<Out> {
pub(crate) output_receiver: Receiver<Out>,
handles: FuturesUnordered<JoinHandle<()>>,
}
impl<Out> From<Receiver<Out>> for Pipeline<Out> {
fn from(receiver: Receiver<Out>) -> Self {
Pipeline {
output_receiver: receiver,
handles: FuturesUnordered::new(),
}
}
}
impl<Out> Pipeline<Out>
where
Out: Send + 'static,
{
pub fn from_stream(stream: impl Stream<Item = Out> + Send + 'static) -> Self {
let (output_sender, output_receiver) = mpsc::channel(1);
let h = tokio::spawn(async move {
tokio::pin!(stream);
while let Some(output) = stream.next().await {
if let Err(_e) = output_sender.send(output).await {
break;
}
}
});
Pipeline {
output_receiver,
handles: [h].into_iter().collect(),
}
}
#[allow(clippy::should_implement_trait)] pub fn from_iter<I>(iter: I) -> Self
where
I: IntoIterator<Item = Out> + Send + 'static,
<I as IntoIterator>::IntoIter: std::marker::Send,
{
let (output_sender, output_receiver) = mpsc::channel(1);
let h = tokio::spawn(async move {
let iter = iter.into_iter();
for output in iter {
if let Err(_e) = output_sender.send(output).await {
break;
};
}
});
Pipeline {
output_receiver,
handles: [h].into_iter().collect(),
}
}
pub fn pump<P, T>(self, pump: P) -> Pipeline<T>
where
P: Pump<Out, T>,
{
let (pump_output_receiver, join_handle) = pump.spawn(self.output_receiver);
let handles = self.handles;
handles.push(join_handle);
Pipeline {
output_receiver: pump_output_receiver,
handles,
}
}
pub fn map<F, Fut, T>(self, map_fn: F, concurrency: Concurrency) -> Pipeline<T>
where
F: Fn(Out) -> Fut + Send + 'static,
Fut: Future<Output = T> + Send + 'static,
T: Send + 'static,
Out: Send + 'static,
{
self.pump(MapPump {
map_fn,
concurrency,
})
}
pub fn filter_map<F, Fut, T>(self, map_fn: F, concurrency: Concurrency) -> Pipeline<T>
where
F: FnMut(Out) -> Fut + Send + 'static,
Fut: Future<Output = Option<T>> + Send + 'static,
T: Send + 'static,
Out: Send + 'static,
{
self.pump(FilterMapPump {
map_fn,
concurrency,
})
}
pub fn enumerate(self) -> Pipeline<(usize, Out)> {
self.pump(crate::pumps::enumerate::EnumeratePump)
}
pub fn batch(self, n: usize) -> Pipeline<Vec<Out>> {
self.pump(crate::pumps::batch::BatchPump { n })
}
pub fn batch_while<F, State>(self, state_init: State, mut while_fn: F) -> Pipeline<Vec<Out>>
where
F: FnMut(State, &Out) -> Option<State> + Send + 'static,
State: Send + Clone + 'static,
{
self.pump(
crate::pumps::batch_while_with_expiry::BatchWhileWithExpiryPump {
state_init,
while_fn: move |state: State, x: &Out| {
let new_state = while_fn(state.clone(), x);
new_state.map(|new_state| (new_state, future::pending()))
},
},
)
}
pub fn batch_while_with_expiry<F, Fut, State>(
self,
state_init: State,
while_fn: F,
) -> Pipeline<Vec<Out>>
where
F: FnMut(State, &Out) -> Option<(State, Fut)> + Send + 'static,
Fut: Future<Output = ()> + Send,
State: Send + Clone + 'static,
{
self.pump(
crate::pumps::batch_while_with_expiry::BatchWhileWithExpiryPump {
state_init,
while_fn,
},
)
}
pub fn skip(self, n: usize) -> Pipeline<Out> {
self.pump(crate::pumps::skip::SkipPump { n })
}
pub fn take(self, n: usize) -> Pipeline<Out> {
self.pump(crate::pumps::take::TakePump { n })
}
pub fn backpressure(self, n: usize) -> Pipeline<Out> {
self.pump(crate::pumps::backpressure::Backpressure { n })
}
pub fn backpressure_with_relief_valve(self, n: usize) -> Pipeline<Out> {
self.pump(crate::pumps::backpressure_with_relief_valve::BackpressureWithReliefValve { n })
}
pub fn build(mut self) -> (Receiver<Out>, BoxFuture<'static, Result<(), JoinError>>) {
let join_result = async move {
while let Some(res) = self.handles.next().await {
match res {
Ok(_) => continue,
Err(e) => return Err(e),
}
}
Ok(())
};
(self.output_receiver, join_result.boxed())
}
pub fn abort(self) {
for handle in self.handles {
handle.abort();
}
}
}
impl<Err, Out> Pipeline<Result<Out, Err>>
where
Err: Send + Sync + 'static,
Out: Send + Sync + 'static,
{
pub fn catch(self, err_channel: tokio::sync::mpsc::Sender<Err>) -> Pipeline<Out> {
self.pump(CatchPump {
err_channel,
abort_on_error: false,
})
}
pub fn catch_abort(self, err_channel: tokio::sync::mpsc::Sender<Err>) -> Pipeline<Out> {
self.pump(CatchPump {
err_channel,
abort_on_error: true,
})
}
}
impl<Out> Pipeline<Pipeline<Out>>
where
Out: Send + Sync + 'static,
{
pub fn flatten(self, concurrency: FlattenConcurrency) -> Pipeline<Out> {
self.pump(FlattenPump { concurrency })
}
}
impl<Out, In: IntoIterator<Item = Out>> Pipeline<In>
where
In: Send + 'static,
Out: Send + Sync + 'static,
<In as IntoIterator>::IntoIter: Send,
{
pub fn flatten_iter(self) -> Pipeline<Out> {
self.pump(FlattenIterPump {})
}
}
impl<OutOk, OutErr> Pipeline<Result<OutOk, OutErr>> {
pub fn map_ok<F, Fut, T>(
self,
map_fn: F,
concurrency: Concurrency,
) -> Pipeline<Result<T, OutErr>>
where
F: Fn(OutOk) -> Fut + Send + 'static,
Fut: Future<Output = T> + Send,
T: Send + 'static,
OutErr: Send + 'static,
OutOk: Send + 'static,
{
self.pump(MapOkPump {
map_fn,
concurrency,
})
}
pub fn map_err<F, Fut, T>(
self,
map_fn: F,
concurrency: Concurrency,
) -> Pipeline<Result<OutOk, T>>
where
F: Fn(OutErr) -> Fut + Send + 'static,
Fut: Future<Output = T> + Send,
T: Send + 'static,
OutErr: Send + 'static,
OutOk: Send + 'static,
{
self.pump(MapErrPump {
map_fn,
concurrency,
})
}
pub fn and_then<F, Fut, T>(
self,
map_fn: F,
concurrency: Concurrency,
) -> Pipeline<Result<T, OutErr>>
where
F: Fn(OutOk) -> Fut + Send + 'static,
Fut: Future<Output = Result<T, OutErr>> + Send,
T: Send + 'static,
OutErr: Send + 'static,
OutOk: Send + 'static,
{
self.pump(AndThenPump {
map_fn,
concurrency,
})
}
pub fn try_filter_map<F, Fut, T>(
self,
map_fn: F,
concurrency: Concurrency,
) -> Pipeline<Result<T, OutErr>>
where
F: FnMut(OutOk) -> Fut + Send + 'static,
Fut: Future<Output = Result<Option<T>, OutErr>> + Send,
T: Send + 'static,
OutErr: Send + 'static,
OutOk: Send + 'static,
{
self.pump(TryFilterMapPump {
map_fn,
concurrency,
})
}
}
#[cfg(test)]
mod tests {
use futures::{stream, SinkExt};
use super::*;
async fn async_job(x: i32) -> i32 {
x
}
async fn async_filter_map(x: i32) -> Option<i32> {
if x % 2 == 0 {
Some(x)
} else {
None
}
}
#[tokio::test]
async fn test_pipeline() {
let (input_sender, input_receiver) = mpsc::channel(100);
let pipeline = Pipeline::from(input_receiver)
.map(async_job, Concurrency::concurrent_unordered(2))
.backpressure(100)
.map(async_job, Concurrency::concurrent_unordered(2))
.filter_map(async_filter_map, Concurrency::serial());
let (mut output_receiver, join_handle) = pipeline.build();
input_sender.send(1).await.unwrap();
input_sender.send(2).await.unwrap();
input_sender.send(3).await.unwrap();
input_sender.send(4).await.unwrap();
assert_eq!(output_receiver.recv().await, Some(2));
assert_eq!(output_receiver.recv().await, Some(4));
drop(input_sender);
assert_eq!(output_receiver.recv().await, None);
assert!(matches!(join_handle.await, Ok(())));
}
#[tokio::test]
async fn panic_handling() {
let (input_sender, input_receiver) = mpsc::channel(100);
let (mut output_receiver, join_handle) = Pipeline::from(input_receiver)
.map(async_job, Concurrency::concurrent_unordered(2))
.backpressure(100)
.map(
|x| async move {
if x == 2 {
panic!("2 is not supported");
}
x
},
Concurrency::concurrent_unordered(2),
)
.build();
input_sender.send(1).await.unwrap();
input_sender.send(2).await.unwrap();
input_sender.send(3).await.unwrap();
assert_eq!(output_receiver.recv().await, Some(1));
assert_eq!(output_receiver.recv().await, None);
assert_eq!(output_receiver.recv().await, None);
let res = join_handle.await;
assert!(res.is_err());
}
#[tokio::test]
async fn test_from_stream() {
let stream = stream::iter(vec![1, 2, 3]);
let pipeline = Pipeline::from_stream(stream).map(async_job, Concurrency::serial());
let mut output_receiver = pipeline.output_receiver;
assert_eq!(output_receiver.recv().await, Some(1));
assert_eq!(output_receiver.recv().await, Some(2));
assert_eq!(output_receiver.recv().await, Some(3));
assert_eq!(output_receiver.recv().await, None);
}
#[tokio::test]
async fn test_from_futures_channel() {
let (mut sender, receiver) = futures::channel::mpsc::channel(100);
sender.send(1).await.unwrap();
sender.send(2).await.unwrap();
sender.send(3).await.unwrap();
let pipeline = Pipeline::from_stream(receiver).map(async_job, Concurrency::serial());
let mut output_receiver = pipeline.output_receiver;
assert_eq!(output_receiver.recv().await, Some(1));
assert_eq!(output_receiver.recv().await, Some(2));
assert_eq!(output_receiver.recv().await, Some(3));
drop(sender);
assert_eq!(output_receiver.recv().await, None);
}
#[tokio::test]
async fn test_from_iter() {
let iter = vec![1, 2, 3];
let pipeline = Pipeline::from_iter(iter).map(async_job, Concurrency::serial());
let mut output_receiver = pipeline.output_receiver;
assert_eq!(output_receiver.recv().await, Some(1));
assert_eq!(output_receiver.recv().await, Some(2));
assert_eq!(output_receiver.recv().await, Some(3));
assert_eq!(output_receiver.recv().await, None);
}
#[tokio::test]
async fn test_custom_pump() {
pub struct CustomPump;
impl<In> Pump<In, In> for CustomPump
where
In: Send + Sync + Clone + 'static,
{
fn spawn(self, mut input_receiver: Receiver<In>) -> (Receiver<In>, JoinHandle<()>) {
let (output_sender, output_receiver) = mpsc::channel(1);
let h = tokio::spawn(async move {
while let Some(input) = input_receiver.recv().await {
if let Err(_e) = output_sender.send(input.clone()).await {
break;
}
if let Err(_e) = output_sender.send(input).await {
break;
}
}
});
(output_receiver, h)
}
}
let (input_sender, input_receiver) = mpsc::channel(100);
let (mut output_receiver, join_handle) =
Pipeline::from(input_receiver).pump(CustomPump).build();
input_sender.send(1).await.unwrap();
assert_eq!(output_receiver.recv().await, Some(1));
assert_eq!(output_receiver.recv().await, Some(1));
drop(input_sender);
assert_eq!(output_receiver.recv().await, None);
join_handle.await.unwrap();
}
}