async_shared_timeout/wrapper/
mod.rs

1use core::{
2    future::Future,
3    pin::Pin,
4    task::{Context, Poll},
5};
6#[cfg(feature = "std")]
7use std::sync::Arc;
8#[cfg(all(feature = "futures-io", feature = "read-write"))]
9mod futures_io_read_write;
10#[cfg(feature = "stream")]
11mod stream;
12#[cfg(all(feature = "tokio", feature = "read-write"))]
13mod tokio_read_write;
14#[cfg(all(feature = "std", unix))]
15use std::os::unix::io::{AsRawFd, RawFd};
16
17use crate::{runtime::Runtime, Timeout};
18
19#[derive(Clone)]
20enum CowTimeout<'a, R: Runtime> {
21    #[cfg(feature = "std")]
22    Arc(Arc<Timeout<R>>),
23    Ref(&'a Timeout<R>),
24}
25impl<'a, R: Runtime> AsRef<Timeout<R>> for CowTimeout<'a, R> {
26    fn as_ref(&self) -> &Timeout<R> {
27        match self {
28            #[cfg(feature = "std")]
29            Self::Arc(x) => x,
30            Self::Ref(x) => x,
31        }
32    }
33}
34
35pin_project_lite::pin_project! {
36    /// A wrapper that wraps a future, a stream or an async reader/writer and resets the timeout
37    /// upon a new event.
38    ///
39    /// **WARNING: THIS WILL NOT TIME OUT AUTOMATICALLY. THE TIMEOUT MUST BE AWAITED SOMEWHERE ELSE.**
40    /// See example below.
41    ///
42    /// - In case of a [future](core::future::Future), timeout will be reset upon future completion
43    /// - In case of an [`AsyncRead`](tokio::io::AsyncRead) object, timeout will be reset upon a
44    ///   successful read or seek.
45    /// - In case of an [`AsyncWrite`](tokio::io::AsyncWrite) object, timeout will be reset upon a
46    ///   successful write. It will not be reset upon a shutdown or a flush, please notify me if you
47    ///   think this should be changed!
48    /// - In case of a [`Stream`](futures_core::Stream) object, timeout will be reset upon stream
49    ///   advancement.
50    ///
51    /// Since [`Wrapper::new`] accepts a shared reference to `Timeout`, you can make multiple
52    /// objects use a single timeout. This means the timeout will only expire when *all* objects
53    /// stopped having new events.
54    ///
55    /// # Example
56    /// ```
57    /// # async fn wrapper() -> std::io::Result<()> {
58    /// # let remote_stream = tokio_test::io::Builder::new().build();
59    /// # let local_stream = tokio_test::io::Builder::new().build();
60    /// // Proxy with timeout
61    /// use std::{io, time::Duration};
62    /// use async_shared_timeout::{runtime, Timeout, Wrapper};
63    ///
64    /// let runtime = runtime::Tokio::new();
65    /// let timeout_dur = Duration::from_secs(10);
66    /// let timeout = Timeout::new(runtime, timeout_dur);
67    /// let mut remote_stream = Wrapper::new(remote_stream, &timeout);
68    /// let mut local_stream = Wrapper::new(local_stream, &timeout);
69    /// let (copied_a_to_b, copied_b_to_a) = tokio::select! {
70    ///     _ = timeout.wait() => {
71    ///         return Err(io::Error::new(io::ErrorKind::TimedOut, "stream timeout"))
72    ///     }
73    ///     x = tokio::io::copy_bidirectional(&mut remote_stream, &mut local_stream) => {
74    ///         x?
75    ///     }
76    /// };
77    /// # drop((copied_a_to_b, copied_b_to_a));
78    /// # Ok(())
79    /// # }
80    /// ```
81    #[cfg_attr(docsrs, doc(cfg(feature = "wrapper")))]
82    pub struct Wrapper<'a, R: Runtime, T> {
83        #[pin]
84        inner: T,
85        timeout: CowTimeout<'a, R>,
86    }
87}
88
89/// An alias for [`Wrapper`] using the tokio runtime
90#[cfg(feature = "tokio")]
91pub type TokioWrapper<'a, T> = Wrapper<'a, crate::runtime::Tokio, T>;
92
93impl<'a, R: Runtime, T> Wrapper<'a, R, T> {
94    /// Create a wrapper around an object that will update the given timeout upon successful
95    /// operations
96    ///
97    /// # Arguments
98    ///
99    /// - `inner` - the object to be wrapped
100    /// - `timeout` - a reference to the timeout to be used for operations on `inner`
101    /// - `default_timeout` - on a successful operation, `timeout` will be [reset](`Timeout::reset`) to this value
102    #[must_use]
103    pub fn new(inner: T, timeout: &'a Timeout<R>) -> Self {
104        Self {
105            inner,
106            timeout: CowTimeout::Ref(timeout),
107        }
108    }
109    /// The timeout reference
110    pub fn timeout(&self) -> &Timeout<R> {
111        self.timeout.as_ref()
112    }
113    /// A reference to the underlying object
114    pub fn inner(&self) -> &T {
115        &self.inner
116    }
117    /// A mutable reference to the underlying object
118    pub fn inner_mut(&mut self) -> &mut T {
119        &mut self.inner
120    }
121}
122
123#[cfg(feature = "std")]
124#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
125impl<R: Runtime, T> Wrapper<'static, R, T> {
126    /// Create a wrapper using a timeout behind an `Arc` pointer rather than a shared reference.
127    /// See [`Wrapper::new`] for more info.
128    #[must_use]
129    pub fn new_arc(inner: T, timeout: Arc<Timeout<R>>) -> Self {
130        Self {
131            inner,
132            timeout: CowTimeout::Arc(timeout),
133        }
134    }
135}
136
137impl<T, R: Runtime> AsRef<T> for Wrapper<'_, R, T> {
138    fn as_ref(&self) -> &T {
139        &self.inner
140    }
141}
142
143impl<T, R: Runtime> AsMut<T> for Wrapper<'_, R, T> {
144    fn as_mut(&mut self) -> &mut T {
145        &mut self.inner
146    }
147}
148
149impl<R: Runtime, T: Future> Future for Wrapper<'_, R, T> {
150    type Output = T::Output;
151
152    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
153        let pinned = self.project();
154        match pinned.inner.poll(cx) {
155            Poll::Ready(x) => {
156                pinned.timeout.as_ref().reset();
157                Poll::Ready(x)
158            }
159            Poll::Pending => Poll::Pending,
160        }
161    }
162}
163
164#[cfg(all(feature = "std", unix))]
165#[cfg_attr(docsrs, doc(cfg(all(feature = "std", unix))))]
166impl<R: Runtime, T: AsRawFd> AsRawFd for Wrapper<'_, R, T> {
167    fn as_raw_fd(&self) -> RawFd {
168        self.inner.as_raw_fd()
169    }
170}