1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
// SPDX-FileCopyrightText: The futures-stream-ext authors
// SPDX-License-Identifier: MPL-2.0

use std::{
    num::NonZeroUsize,
    pin::Pin,
    task::{ready, Context, Poll},
};

use futures_core::stream::Stream;
use pin_project_lite::pin_project;

mod interval;
#[cfg(feature = "tokio")]
pub use self::interval::IntervalThrottler;
pub use self::interval::ThrottleIntervalConfig;

/// Callbacks for throttling a stream
pub trait Throttler<T>: Stream<Item = ()> {
    /// A new item has been received from the input stream.
    ///
    /// After invocation of this method `throttle_ready` will be called
    /// with `Some` when the current interval has elapsed. The current
    /// item that will be yielded may still change until `throttle_ready`
    /// is called.
    ///
    /// The `cx` argument is only provided for consistency and can safely
    /// be ignored in most cases. The throttler will be polled immediately
    /// after this method returns.
    fn throttle_pending(self: Pin<&mut Self>, cx: &mut Context<'_>);

    /// The current interval has elapsed.
    ///
    /// Provides the pending item of the throttled input stream that is ready
    /// or `None` if no item has been received from the input stream during
    /// the last interval.
    fn throttle_ready(self: Pin<&mut Self>, cx: &mut Context<'_>, next_item: Option<&T>);
}

/// Internal state.
#[derive(Debug, Clone, Copy)]
enum State {
    Streaming,
    Finishing,
    Finished,
}

pin_project! {
    /// Throttled stream
    #[derive(Debug)]
    #[must_use = "streams do nothing unless polled or .awaited"]
    pub struct Throttle<S: Stream, T: Throttler<<S as Stream>::Item>> {
        #[pin]
        stream: S,
        #[pin]
        throttler: T,
        poll_next_max_ready_count: NonZeroUsize,
        state: State,
        pending: Option<S::Item>,
    }
}

impl<S, T> Throttle<S, T>
where
    S: Stream,
    T: Throttler<<S as Stream>::Item>,
{
    #[allow(clippy::needless_pass_by_value)]
    pub const fn new(stream: S, throttler: T, poll_next_max_ready_count: NonZeroUsize) -> Self {
        Self {
            stream,
            throttler,
            poll_next_max_ready_count,
            state: State::Streaming,
            pending: None,
        }
    }
}

impl<S, T> Stream for Throttle<S, T>
where
    S: Stream,
    T: Throttler<<S as Stream>::Item>,
{
    type Item = S::Item;

    #[allow(unsafe_code)]
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let mut this = self.project();

        if matches!(this.state, State::Streaming) {
            // Poll the inner stream while it yields items. We want to receive
            // the most recent item that is ready.
            let mut ready_count = 0;
            loop {
                match this.stream.as_mut().poll_next(cx) {
                    Poll::Ready(Some(item)) => {
                        if this.pending.is_none() {
                            this.throttler.as_mut().throttle_pending(cx);
                        }
                        *this.pending = Some(item);
                        debug_assert!(ready_count < this.poll_next_max_ready_count.get());
                        ready_count += 1;
                        if ready_count >= this.poll_next_max_ready_count.get() {
                            // Stop polling the inner stream to prevent endless loops
                            // for streams that are always ready.
                            // Wake ourselves up to ensure that polling the stream continues after
                            // after polling the throttler.
                            cx.waker().wake_by_ref();
                            break;
                        }
                    }
                    Poll::Ready(None) => {
                        *this.state = State::Finishing;
                        break;
                    }
                    Poll::Pending => {
                        break;
                    }
                };
            }
        }

        // Poll the throttler.
        match this.state {
            State::Streaming => {
                ready!(this.throttler.as_mut().poll_next(cx));
                let next_item = this.pending.take();
                this.throttler
                    .as_mut()
                    .throttle_ready(cx, next_item.as_ref());
                if next_item.is_some() {
                    Poll::Ready(next_item)
                } else {
                    Poll::Pending
                }
            }
            State::Finishing => {
                if this.pending.is_some() {
                    ready!(this.throttler.as_mut().poll_next(cx));
                    let last_item = this.pending.take();
                    // Wake ourselves up for the final state transition from `Finishing`
                    // to `Finished` that becomes ready immediately.
                    cx.waker().wake_by_ref();
                    Poll::Ready(last_item)
                } else {
                    // The final state transition.
                    *this.state = State::Finished;
                    Poll::Ready(None)
                }
            }
            State::Finished => panic!("stream polled after completion"),
        }
    }
}