use atomic_refcell::AtomicRefCell;
use futures::future::{Fuse, FusedFuture, join};
use futures::stream::{FusedStream, FuturesOrdered, FuturesUnordered};
use futures::{FutureExt, Stream, StreamExt};
use std::collections::VecDeque;
use std::pin::{Pin, pin};
use std::sync::Arc;
use std::task::Waker;
use std::task::{Context, Poll, Poll::Pending, Poll::Ready};
#[derive(Debug)]
struct BufferInner<T> {
items: VecDeque<T>,
limit: usize,
is_closed: bool,
sender_waker: Option<Waker>,
}
#[derive(Debug)]
struct Buffer<T>(Arc<AtomicRefCell<BufferInner<T>>>);
impl<T> Buffer<T> {
fn new(limit: usize) -> Self {
Self(Arc::new(AtomicRefCell::new(BufferInner {
items: VecDeque::new(),
limit,
is_closed: false,
sender_waker: None,
})))
}
fn push(&self, item: T) {
let mut this = self.0.borrow_mut();
assert!(this.items.len() < this.limit);
assert!(!this.is_closed);
this.items.push_back(item);
}
}
impl<T> Clone for Buffer<T> {
fn clone(&self) -> Self {
Self(Arc::clone(&self.0))
}
}
trait TypeErasedBuffer {
fn len(&self) -> usize;
fn is_empty(&self) -> bool;
fn close(&self);
fn register_sender_waker(&self, waker: Waker);
}
impl<T> TypeErasedBuffer for Buffer<T> {
fn len(&self) -> usize {
self.0.borrow().items.len()
}
fn is_empty(&self) -> bool {
self.len() == 0
}
fn close(&self) {
self.0.borrow_mut().is_closed = true;
}
fn register_sender_waker(&self, waker: Waker) {
let mut this = self.0.borrow_mut();
assert!(!this.items.is_empty());
this.sender_waker = Some(waker);
}
}
impl<T> Stream for Buffer<T> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.0.borrow_mut();
if let Some(waker) = this.sender_waker.take() {
assert!(!this.items.is_empty());
waker.wake();
}
if let Some(item) = this.items.pop_front() {
Ready(Some(item))
} else if this.is_closed {
Ready(None)
} else {
Pending
}
}
}
impl<T> FusedStream for Buffer<T> {
fn is_terminated(&self) -> bool {
let this = self.0.borrow();
this.items.is_empty() && this.is_closed
}
}
pin_project_lite::pin_project! {
struct PipelineStage<Fut, T> {
#[pin]
future: Fuse<Fut>,
outputs: Buffer<T>,
}
}
trait TypeErasedStage {
fn future(self: Pin<&mut Self>) -> Pin<&mut dyn Future<Output = ()>>;
fn is_done(&self) -> bool;
fn outputs_buffer(&self) -> &dyn TypeErasedBuffer;
}
impl<Fut: Future<Output = ()>, T> TypeErasedStage for PipelineStage<Fut, T> {
fn future(self: Pin<&mut Self>) -> Pin<&mut dyn Future<Output = ()>> {
self.project().future
}
fn is_done(&self) -> bool {
self.future.is_terminated()
}
fn outputs_buffer(&self) -> &dyn TypeErasedBuffer {
&self.outputs
}
}
fn poll_stages<'a>(
mut stages: Vec<Pin<Box<dyn TypeErasedStage + 'a>>>,
) -> impl Future<Output = ()> + 'a {
std::future::poll_fn(move |cx| {
if stages.is_empty() {
return Ready(());
}
for i in 0..stages.len() {
let (prev_slice, rest_slice) = stages.split_at_mut(i);
let previous_stage = prev_slice.last().map(|s| &**s);
let inputs = previous_stage.map(TypeErasedStage::outputs_buffer);
let current_stage = &mut rest_slice[0];
if current_stage.as_mut().future().poll(cx).is_ready() {
assert!(
previous_stage.is_none_or(TypeErasedStage::is_done),
"later stage ({}) finished before previous ({})",
i,
i - 1,
);
assert!(
inputs.is_none_or(TypeErasedBuffer::is_empty),
"stage finished with leftover inputs"
);
}
}
if stages.last().unwrap().is_done() {
Ready(())
} else {
Pending
}
})
}
enum Executor<Fut: Future> {
Ordered(FuturesOrdered<Fut>),
Unordered(FuturesUnordered<Fut>),
}
enum ExecutorKind {
Ordered,
Unordered,
}
impl<Fut: Future> Executor<Fut> {
fn new(kind: ExecutorKind) -> Self {
match kind {
ExecutorKind::Ordered => Self::Ordered(FuturesOrdered::new()),
ExecutorKind::Unordered => Self::Unordered(FuturesUnordered::new()),
}
}
fn len(&self) -> usize {
match self {
Self::Ordered(futures) => futures.len(),
Self::Unordered(futures) => futures.len(),
}
}
fn push(&mut self, fut: Fut) {
match self {
Self::Ordered(futures) => {
futures.push_back(fut);
}
Self::Unordered(futures) => {
futures.push(fut);
}
}
}
fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Fut::Output>> {
match self {
Self::Ordered(futures) => Pin::new(futures).poll_next(cx),
Self::Unordered(futures) => Pin::new(futures).poll_next(cx),
}
}
}
fn filter_map<T, U, S, F, Fut>(
mut inputs: Pin<&mut S>,
mut f: F,
outputs: Buffer<U>,
limit: usize,
kind: ExecutorKind,
) -> impl Future<Output = ()>
where
S: Stream<Item = T> + FusedStream,
F: FnMut(T) -> Fut,
Fut: Future<Output = Option<U>>,
{
let mut executor = Executor::new(kind);
std::future::poll_fn(move |cx| {
loop {
let mut keep_looping = false;
if executor.len() + outputs.len() < limit {
if let Ready(Some(input)) = inputs.as_mut().poll_next(cx) {
executor.push(f(input));
keep_looping = true;
}
} else if outputs.len() > 0 {
outputs.register_sender_waker(cx.waker().clone());
}
if let Ready(Some(maybe_output)) = executor.poll_next(cx) {
if let Some(output) = maybe_output {
outputs.push(output);
}
keep_looping = true;
}
if !keep_looping {
break;
}
}
if inputs.is_terminated() && executor.len() == 0 {
outputs.close();
Ready(())
} else {
Pending
}
})
}
pub struct AsyncPipeline<'a, S: Stream + 'a> {
outputs: S,
stages: Vec<Pin<Box<dyn TypeErasedStage + 'a>>>,
}
impl<'a, I: Iterator> AsyncPipeline<'a, futures::stream::Iter<I>> {
pub fn from_iter(iter: impl IntoIterator<IntoIter = I>) -> Self {
Self::from_stream(futures::stream::iter(iter))
}
}
impl<'a, S: Stream> AsyncPipeline<'a, S> {
pub fn from_stream(stream: S) -> Self {
Self {
outputs: stream,
stages: Vec::new(),
}
}
pub async fn for_each(self, mut f: impl AsyncFnMut(S::Item)) {
join(poll_stages(self.stages), async {
let mut outputs = pin!(self.outputs);
while let Some(item) = outputs.next().await {
f(item).await;
}
})
.await;
}
pub async fn for_each_concurrent(self, f: impl AsyncFn(S::Item), limit: usize) {
let mut inputs = pin!(self.outputs.fuse());
let mut executor = FuturesUnordered::new();
join(poll_stages(self.stages), async {
loop {
let mut keep_looping = false;
if executor.len() < limit
&& let Ready(Some(input)) = futures::poll!(inputs.next())
{
executor.push(f(input));
keep_looping = true;
}
if let Ready(Some(())) = futures::poll!(executor.next()) {
keep_looping = true;
}
if keep_looping {
continue;
} else if inputs.is_terminated() && executor.is_empty() {
return;
} else {
futures::pending!();
}
}
})
.await;
}
pub async fn collect<C: Default + Extend<S::Item>>(self) -> C {
let mut collection = C::default();
self.for_each(async |item| {
collection.extend(std::iter::once(item));
})
.await;
collection
}
pub fn adapt_stream<F, S2>(self, f: F) -> AsyncPipeline<'a, S2>
where
F: FnOnce(S) -> S2,
S2: Stream,
{
AsyncPipeline {
outputs: f(self.outputs),
stages: self.stages,
}
}
pub fn map_concurrent<F, Fut, U>(
self,
mut f: F,
limit: usize,
) -> AsyncPipeline<'a, impl Stream<Item = U>>
where
F: FnMut(S::Item) -> Fut + 'a,
Fut: Future<Output = U> + 'a,
U: 'a,
{
self.filter_map_concurrent(
move |item| {
let fut = f(item);
async { Some(fut.await) }
},
limit,
)
}
pub fn map_unordered<F, Fut, U>(
self,
mut f: F,
limit: usize,
) -> AsyncPipeline<'a, impl Stream<Item = U>>
where
F: FnMut(S::Item) -> Fut + 'a,
Fut: Future<Output = U> + 'a,
U: 'a,
{
self.filter_map_unordered(
move |item| {
let fut = f(item);
async { Some(fut.await) }
},
limit,
)
}
fn filter_map_inner<F, Fut, U>(
mut self,
f: F,
limit: usize,
kind: ExecutorKind,
) -> AsyncPipeline<'a, impl Stream<Item = U>>
where
F: FnMut(S::Item) -> Fut + 'a,
Fut: Future<Output = Option<U>> + 'a,
U: 'a,
{
let buffer = Buffer::<U>::new(limit);
let buffer_clone = buffer.clone();
self.stages.push(Box::pin(PipelineStage {
outputs: buffer.clone(),
future: async move {
let inputs = pin!(self.outputs.fuse());
filter_map(inputs, f, buffer_clone, limit, kind).await;
}
.fuse(),
}));
AsyncPipeline {
outputs: buffer,
stages: self.stages,
}
}
pub fn filter_map_concurrent<F, Fut, U>(
self,
f: F,
limit: usize,
) -> AsyncPipeline<'a, impl Stream<Item = U>>
where
F: FnMut(S::Item) -> Fut + 'a,
Fut: Future<Output = Option<U>> + 'a,
U: 'a,
{
self.filter_map_inner(f, limit, ExecutorKind::Ordered)
}
pub fn filter_map_unordered<F, Fut, U>(
self,
f: F,
limit: usize,
) -> AsyncPipeline<'a, impl Stream<Item = U>>
where
F: FnMut(S::Item) -> Fut + 'a,
Fut: Future<Output = Option<U>> + 'a,
U: 'a,
{
self.filter_map_inner(f, limit, ExecutorKind::Unordered)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering::Relaxed};
use tokio::sync::Mutex;
use tokio::sync::mpsc::unbounded_channel;
use tokio::time::{Duration, sleep};
use tokio_stream::wrappers::UnboundedReceiverStream;
#[tokio::test]
async fn test_for_each() {
let inputs = [0, 1, 2, 3, 4];
let mut v = Vec::new();
AsyncPipeline::from_iter(&inputs)
.adapt_stream(|s| {
s.then(async |x| {
sleep(Duration::from_millis(1)).await;
x + 1
})
})
.map_concurrent(
async |x| {
sleep(Duration::from_millis(1)).await;
10 * x
},
3,
)
.for_each(async |x| {
v.push(x);
})
.await;
assert_eq!(v, vec![10, 20, 30, 40, 50]);
}
#[tokio::test]
async fn test_for_each_concurrent() {
let v = Mutex::new(Vec::new());
AsyncPipeline::from_iter(0..5)
.adapt_stream(|s| {
s.then(async |x| {
sleep(Duration::from_millis(1)).await;
x + 1
})
})
.map_concurrent(
async |x| {
sleep(Duration::from_millis(1)).await;
10 * x
},
3,
)
.for_each_concurrent(
async |x| {
v.lock().await.push(x);
},
3,
)
.await;
assert_eq!(v.into_inner(), vec![10, 20, 30, 40, 50]);
}
#[tokio::test]
async fn test_collect() {
let v: Vec<_> = AsyncPipeline::from_iter(0..5)
.adapt_stream(|s| {
s.then(async |x| {
sleep(Duration::from_millis(1)).await;
x + 1
})
.then(async |x| {
sleep(Duration::from_millis(1)).await;
10 * x
})
})
.collect()
.await;
assert_eq!(v, vec![10, 20, 30, 40, 50]);
}
#[tokio::test]
async fn test_max_in_flight() {
static ELEMENTS_IN_FLIGHT: AtomicU32 = AtomicU32::new(0);
let mut i = 0;
AsyncPipeline::from_iter(std::iter::from_fn(|| {
if i < 10 {
let in_flight = ELEMENTS_IN_FLIGHT.fetch_add(1, Relaxed);
assert_eq!(in_flight, 0, "too many elements in flight at i = {i}");
i += 1;
Some(i)
} else {
None
}
}))
.for_each(async |i| {
let in_flight = ELEMENTS_IN_FLIGHT.fetch_sub(1, Relaxed);
assert_eq!(in_flight, 1, "too many elements in flight at i = {i}");
sleep(Duration::from_millis(1)).await;
})
.await;
}
#[tokio::test]
async fn test_map_concurrent() {
use std::sync::atomic::{AtomicU32, Ordering::Relaxed};
static FUTURES_IN_FLIGHT: AtomicU32 = AtomicU32::new(0);
static MAX_IN_FLIGHT: AtomicU32 = AtomicU32::new(0);
let v: Vec<i32> = AsyncPipeline::from_iter(0..10)
.map_concurrent(
async |i| {
let in_flight = FUTURES_IN_FLIGHT.fetch_add(1, Relaxed);
MAX_IN_FLIGHT.fetch_max(in_flight + 1, Relaxed);
sleep(Duration::from_millis(1)).await;
FUTURES_IN_FLIGHT.fetch_sub(1, Relaxed);
2 * i
},
3,
)
.collect()
.await;
assert_eq!(v, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
assert_eq!(MAX_IN_FLIGHT.load(Relaxed), 3);
}
#[tokio::test]
async fn test_map_unordered() {
use std::sync::atomic::{AtomicU32, Ordering::Relaxed};
static FUTURES_IN_FLIGHT: AtomicU32 = AtomicU32::new(0);
static MAX_IN_FLIGHT: AtomicU32 = AtomicU32::new(0);
let v: Vec<i32> = AsyncPipeline::from_iter(0..10)
.map_unordered(
async |i| {
let in_flight = FUTURES_IN_FLIGHT.fetch_add(1, Relaxed);
MAX_IN_FLIGHT.fetch_max(in_flight + 1, Relaxed);
sleep(Duration::from_millis(1)).await;
FUTURES_IN_FLIGHT.fetch_sub(1, Relaxed);
2 * i
},
3,
)
.collect()
.await;
assert_eq!(v, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
assert_eq!(MAX_IN_FLIGHT.load(Relaxed), 3);
}
#[tokio::test]
async fn test_deadlocks() {
async fn foo(i: i32) -> i32 {
static LOCK: Mutex<()> = Mutex::const_new(());
println!("locking foo({i})");
let _guard = LOCK.lock();
println!("sleeping foo({i})");
sleep(Duration::from_millis(rand::random_range(0..10))).await;
println!("waking foo({i})");
i + 1
}
let v = Mutex::new(Vec::new());
AsyncPipeline::from_iter(0..100)
.map_concurrent(async |i| foo(i).await, 10)
.map_unordered(async |i| foo(i).await, 10)
.filter_map_concurrent(async |i| Some(foo(i).await), 10)
.filter_map_unordered(async |i| Some(foo(i).await), 10)
.for_each_concurrent(
async |i| {
futures::join!(foo(i), foo(i), foo(i), foo(i), foo(i));
v.lock().await.push(foo(i).await);
},
10,
)
.await;
let mut v = v.into_inner();
v.sort();
assert_eq!(v[..], (5..105).collect::<Vec<_>>());
}
#[tokio::test]
async fn test_channel() {
let lock = Mutex::new(());
let atomic1 = AtomicU32::new(0);
let atomic2 = AtomicU32::new(0);
let num_jobs: usize = 1000;
let (sender, receiver) = unbounded_channel::<()>();
let pipeline = AsyncPipeline::from_stream(UnboundedReceiverStream::new(receiver))
.for_each_concurrent(
async |_| {
atomic1.fetch_add(1, Relaxed);
let _guard = lock.lock().await;
atomic2.fetch_add(1, Relaxed);
},
num_jobs,
);
join(pipeline, async {
let _guard = lock.lock().await;
for _ in 0..num_jobs {
sender.send(()).unwrap();
}
while atomic1.load(Relaxed) != num_jobs as u32 {
sleep(Duration::from_millis(1)).await;
}
assert_eq!(atomic2.load(Relaxed), 0);
drop(_guard);
while atomic2.load(Relaxed) != num_jobs as u32 {
sleep(Duration::from_millis(1)).await;
}
drop(sender);
})
.await;
}
#[tokio::test]
async fn test_dont_preallocate_buffers() {
AsyncPipeline::from_iter(0..10)
.map_concurrent(async |x| x + 1, usize::MAX)
.for_each(async |_| {})
.await;
}
}