async_shared_timeout/
lib.rs

1#![cfg_attr(
2    all(
3        not(feature = "std"),
4        not(feature = "async-io"),
5        not(feature = "tokio")
6    ),
7    no_std
8)]
9#![cfg_attr(docsrs, feature(doc_cfg))]
10
11//! A crate that offers a way to create a timeout that can be reset and shared.
12//! Additionally, stream timeout is offered under a feature flag.
13//!
14//! # Feature flags:
15//!
16//! **Wrapper**
17//!
18//! - `wrapper` - enable a wrapper around types that you can use for easier resetting. By default,
19//!               only future support is enabled (reset the timer upon future completion).
20//! - `read-write` - enable async `Read`/`Write` trait support for the wrapper (reset the timer
21//!                  upon successful read/write operations)
22//! - `stream` - enable `Stream` support for the wrapper (reset the timer upon stream advancement).
23//!
24//! **Integration with other runtimes**
25//!
26//! - `std` (enabled by default) - enable `std` integration. Currently it's only used to enable
27//!                                `Arc` and `AsRawFd` support for the wrapper.
28//! - `tokio` (enabled by default) - [`tokio`](https://docs.rs/tokio) support
29//! - `async-io` - support [`async-io`](https://docs.rs/async-io) as the timer runtime.
30//! - `futures-io` - support [`futures-io`](https://docs.rs/futures-io) traits.
31//! - `async-std` - [`async-std`](https://docs.rs/async-std) support (enables `async-io` and `futures-io`).
32//!
33//! See struct documentation for examples.
34use core::{
35    future::Future,
36    pin::Pin,
37    sync::atomic::Ordering,
38    task::{Context, Poll},
39    time::Duration,
40};
41use portable_atomic::AtomicU64;
42
43pub mod runtime;
44use runtime::{Instant, Runtime, Sleep};
45
46/// A shared timeout.
47///
48/// # Example
49///
50/// ```
51/// # async fn read_command() -> Option<&'static str> { Some("command") }
52/// # async fn example_fn() {
53/// use std::time::Duration;
54///
55/// let timeout_secs = Duration::from_secs(10);
56/// // Use the tokio runtime
57/// let runtime = async_shared_timeout::runtime::Tokio::new();
58/// let timeout = async_shared_timeout::Timeout::new(runtime, timeout_secs);
59/// tokio::select! {
60///     _ = timeout.wait() => {
61///         println!("timeout expired!");
62///     }
63///     _ = async {
64///         while let Some(cmd) = read_command().await {
65///             println!("command received: {:?}", cmd);
66///             timeout.reset();
67///         }
68///     } => {
69///         println!("no more commands!");
70///     }
71/// }
72/// # }
73/// ```
74#[derive(Debug)]
75pub struct Timeout<R: Runtime> {
76    runtime: R,
77    epoch: R::Instant,
78    timeout_from_epoch_ns: AtomicU64,
79    default_timeout: AtomicU64,
80}
81
82/// An alias for [`Timeout`] using the tokio runtime
83#[cfg(feature = "tokio")]
84pub type TokioTimeout = Timeout<runtime::Tokio>;
85
86#[cfg(feature = "tokio")]
87impl TokioTimeout {
88    /// Create a new timeout that expires after `default_timeout`, creating a runtime with [`runtime::Tokio::new`]
89    ///
90    /// # Panics
91    /// Panics if `default_timeout` is longer than ~584 years
92    pub fn new_tokio(default_timeout: Duration) -> Self {
93        let runtime = runtime::Tokio::new();
94        let epoch = runtime.now();
95        let default_timeout = u64::try_from(default_timeout.as_nanos()).unwrap();
96        Self {
97            runtime,
98            epoch,
99            timeout_from_epoch_ns: default_timeout.into(),
100            default_timeout: default_timeout.into(),
101        }
102    }
103}
104
105impl<R: Runtime> Timeout<R> {
106    /// Create a new timeout that expires after `default_timeout`
107    ///
108    /// # Panics
109    /// Panics if `default_timeout` is longer than ~584 years
110    #[must_use]
111    pub fn new(runtime: R, default_timeout: Duration) -> Self {
112        let epoch = runtime.now();
113        let default_timeout = u64::try_from(default_timeout.as_nanos()).unwrap();
114        Self {
115            runtime,
116            epoch,
117            timeout_from_epoch_ns: default_timeout.into(),
118            default_timeout: default_timeout.into(),
119        }
120    }
121
122    fn elapsed(&self) -> Duration {
123        self.runtime.now().duration_since(&self.epoch)
124    }
125
126    /// Reset the timeout to the default time.
127    ///
128    /// This function is cheap to call.
129    ///
130    /// # Panics
131    /// Panics if over ~584 years have elapsed since the timer started.
132    pub fn reset(&self) {
133        self.timeout_from_epoch_ns.store(
134            u64::try_from(self.elapsed().as_nanos()).unwrap()
135                + self.default_timeout.load(Ordering::Acquire),
136            Ordering::Release,
137        );
138    }
139
140    /// The default timeout. Timeout will be reset to this value upon a successful operation.
141    pub fn default_timeout(&self) -> Duration {
142        Duration::from_nanos(self.default_timeout.load(Ordering::Acquire))
143    }
144    /// Change the default timeout.
145    ///
146    /// Warning: if this timeout is shorter than previous one, it will only update after the
147    /// previous timeout has expired!
148    ///
149    /// Additionally, this won't automatically reset the timeout - it will only affect the next
150    /// reset.
151    ///
152    /// # Panics
153    /// Panics if `default_timeout` is longer than ~584 years
154    pub fn set_default_timeout(&self, default_timeout: Duration) {
155        self.default_timeout.store(
156            u64::try_from(default_timeout.as_nanos()).unwrap(),
157            Ordering::Release,
158        );
159    }
160
161    fn timeout_duration(&self) -> Option<Duration> {
162        let elapsed_nanos = u64::try_from(self.elapsed().as_nanos()).unwrap();
163        let target_nanos = self.timeout_from_epoch_ns.load(Ordering::Acquire);
164        (elapsed_nanos < target_nanos).then(|| Duration::from_nanos(target_nanos - elapsed_nanos))
165    }
166
167    /// Wait for the timeout to expire
168    ///
169    /// This is a function that's expensive to start, so for best performance, only call it once
170    /// per timer - launch it separately and call [`reset`](Timeout::reset) from the
171    /// other futures (see the example in top-level documentation).
172    pub async fn wait(&self) {
173        pin_project_lite::pin_project! {
174            struct SleepFuture<F: Sleep> {
175                #[pin]
176                inner: F,
177            }
178        }
179
180        impl<F: Sleep> Future for SleepFuture<F> {
181            type Output = ();
182            fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
183                self.project().inner.poll_sleep(cx)
184            }
185        }
186        if let Some(timeout) = self.timeout_duration() {
187            let future = self.runtime.create_sleep(timeout);
188            let mut future = SleepFuture { inner: future };
189            // SAFETY: the original future binding is shadowed,
190            // so the unpinned binding can never be accessed again.
191            // This is exactly the same code as the tokio::pin! macro
192            let future = &mut unsafe { Pin::new_unchecked(&mut future) };
193            while let Some(instant) = self.timeout_duration() {
194                future.as_mut().project().inner.reset(instant);
195                future.as_mut().await;
196            }
197        }
198    }
199}
200
201#[cfg(feature = "wrapper")]
202mod wrapper;
203#[cfg(feature = "wrapper")]
204pub use wrapper::Wrapper;
205#[cfg(all(feature = "wrapper", feature = "tokio"))]
206pub use wrapper::TokioWrapper;
207
208#[cfg(test)]
209mod tests {
210    use tokio::time::Instant;
211
212    use crate::*;
213    #[test]
214    fn test_expiry() {
215        let start = Instant::now();
216        tokio_test::block_on(async {
217            let timer = Timeout::new(runtime::Tokio::new(), Duration::from_secs(1));
218            timer.wait().await;
219        });
220        assert!(start.elapsed() >= Duration::from_secs(1));
221    }
222    #[test]
223    fn test_non_expiry() {
224        let start = Instant::now();
225        assert!(tokio_test::block_on(async {
226            let timer = Timeout::new(runtime::Tokio::new(), Duration::from_secs(2));
227            tokio::select! {
228                _ = timer.wait() => {
229                    false
230                }
231                _ = async {
232                    tokio::time::sleep(Duration::from_secs(1)).await;
233                    timer.reset();
234                    tokio::time::sleep(Duration::from_secs(1)).await;
235                } => {
236                    true
237                }
238            }
239        }));
240        assert!(start.elapsed() >= Duration::from_secs(2));
241    }
242}