futures_ext/stream/
stream_with_timeout.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under both the MIT license found in the
5 * LICENSE-MIT file in the root directory of this source tree and the Apache
6 * License, Version 2.0 found in the LICENSE-APACHE file in the root directory
7 * of this source tree.
8 */
9
10use std::pin::Pin;
11use std::time::Duration;
12
13use futures::future::Future;
14use futures::stream::Stream;
15use futures::task::Context;
16use futures::task::Poll;
17use pin_project::pin_project;
18use thiserror::Error;
19use tokio::time::Sleep;
20
21/// Error returned when a StreamWithTimeout exceeds its deadline.
22#[derive(Debug, Error)]
23#[error("Stream timeout with duration {:?} was exceeded", .0)]
24pub struct StreamTimeoutError(Duration);
25
26/// A stream that must finish within a given duration, or it will error during poll (i.e. it must
27/// yield None). The clock starts counting the first time the stream is polled.
28#[pin_project]
29pub struct StreamWithTimeout<S> {
30    #[pin]
31    inner: S,
32    duration: Duration,
33    done: bool,
34    #[pin]
35    deadline: Option<Sleep>,
36}
37
38impl<S> StreamWithTimeout<S> {
39    /// Create a new [StreamWithTimeout].
40    pub fn new(inner: S, duration: Duration) -> Self {
41        Self {
42            inner,
43            duration,
44            done: false,
45            deadline: None,
46        }
47    }
48}
49
50impl<S: Stream> Stream for StreamWithTimeout<S> {
51    type Item = Result<<S as Stream>::Item, StreamTimeoutError>;
52
53    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
54        let mut this = self.project();
55
56        if *this.done {
57            return Poll::Ready(None);
58        }
59
60        let duration = *this.duration;
61
62        if this.deadline.is_none() {
63            this.deadline.set(Some(tokio::time::sleep(duration)));
64        }
65
66        // NOTE: This unwrap() is safe as we just set the value.
67        match this.deadline.as_pin_mut().unwrap().poll(cx) {
68            Poll::Ready(()) => {
69                *this.done = true;
70                return Poll::Ready(Some(Err(StreamTimeoutError(duration))));
71            }
72            Poll::Pending => {
73                // Continue
74            }
75        }
76
77        // Keep track of whether the stream has finished, so that we don't attempt to poll the
78        // deadline later if the stream has indeed finished already.
79        let res = futures::ready!(this.inner.poll_next(cx));
80        if res.is_none() {
81            *this.done = true;
82        }
83
84        Poll::Ready(Ok(res).transpose())
85    }
86}
87
88#[cfg(test)]
89mod test {
90    use anyhow::Error;
91    use futures::stream::StreamExt;
92    use futures::stream::TryStreamExt;
93
94    use super::*;
95
96    #[tokio::test]
97    async fn test_stream_timeout() -> Result<(), Error> {
98        tokio::time::pause();
99
100        let s = async_stream::stream! {
101            yield Result::<(), Error>::Ok(());
102            tokio::time::advance(Duration::from_secs(2)).await;
103            yield Result::<(), Error>::Ok(());
104        };
105
106        let mut s = StreamWithTimeout::new(s.boxed(), Duration::from_secs(1)).boxed();
107
108        assert!(s.try_next().await?.is_some());
109        assert!(s.try_next().await.is_err());
110        assert!(s.try_next().await?.is_none());
111
112        Ok(())
113    }
114
115    #[tokio::test]
116    async fn test_stream_done_before_timeout() -> Result<(), Error> {
117        tokio::time::pause();
118
119        let s = async_stream::stream! {
120            yield Result::<(), Error>::Ok(());
121            yield Result::<(), Error>::Ok(());
122        };
123
124        let mut s = StreamWithTimeout::new(s.boxed(), Duration::from_secs(1)).boxed();
125
126        assert!(s.try_next().await?.is_some());
127        assert!(s.try_next().await?.is_some());
128        assert!(s.try_next().await?.is_none());
129
130        tokio::time::advance(Duration::from_secs(2)).await;
131
132        assert!(s.try_next().await?.is_none());
133
134        Ok(())
135    }
136
137    #[tokio::test]
138    async fn test_clock_starts_at_poll() -> Result<(), Error> {
139        tokio::time::pause();
140
141        let s = async_stream::stream! {
142            yield Result::<(), Error>::Ok(());
143            yield Result::<(), Error>::Ok(());
144        };
145        let mut s = StreamWithTimeout::new(s.boxed(), Duration::from_secs(1)).boxed();
146
147        tokio::time::advance(Duration::from_secs(2)).await;
148
149        assert!(s.try_next().await?.is_some());
150        assert!(s.try_next().await?.is_some());
151        assert!(s.try_next().await?.is_none());
152
153        Ok(())
154    }
155}