1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
use core::{future::Future, pin::Pin, task::{Poll, Context}};
#[cfg(feature = "std")]
use std::sync::Arc;
#[cfg(all(feature = "tokio", feature = "read-write"))]
mod tokio_read_write;
#[cfg(all(feature = "futures-io", feature = "read-write"))]
mod futures_io_read_write;
#[cfg(feature = "stream")]
mod stream;

use crate::{Timeout, runtime::Runtime};

#[derive(Clone)]
enum CowTimeout<'a, R: Runtime> {
    #[cfg(feature = "std")]
    Arc(Arc<Timeout<R>>),
    Ref(&'a Timeout<R>),
}
impl<'a, R: Runtime> AsRef<Timeout<R>> for CowTimeout<'a, R> {
    fn as_ref(&self) -> &Timeout<R> {
        match self {
            #[cfg(feature = "std")]
            Self::Arc(x) => x,
            Self::Ref(x) => x,
        }
    }
}

pin_project_lite::pin_project! {
    /// A wrapper that wraps a future, a stream or an async reader/writer and resets the timeout
    /// upon a new event.
    ///
    /// **WARNING: THIS WILL NOT TIME OUT AUTOMATICALLY. THE TIMEOUT MUST BE AWAITED SOMEWHERE ELSE.**
    /// See example below.
    ///
    /// - In case of a [future](core::future::Future), timeout will be reset upon future completion
    /// - In case of an [`AsyncRead`](tokio::io::AsyncRead) object, timeout will be reset upon a
    ///   successful read or seek.
    /// - In case of an [`AsyncWrite`](tokio::io::AsyncWrite) object, timeout will be reset upon a
    ///   successful write. It will not be reset upon a shutdown or a flush, please notify me if you
    ///   think this should be changed!
    /// - In case of a [`Stream`](futures_core::Stream) object, timeout will be reset upon stream
    ///   advancement.
    /// 
    /// Since [`Wrapper::new`] accepts a shared reference to `Timeout`, you can make multiple
    /// objects use a single timeout. This means the timeout will only expire when *all* objects
    /// stopped having new events.
    ///
    /// # Example
    /// ```
    /// # async fn wrapper() -> std::io::Result<()> {
    /// # let remote_stream = tokio_test::io::Builder::new().build();
    /// # let local_stream = tokio_test::io::Builder::new().build();
    /// // Proxy with timeout
    /// use std::{io, time::Duration};
    /// use async_shared_timeout::{runtime, Timeout, Wrapper};
    ///
    /// let runtime = runtime::Tokio::new();
    /// let timeout_dur = Duration::from_secs(10);
    /// let timeout = Timeout::new(runtime, timeout_dur);
    /// let mut remote_stream = Wrapper::new(remote_stream, &timeout, timeout_dur);
    /// let mut local_stream = Wrapper::new(local_stream, &timeout, timeout_dur);
    /// let (copied_a_to_b, copied_b_to_a) = tokio::select! {
    ///     _ = timeout.wait() => {
    ///         return Err(io::Error::new(io::ErrorKind::TimedOut, "stream timeout"))
    ///     }
    ///     x = tokio::io::copy_bidirectional(&mut remote_stream, &mut local_stream) => {
    ///         x?
    ///     }
    /// };
    /// # drop((copied_a_to_b, copied_b_to_a));
    /// # Ok(())
    /// # }
    /// ```
    pub struct Wrapper<'a, R: Runtime, T> {
        #[pin]
        inner: T,
        timeout: CowTimeout<'a, R>,
    }
}

impl<'a, R: Runtime, T> Wrapper<'a, R, T> {
    /// Create a wrapper around an object that will update the given timeout upon successful
    /// operations
    /// 
    /// # Arguments
    /// 
    /// - `inner` - the object to be wrapped
    /// - `timeout` - a reference to the timeout to be used for operations on `inner`
    /// - `default_timeout` - on a successful operation, `timeout` will be [reset](`Timeout::reset`) to this value
    #[must_use]
    pub fn new(inner: T, timeout: &'a Timeout<R>) -> Self {
        Self {
            inner,
            timeout: CowTimeout::Ref(timeout),
        }
    }
    /// Create a wrapper using a timeout behind an `Arc` pointer rather than a shared reference.
    /// See [`Wrapper::new`] for more info.
    #[cfg(feature = "std")]
    #[must_use]
    pub fn new_arc(inner: T, timeout: Arc<Timeout<R>>) -> Self {
        Self {
            inner,
            timeout: CowTimeout::Arc(timeout),
        }
    }
    /// The timeout reference
    pub fn timeout(&self) -> &Timeout<R> {
        self.timeout.as_ref()
    }
    /// A reference to the underlying object
    pub fn inner(&self) -> &T {
        &self.inner
    }
    /// A mutable reference to the underlying object
    pub fn inner_mut(&mut self) -> &mut T {
        &mut self.inner
    }
}

impl<T, R: Runtime> AsRef<T> for Wrapper<'_, R, T> {
    fn as_ref(&self) -> &T {
        &self.inner
    }
}

impl<T, R: Runtime> AsMut<T> for Wrapper<'_, R, T> {
    fn as_mut(&mut self) -> &mut T {
        &mut self.inner
    }
}

impl<'a, R: Runtime, T: Future> Future for Wrapper<'a, R, T> {
    type Output = T::Output;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let pinned = self.project();
        match pinned.inner.poll(cx) {
            Poll::Ready(x) => {
                pinned.timeout.as_ref().reset();
                Poll::Ready(x)
            }
            Poll::Pending => Poll::Pending,
        }
    }
}