futures_ext/stream/
yield_periodically.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under both the MIT license found in the
5 * LICENSE-MIT file in the root directory of this source tree and the Apache
6 * License, Version 2.0 found in the LICENSE-APACHE file in the root directory
7 * of this source tree.
8 */
9
10use std::pin::Pin;
11use std::time::Duration;
12use std::time::Instant;
13
14use futures::stream::Stream;
15use futures::task::Context;
16use futures::task::Poll;
17use pin_project::pin_project;
18
19/// A stream that will yield control back to the caller if it runs for more than a given duration
20/// without yielding (i.e. returning Poll::Pending).  The clock starts counting the first time the
21/// stream is polled, and is reset every time the stream yields.
22#[pin_project]
23pub struct YieldPeriodically<S> {
24    #[pin]
25    inner: S,
26    /// Default budget.
27    budget: Duration,
28    /// Budget left for the current iteration.
29    current_budget: Duration,
30    /// Whether the next iteration must yield because the budget was exceeded.
31    must_yield: bool,
32}
33
34impl<S> YieldPeriodically<S> {
35    /// Create a new [YieldPeriodically].
36    pub fn new(inner: S, budget: Duration) -> Self {
37        Self {
38            inner,
39            budget,
40            current_budget: budget,
41            must_yield: false,
42        }
43    }
44}
45
46impl<S: Stream> Stream for YieldPeriodically<S> {
47    type Item = <S as Stream>::Item;
48
49    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
50        let this = self.project();
51
52        if *this.must_yield {
53            *this.must_yield = false;
54            cx.waker().wake_by_ref();
55            return Poll::Pending;
56        }
57
58        let now = Instant::now();
59        let res = this.inner.poll_next(cx);
60
61        if res.is_pending() {
62            *this.current_budget = *this.budget;
63            return res;
64        }
65
66        let elapsed = now.elapsed();
67
68        match this.current_budget.checked_sub(elapsed) {
69            Some(new_budget) => *this.current_budget = new_budget,
70            None => {
71                *this.must_yield = true;
72                *this.current_budget = *this.budget;
73            }
74        };
75
76        res
77    }
78}
79
80#[cfg(test)]
81mod test {
82    use futures::stream::StreamExt;
83
84    use super::*;
85
86    #[test]
87    fn test_yield_happens() {
88        let stream = futures::stream::repeat(()).inspect(|_| {
89            // Simulate CPU work
90            std::thread::sleep(Duration::from_millis(1));
91        });
92
93        let stream = YieldPeriodically::new(stream, Duration::from_millis(100));
94
95        futures::pin_mut!(stream);
96
97        let now = Instant::now();
98
99        let waker = futures::task::noop_waker();
100        let mut cx = futures::task::Context::from_waker(&waker);
101
102        while stream.as_mut().poll_next(&mut cx).is_ready() {
103            assert!(
104                now.elapsed() < Duration::from_millis(200),
105                "Stream did not yield in time"
106            );
107        }
108
109        let now = Instant::now();
110        let mut did_unpause = false;
111
112        while stream.as_mut().poll_next(&mut cx).is_ready() {
113            did_unpause = true;
114
115            assert!(
116                now.elapsed() < Duration::from_millis(200),
117                "Stream did not yield in time"
118            );
119        }
120
121        assert!(did_unpause, "Stream did not unpause");
122    }
123
124    #[tokio::test]
125    async fn test_yield_registers_for_wakeup() {
126        // This will hang if the stream doesn't register
127        let stream = futures::stream::repeat(())
128            .inspect(|_| {
129                // Simulate CPU work
130                std::thread::sleep(Duration::from_millis(1));
131            })
132            .take(30);
133
134        let stream = YieldPeriodically::new(stream, Duration::from_millis(10));
135        stream.collect::<Vec<_>>().await;
136    }
137}