fast_ordered_buffer/
lib.rs

1#![doc = include_str!("../README.md")]
2#![deny(clippy::all, clippy::pedantic)]
3#![allow(clippy::uninlined_format_args)]
4
5use std::cmp::Ordering;
6use std::collections::BinaryHeap;
7use std::num::NonZeroUsize;
8use std::{
9    future::Future,
10    pin::Pin,
11    task::{Context, Poll},
12};
13
14use futures::Sink;
15use futures::stream::{Fuse, FuturesUnordered, Stream, StreamExt};
16use pin_project_lite::pin_project;
17
18/// A wrapper struct for heap entries so the smallest ID is considered the "largest" item in the
19/// [`BinaryHeap`] (and thus popped first).
20struct Pending<O> {
21    id: usize,
22    output: O,
23}
24
25#[cfg_attr(test, mutants::skip)]
26impl<O> PartialEq for Pending<O> {
27    fn eq(&self, other: &Self) -> bool {
28        self.id == other.id
29    }
30}
31
32impl<O> Eq for Pending<O> {}
33
34impl<O> Ord for Pending<O> {
35    fn cmp(&self, other: &Self) -> Ordering {
36        // We flip the comparison so that lower IDs have higher priority.
37        other.id.cmp(&self.id)
38    }
39}
40
41impl<O> PartialOrd for Pending<O> {
42    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
43        Some(self.cmp(other))
44    }
45}
46
47pin_project! {
48    pub struct IdentifiableFuture<Fut> {
49        id: usize,
50        #[pin]
51        fut: Fut,
52    }
53}
54
55impl<F> Future for IdentifiableFuture<F>
56where
57    F: Future,
58{
59    type Output = (usize, F::Output);
60
61    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
62        let this = self.project();
63        this.fut.poll(cx).map(|x| (*this.id, x))
64    }
65}
66
67pin_project! {
68    pub struct FastBufferOrdered<St>
69    where
70        St: Stream,
71        St::Item: Future,
72    {
73        #[pin]
74        stream: Fuse<St>,
75        in_progress_queue: FuturesUnordered<IdentifiableFuture<St::Item>>,
76        max: Option<NonZeroUsize>,
77        next_id: usize,
78        pending_release: BinaryHeap<Pending<<St::Item as Future>::Output>>,
79        waiting_for: usize,
80    }
81}
82
83impl<St> FastBufferOrdered<St>
84where
85    St: Stream,
86    St::Item: Future,
87{
88    pub fn new(stream: St, n: Option<usize>) -> Self {
89        Self {
90            stream: stream.fuse(),
91            in_progress_queue: FuturesUnordered::new(),
92            max: n.and_then(NonZeroUsize::new),
93            next_id: 0,
94            pending_release: BinaryHeap::new(),
95            waiting_for: 0,
96        }
97    }
98}
99
100impl<St> Stream for FastBufferOrdered<St>
101where
102    St: Stream,
103    St::Item: Future,
104{
105    type Item = <St::Item as Future>::Output;
106
107    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
108        let mut this = self.project();
109
110        // First up, try to spawn off as many futures as possible by filling up
111        // our queue of futures.
112        while this
113            .max
114            .map(|max| this.in_progress_queue.len() < max.get())
115            .unwrap_or(true)
116        {
117            match this.stream.as_mut().poll_next(cx) {
118                Poll::Ready(Some(fut)) => {
119                    let fut = IdentifiableFuture {
120                        id: *this.next_id,
121                        fut,
122                    };
123                    this.in_progress_queue.push(fut);
124                    *this.next_id += 1;
125                }
126                Poll::Ready(None) | Poll::Pending => break,
127            }
128        }
129
130        // Attempt to pull the next value from the in_progress_queue
131        while let Poll::Ready(Some((id, output))) = this.in_progress_queue.poll_next_unpin(cx) {
132            if id == *this.waiting_for {
133                *this.waiting_for += 1;
134                return Poll::Ready(Some(output));
135            }
136            this.pending_release.push(Pending { id, output });
137        }
138
139        if let Some(next) = this.pending_release.peek() {
140            if next.id == *this.waiting_for {
141                *this.waiting_for += 1;
142                return Poll::Ready(Some(this.pending_release.pop().unwrap().output));
143            }
144        }
145
146        // If more values are still coming from the stream, we're not done yet
147        if this.stream.is_done() && this.in_progress_queue.is_empty() {
148            Poll::Ready(None)
149        } else {
150            Poll::Pending
151        }
152    }
153
154    fn size_hint(&self) -> (usize, Option<usize>) {
155        let queue_len = self.in_progress_queue.len() + self.pending_release.len();
156        let (lower, upper) = self.stream.size_hint();
157        let lower = lower.saturating_add(queue_len);
158        let upper = match upper {
159            Some(x) => x.checked_add(queue_len),
160            None => None,
161        };
162        (lower, upper)
163    }
164}
165
166pub trait FobStreamExt: Stream {
167    fn fast_ordered_buffer(self, n: usize) -> FastBufferOrdered<Self>
168    where
169        Self: Sized,
170        Self::Item: Future,
171    {
172        FastBufferOrdered::new(self, Some(n))
173    }
174}
175impl<T: Stream> FobStreamExt for T {}
176
177// Forwarding impl of Sink from the underlying stream
178#[cfg_attr(test, mutants::skip)]
179impl<S, Item> Sink<Item> for FastBufferOrdered<S>
180where
181    S: Stream + Sink<Item>,
182    S::Item: Future,
183{
184    type Error = S::Error;
185
186    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
187        self.project().stream.poll_ready(cx)
188    }
189
190    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
191        self.project().stream.start_send(item)
192    }
193
194    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
195        self.project().stream.poll_flush(cx)
196    }
197
198    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
199        self.project().stream.poll_close(cx)
200    }
201}