cs_utils/utils/futures/
with_timeout.rs

1use std::pin::Pin;
2
3use futures::{future, Future};
4
5use super::wait;
6
7/// Runs a future to completion or to timeout, whatever happens first.
8/// 
9/// Wraps the original result of a future into an `Option` to indicate if
10/// the future completed (`Some(original result)`) or timed out (`None`).
11///
12/// ### Examples
13/// 
14/// ```
15/// use std::time::Instant;
16///
17/// use cs_utils::{
18///     random_number,
19///     futures::{wait, with_timeout},
20/// };
21///
22/// #[tokio::main]
23/// async fn main() {
24///     let timeout: u64 = 25;
25///
26///     // as future that never completes
27///     let forever_future = async move {
28///         loop {
29///             wait(5).await;
30///         }
31///     };
32///
33///     // wrap the original future to get the `stop` function
34///     let with_timeout_future = with_timeout(forever_future, timeout);
35///
36///     // record starting time
37///     let start_time = Instant::now();
38/// 
39///     // wait for the future to complete
40///     let result = with_timeout_future.await;
41///         
42///     // calculate elapsed time
43///     let time_delta_ms = (Instant::now() - start_time).as_millis();
44///
45///     assert!(
46///         result.is_none(),
47///         "Timed out future must complete with `None` result.",
48///     );
49///
50///     // assert that the completion time of the future is close to the `timeout`
51///     assert!(
52///         time_delta_ms >= (timeout - 2) as u128,
53///         "Must have waited for at least duration of the timeout.",
54///     );
55///     assert!(
56///         time_delta_ms <= (timeout + 2) as u128,
57///         "Must have waited for at most duration of the timeout.",
58///     );
59/// }
60/// ```
61pub fn with_timeout<
62    T: Send + 'static,
63    TFuture: Future<Output = T> + Send + 'static,
64>(
65    original_future: TFuture,
66    timeout_ms: u64,
67) -> Pin<Box<dyn Future<Output = Option<T>> + Send + 'static>> {
68    let futures: Vec<Pin<Box<dyn Future<Output = Option<T>> + Send + 'static>>> = vec![
69        Box::pin(async move {
70            wait(timeout_ms).await;
71
72            return None;
73        }),
74        
75        Box::pin(async move {
76            return Some(
77                original_future.await,
78            );
79        }),
80    ];
81
82    let result_future: Pin<Box<dyn Future<Output = Option<T>> + Send + 'static>> = Box::pin(async move {
83        let (result, _, _) = future::select_all(futures).await;
84
85        result
86    });
87
88    return result_future;
89}
90
91#[cfg(test)]
92mod tests {
93    use std::time::Instant;
94
95    use cs_utils::{
96        random_number,
97        futures::{wait, with_timeout},
98    };
99    
100    #[tokio::test]
101    async fn can_stop_a_future_after_a_timeout() {
102        let timeout: u64 = 25;
103
104        // as future that never completes
105        let forever_future = async move {
106            loop {
107                wait(5).await;
108            }
109        };
110        
111        // wrap the original future to get the `stop` function
112        let with_timeout_future = with_timeout(forever_future, timeout);
113
114        // record starting time
115        let start_time = Instant::now();
116        
117        // wait for the future to complete
118        let result = with_timeout_future.await;
119        
120        // calculate elapsed time
121        let time_delta_ms = (Instant::now() - start_time).as_millis();
122
123        assert!(
124            result.is_none(),
125            "Timed out future must complete with `None` result.",
126        );
127
128        // assert that the completion time of the future is close to the `timeout`
129        assert!(
130            time_delta_ms >= (timeout - 2) as u128,
131            "Must have waited for at least duration of the timeout (\"{delta}\" vs \"{timeout}\").",
132            delta = time_delta_ms,
133            timeout = timeout,
134        );
135        assert!(
136            time_delta_ms <= (timeout + 2) as u128,
137            "Must have waited for at most duration of the timeout (\"{delta}\" vs \"{timeout}\").",
138            delta = time_delta_ms,
139            timeout = timeout,
140        );
141    }
142
143    #[tokio::test]
144    async fn completed_future_returns_some_result() {
145        let timeout: u64 = 25;
146        let completion_delay: u64 = 5;
147        let completing_future_result = random_number(0..=i128::MAX);
148
149        // as future that never completes
150        let completing_future = async move {
151            wait(completion_delay).await;
152
153            return completing_future_result;
154        };
155        
156        // wrap the original future to get the `stop` function
157        let with_timeout_future = with_timeout(completing_future, timeout);
158
159        // record starting time
160        let start_time = Instant::now();
161        
162        // wait for the future to complete
163        let result = with_timeout_future.await
164            .expect("Completed future must complete with `Some(original result)`.");
165
166        // calculate elapsed time
167        let time_delta_ms = (Instant::now() - start_time).as_millis();
168    
169        assert_eq!(
170            result,
171            completing_future_result,
172            "Original and received future results must match.",
173        );
174
175        // assert that the completion time of the future is close to the `completion_delay`
176        assert!(
177            time_delta_ms >= (completion_delay - 2) as u128,
178            "Must have waited for at least duration of the completion delay (\"{delta}\" vs \"{delay}\").",
179            delta = time_delta_ms,
180            delay = completion_delay,
181        );
182        assert!(
183            time_delta_ms <= (completion_delay + 2) as u128,
184            "Must have waited for at most duration of the completion delay (\"{delta}\" vs \"{delay}\").",
185            delta = time_delta_ms,
186            delay = completion_delay,
187        );
188    }
189}