async_shared_timeout/wrapper/mod.rs
1use core::{
2 future::Future,
3 pin::Pin,
4 task::{Context, Poll},
5};
6#[cfg(feature = "std")]
7use std::sync::Arc;
8#[cfg(all(feature = "futures-io", feature = "read-write"))]
9mod futures_io_read_write;
10#[cfg(feature = "stream")]
11mod stream;
12#[cfg(all(feature = "tokio", feature = "read-write"))]
13mod tokio_read_write;
14#[cfg(all(feature = "std", unix))]
15use std::os::unix::io::{AsRawFd, RawFd};
16
17use crate::{runtime::Runtime, Timeout};
18
19#[derive(Clone)]
20enum CowTimeout<'a, R: Runtime> {
21 #[cfg(feature = "std")]
22 Arc(Arc<Timeout<R>>),
23 Ref(&'a Timeout<R>),
24}
25impl<'a, R: Runtime> AsRef<Timeout<R>> for CowTimeout<'a, R> {
26 fn as_ref(&self) -> &Timeout<R> {
27 match self {
28 #[cfg(feature = "std")]
29 Self::Arc(x) => x,
30 Self::Ref(x) => x,
31 }
32 }
33}
34
35pin_project_lite::pin_project! {
36 /// A wrapper that wraps a future, a stream or an async reader/writer and resets the timeout
37 /// upon a new event.
38 ///
39 /// **WARNING: THIS WILL NOT TIME OUT AUTOMATICALLY. THE TIMEOUT MUST BE AWAITED SOMEWHERE ELSE.**
40 /// See example below.
41 ///
42 /// - In case of a [future](core::future::Future), timeout will be reset upon future completion
43 /// - In case of an [`AsyncRead`](tokio::io::AsyncRead) object, timeout will be reset upon a
44 /// successful read or seek.
45 /// - In case of an [`AsyncWrite`](tokio::io::AsyncWrite) object, timeout will be reset upon a
46 /// successful write. It will not be reset upon a shutdown or a flush, please notify me if you
47 /// think this should be changed!
48 /// - In case of a [`Stream`](futures_core::Stream) object, timeout will be reset upon stream
49 /// advancement.
50 ///
51 /// Since [`Wrapper::new`] accepts a shared reference to `Timeout`, you can make multiple
52 /// objects use a single timeout. This means the timeout will only expire when *all* objects
53 /// stopped having new events.
54 ///
55 /// # Example
56 /// ```
57 /// # async fn wrapper() -> std::io::Result<()> {
58 /// # let remote_stream = tokio_test::io::Builder::new().build();
59 /// # let local_stream = tokio_test::io::Builder::new().build();
60 /// // Proxy with timeout
61 /// use std::{io, time::Duration};
62 /// use async_shared_timeout::{runtime, Timeout, Wrapper};
63 ///
64 /// let runtime = runtime::Tokio::new();
65 /// let timeout_dur = Duration::from_secs(10);
66 /// let timeout = Timeout::new(runtime, timeout_dur);
67 /// let mut remote_stream = Wrapper::new(remote_stream, &timeout);
68 /// let mut local_stream = Wrapper::new(local_stream, &timeout);
69 /// let (copied_a_to_b, copied_b_to_a) = tokio::select! {
70 /// _ = timeout.wait() => {
71 /// return Err(io::Error::new(io::ErrorKind::TimedOut, "stream timeout"))
72 /// }
73 /// x = tokio::io::copy_bidirectional(&mut remote_stream, &mut local_stream) => {
74 /// x?
75 /// }
76 /// };
77 /// # drop((copied_a_to_b, copied_b_to_a));
78 /// # Ok(())
79 /// # }
80 /// ```
81 #[cfg_attr(docsrs, doc(cfg(feature = "wrapper")))]
82 pub struct Wrapper<'a, R: Runtime, T> {
83 #[pin]
84 inner: T,
85 timeout: CowTimeout<'a, R>,
86 }
87}
88
89/// An alias for [`Wrapper`] using the tokio runtime
90#[cfg(feature = "tokio")]
91pub type TokioWrapper<'a, T> = Wrapper<'a, crate::runtime::Tokio, T>;
92
93impl<'a, R: Runtime, T> Wrapper<'a, R, T> {
94 /// Create a wrapper around an object that will update the given timeout upon successful
95 /// operations
96 ///
97 /// # Arguments
98 ///
99 /// - `inner` - the object to be wrapped
100 /// - `timeout` - a reference to the timeout to be used for operations on `inner`
101 /// - `default_timeout` - on a successful operation, `timeout` will be [reset](`Timeout::reset`) to this value
102 #[must_use]
103 pub fn new(inner: T, timeout: &'a Timeout<R>) -> Self {
104 Self {
105 inner,
106 timeout: CowTimeout::Ref(timeout),
107 }
108 }
109 /// The timeout reference
110 pub fn timeout(&self) -> &Timeout<R> {
111 self.timeout.as_ref()
112 }
113 /// A reference to the underlying object
114 pub fn inner(&self) -> &T {
115 &self.inner
116 }
117 /// A mutable reference to the underlying object
118 pub fn inner_mut(&mut self) -> &mut T {
119 &mut self.inner
120 }
121}
122
123#[cfg(feature = "std")]
124#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
125impl<R: Runtime, T> Wrapper<'static, R, T> {
126 /// Create a wrapper using a timeout behind an `Arc` pointer rather than a shared reference.
127 /// See [`Wrapper::new`] for more info.
128 #[must_use]
129 pub fn new_arc(inner: T, timeout: Arc<Timeout<R>>) -> Self {
130 Self {
131 inner,
132 timeout: CowTimeout::Arc(timeout),
133 }
134 }
135}
136
137impl<T, R: Runtime> AsRef<T> for Wrapper<'_, R, T> {
138 fn as_ref(&self) -> &T {
139 &self.inner
140 }
141}
142
143impl<T, R: Runtime> AsMut<T> for Wrapper<'_, R, T> {
144 fn as_mut(&mut self) -> &mut T {
145 &mut self.inner
146 }
147}
148
149impl<R: Runtime, T: Future> Future for Wrapper<'_, R, T> {
150 type Output = T::Output;
151
152 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
153 let pinned = self.project();
154 match pinned.inner.poll(cx) {
155 Poll::Ready(x) => {
156 pinned.timeout.as_ref().reset();
157 Poll::Ready(x)
158 }
159 Poll::Pending => Poll::Pending,
160 }
161 }
162}
163
164#[cfg(all(feature = "std", unix))]
165#[cfg_attr(docsrs, doc(cfg(all(feature = "std", unix))))]
166impl<R: Runtime, T: AsRawFd> AsRawFd for Wrapper<'_, R, T> {
167 fn as_raw_fd(&self) -> RawFd {
168 self.inner.as_raw_fd()
169 }
170}