1use pin_project::pin_project;
2use std::{
3 future::Future,
4 marker::PhantomData,
5 pin::Pin,
6 task::{Context, Poll},
7 time::Duration,
8};
9
10#[pin_project]
12pub struct InspectTimeout<Fut, F, T> {
13 #[pin]
14 fut: Fut,
15 #[pin]
16 delay: tokio::time::Sleep,
17 elapse_fn: Option<F>,
18 delay_state: DelayState,
19 _phantom: PhantomData<T>,
20}
21
22impl<Fut, F, T> InspectTimeout<Fut, F, T>
23where
24 F: FnOnce(),
25{
26 pub fn new(fut: Fut, dur: Duration, elapse_fn: F) -> Self {
27 Self {
28 fut,
29 delay: tokio::time::sleep(dur),
30 elapse_fn: Some(elapse_fn),
31 delay_state: DelayState::Idle,
32 _phantom: PhantomData,
33 }
34 }
35
36 fn call_elapse_fn(self: Pin<&mut Self>) {
37 let this = self.project();
38
39 this.elapse_fn
40 .take()
41 .expect("elapse_fn must be called once")();
42
43 *this.delay_state = DelayState::Completed;
44 }
45}
46
47impl<Fut, F, T> Future for InspectTimeout<Fut, F, T>
48where
49 Fut: Future<Output = T>,
50 F: FnOnce(),
51{
52 type Output = T;
53
54 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
55 let this = self.as_mut().project();
56
57 if let Poll::Ready(r) = this.fut.poll(cx) {
58 return Poll::Ready(r);
59 };
60
61 match this.delay_state {
63 DelayState::Idle => match this.delay.poll(cx) {
64 Poll::Ready(_) => {
65 self.as_mut().call_elapse_fn();
66 }
67 Poll::Pending => *this.delay_state = DelayState::Running,
68 },
69 DelayState::Running => {
70 if this.delay.poll(cx).is_ready() {
71 self.as_mut().call_elapse_fn();
72 }
73 }
74 DelayState::Completed => {}
75 };
76
77 Poll::Pending
78 }
79}
80
81pub trait InspectTimeoutExt<Fut, F, T>
82where
83 Fut: Future<Output = T>,
84 F: FnOnce(),
85{
86 fn inspect_timeout(self, dur: Duration, elapse_fn: F) -> InspectTimeout<Fut, F, T>;
88}
89
90impl<Fut, F, T> InspectTimeoutExt<Fut, F, T> for Fut
91where
92 Fut: Future<Output = T>,
93 F: FnOnce(),
94{
95 fn inspect_timeout(self, dur: Duration, elapse_fn: F) -> InspectTimeout<Fut, F, T> {
96 InspectTimeout::new(self, dur, elapse_fn)
97 }
98}
99
100enum DelayState {
101 Idle,
102 Running,
103 Completed,
104}