context_async/
timer.rs

1use std::future::Future;
2use std::pin::{Pin, pin};
3use std::sync::Arc;
4use std::task::Poll;
5use std::time;
6use log::error;
7use tokio::sync;
8use tokio::sync::RwLock;
9use tokio::time::Sleep;
10use crate::{Context, Error};
11#[cfg(feature = "name")]
12use crate::name::Name;
13
14/// The [`Timer`] structure is the default [`Context`].
15#[derive(Debug, Clone)]
16pub struct Timer {
17    inner: Arc<RwLock<Inner>>,
18}
19
20#[derive(Debug)]
21struct Inner {
22    #[cfg(feature = "name")]
23    name: Name,
24    expire_at: Option<time::Instant>,
25    cancelled: bool,
26    cancelled_sender: sync::broadcast::Sender<()>,
27    cancelled_receiver: sync::broadcast::Receiver<()>,
28    childs: Vec<Timer>,
29}
30
31impl Inner {
32    fn new() -> Self {
33        let (sender, receiver) = sync::broadcast::channel(32);
34
35        #[cfg(feature = "name")]
36        let name = Name::default();
37
38        #[cfg(feature = "tracing")]
39        {
40            #[cfg(feature = "name")]
41            tracing::trace!(context_new=name.as_u64());
42
43            #[cfg(not(feature = "name"))]
44            tracing::trace!(context_new="");
45        }
46
47        Self {
48            #[cfg(feature = "name")]
49            name,
50            expire_at: None,
51            cancelled: false,
52            cancelled_sender: sender,
53            cancelled_receiver: receiver,
54            childs: Default::default(),
55        }
56    }
57}
58
59#[async_trait::async_trait]
60impl Context for Timer {
61    type SubContext = Self;
62    
63    fn timer(&self) -> Timer {
64        (*self).clone()
65    }
66
67    #[cfg(feature = "name")]
68    async fn name(&self) -> Name {
69        self.inner.read().await.name
70    }
71
72    async fn deadline(&self) -> Option<time::Instant> {
73        self.inner.read().await
74            .expire_at
75    }
76
77    async fn cancel(&self) {
78        let mut inner = self.inner.write().await;
79        if !inner.cancelled {
80            inner.cancelled = true;
81            let _ = inner.cancelled_sender.send(());
82
83            for child in &inner.childs {
84                child.cancel().await;
85            }
86        }
87    }
88
89    async fn is_cancelled(&self) -> bool {
90       self.inner.read().await.cancelled
91    }
92
93    async fn is_timeout(&self) -> bool {
94        self.inner.read().await.expire_at
95            .is_some_and(|expire_at| expire_at < time::Instant::now())
96    }
97
98    async fn spawn(&self) -> Self {
99        let mut inner = self.inner.write().await;
100
101        let mut child = Inner::new();
102        child.expire_at = inner.expire_at;
103
104        #[cfg(feature = "tracing")]
105        {
106            #[cfg(feature = "name")]
107            tracing::trace!(context_spawn=inner.name.as_u64(), child=child.name.as_u64(), expire_at=?child.expire_at);
108            #[cfg(not(feature = "name"))]
109            tracing::trace!(context_spawn="", expire_at=?child.expire_at)
110        }
111
112        let child_timer = Self::from(child);
113        inner.childs.push(child_timer.clone());
114
115        child_timer
116    }
117
118    async fn spawn_with_timeout(&self, timeout: time::Duration) -> Self {
119        let mut inner = self.inner.write().await;
120
121        let child_expire_at = time::Instant::now() + timeout;
122        let child_expire_at = if let Some(expire_at) = inner.expire_at {
123            if child_expire_at > expire_at {
124                Some(expire_at)
125            } else {
126                Some(child_expire_at)
127            }
128        } else {
129            None
130        };
131
132        let mut child = Inner::new();
133        child.expire_at = child_expire_at;
134
135        #[cfg(feature = "tracing")]
136        {
137            #[cfg(feature = "name")]
138            tracing::trace!(context_spawn=inner.name.as_u64(), with_timeout=?timeout, child=child.name.as_u64(), expire_at=?child.expire_at);
139            #[cfg(not(feature = "name"))]
140            tracing::trace!(context_spawn="", with_timeout=?timeout, expire_at=?child.expire_at)
141        }
142
143        let child_timer = Self::from(child);
144        inner.childs.push(child_timer.clone());
145
146        child_timer
147    }
148
149    async fn handle<'a, Fut, Output>(&self, fut: Fut) -> crate::Result<Output>
150    where
151        Fut: Future<Output = Output> + Send + 'a
152    {
153        if self.is_cancelled().await {
154            return Err(Error::ContextCancelled);
155        } else if self.is_timeout().await {
156            return Err(Error::ContextTimeout);
157        }
158
159        let sleep = self.deadline().await
160            .map(tokio::time::Instant::from_std)
161            .map(tokio::time::sleep_until)
162            .map(Box::pin);
163
164        let mut cancel_receiver = self.cancel_receiver().await;
165        let cancel_receiver_fut = cancel_receiver.recv();
166
167        let task = Task {
168            cancel_receiver: Box::pin(cancel_receiver_fut),
169            sleep,
170            fut: Box::pin(fut),
171        };
172
173        task.await
174    }
175}
176
177impl From<Inner> for Timer {
178    fn from(value: Inner) -> Self {
179        Self { inner: Arc::new(RwLock::new(value)) }
180    }
181}
182
183impl Timer {
184    /// Create a default, independent timer with no time duration limit.
185    #[inline]
186    pub fn background() -> Self {
187        Self::from(Inner::new())
188    }
189
190    /// Create a default, independent timer with no time duration limit.
191    #[inline]
192    pub fn todo() -> Self {
193        Self::background()
194    }
195
196    /// Specify the maximum execution duration for the `Timer`.
197    #[inline]
198    pub fn with_timeout(timeout: time::Duration) -> Self {
199        let mut inner = Inner::new();
200        inner.expire_at = Some(time::Instant::now() + timeout);
201
202        Self::from(inner)
203    }
204
205    /// Specify the maximum execution duration for the `Timer`, in seconds.
206    #[inline]
207    pub fn in_seconds(secs: u64) -> Self {
208        Self::with_timeout(time::Duration::from_secs(secs))
209    }
210
211    /// Specify the maximum execution duration for the `Timer`, in seconds.
212    pub fn in_milliseconds(millis: u64) -> Self {
213        Self::with_timeout(time::Duration::from_millis(millis))
214    }
215
216    async fn cancel_receiver(&self) -> sync::broadcast::Receiver<()> {
217        self.inner.read().await.cancelled_receiver.resubscribe()
218    }
219}
220
221struct Task<Output, Fut, CErr, CFut>
222where
223    Fut: Future<Output = Output> + Send,
224    CErr: ::std::error::Error,
225    CFut: Future<Output = Result<(), CErr>> + Send,
226{
227    fut: Pin<Box<Fut>>,
228    sleep: Option<Pin<Box<Sleep>>>,
229    cancel_receiver: Pin<Box<CFut>>,
230}
231
232impl<Output, Fut, CErr, CFut> Future for Task<Output, Fut, CErr, CFut>
233where
234    Fut: Future<Output = Output> + Send,
235    CErr: std::error::Error,
236    CFut: Future<Output = Result<(), CErr>> + Send,
237{
238    type Output = crate::Result<Output>;
239
240    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
241        let this = self.get_mut();
242
243        if let Some(sleep) = this.sleep.as_mut() {
244            if pin!(sleep).poll(cx).is_ready() {
245                return Poll::Ready(Err(Error::ContextTimeout));
246            }
247        }
248
249        if let Poll::Ready(cancel_result) = pin!(&mut this.cancel_receiver).poll(cx) {
250            if let Err(e) = cancel_result {
251                error!("BUG: error when RecvError: {:?}", e);
252            }
253
254            return Poll::Ready(Err(Error::ContextCancelled));
255        }
256
257        this.fut.as_mut().poll(cx)
258            .map(|r| Ok(r))
259    }
260}
261
262#[cfg(feature = "actix-web-from-request")]
263impl actix_web::FromRequest for Timer {
264    type Error = actix_web::Error;
265    type Future = Pin<Box<dyn Future<Output = Result<Self, Self::Error>>>>;
266
267    fn from_request(_: &actix_web::HttpRequest, _: &mut actix_web::dev::Payload) -> Self::Future {
268        Box::pin(async {
269            Ok(Timer::background())
270        })
271    }
272}
273
274impl Drop for Inner {
275    fn drop(&mut self) {
276        #[cfg(feature = "tracing")]
277        {
278            #[cfg(feature = "name")]
279            tracing::trace!(context_drop=self.name.as_u64(), cancelled=self.cancelled, timeout=?self.expire_at);
280
281            #[cfg(not(feature = "name"))]
282            tracing::trace!(context_drop="", cancelled=self.cancelled, timeout=?self.expire_at);
283        }
284    }
285}