cs_utils/utils/futures/
with_thread.rs

1use std::thread;
2
3use futures::Future;
4use tokio::runtime::Handle;
5
6use super::wait;
7
8/// Run a future on a separate thread.
9///
10/// Useful when future has a blocking code. Normally such blocking code
11/// will also block current all asynchronous tasks. This helper mitigates
12/// such issue by running blocking future on a separate thread with its
13/// own Tokio runtime.
14/// 
15/// ### Examples
16/// 
17/// ```
18/// use std::{thread, pin::Pin, time::Duration};
19///
20/// use futures::{future, Future};
21/// use cs_utils::futures::{wait, with_thread};
22/// 
23/// #[tokio::main(worker_threads = 1)]
24/// async fn main() {
25///     // variable to count how many iterations
26///     // the normal future has run
27///     static mut RUN_CNT: u64 = 0;
28///
29///     // create a future that blocks current thread
30///     let blocking_future = async move {
31///         thread::sleep(Duration::from_secs(1));    
32///     };
33///     
34///     // create a future that intended to run in background
35///     let normal_future = async move {
36///         loop {
37///             unsafe { RUN_CNT += 1 }
38///             wait(100).await;
39///         }
40///     };
41///     
42///     // create futures list
43///     let futures: Vec<Pin<Box<dyn Future<Output = ()>>>> = vec![
44///         Box::pin(with_thread(blocking_future)), // <-- wrap the blocking future here
45///         Box::pin(normal_future),
46///     ];
47///     
48///     // race the futures to completion 
49///     future::select_all(futures).await;
50/// 
51///     // must go thru multiple iterations in the normal future
52///     assert!(
53///         unsafe { RUN_CNT >= 7 } && unsafe { RUN_CNT <= 10 },
54///         "Normal future must run iterations multiple times.",
55///     );
56/// }
57/// ```
58pub async fn with_thread<
59    T: Send + 'static,
60    TFuture: Future<Output = T> + Send + 'static,
61>(
62    original_future: TFuture,
63) -> T {
64    let tokio_handle = Handle::try_current()
65        .expect("Needs running Tokio runtime.");
66
67    let other_thread = thread::spawn(move || {
68        let _guard = tokio_handle.enter();
69
70        return tokio_handle.block_on(original_future);
71    });
72
73    // poll the thread, yielding if it is not finished yet
74    while !other_thread.is_finished() {
75        // common thread time slice is `~100ms`, so
76        // `5ms` delay should be granular enough here 
77        wait(5).await;
78
79        continue;
80    }
81
82    return other_thread
83        .join().unwrap();
84}
85
86#[cfg(test)]
87mod tests {
88    use std::{thread, pin::Pin, time::Duration};
89    use futures::{future, Future};
90    use cs_utils::futures::{wait, with_thread};
91
92    #[tokio::test]
93    async fn run_blocking_future_on_separate_thread() {
94        static mut NORMAL_FUTURE_RUN_COUNTER: u64 = 0;
95        let block_for_ms: u64 = 1000;
96        let run_each_ms: u64 = 100;
97
98        let blocking_future = async move {
99            thread::sleep(Duration::from_millis(block_for_ms));    
100        };
101        let normal_future = async move {
102            loop {
103                wait(run_each_ms).await;
104
105                unsafe {
106                    NORMAL_FUTURE_RUN_COUNTER += 1;
107                }
108            }
109        };
110    
111        let futures: Vec<Pin<Box<dyn Future<Output = ()>>>> = vec![
112            Box::pin(with_thread(blocking_future)),
113            Box::pin(normal_future),
114        ];
115        
116        // race futures to first completion
117        future::select_all(futures).await;
118
119        // normal future can run at most `BLOCK_FOR_MS` / `RUN_EACH_MS` times
120        let expected_run_count = block_for_ms / run_each_ms;
121        // assert that normal future run multiple time (close to `expected_run_count` times)
122        let run_delta = expected_run_count - unsafe { NORMAL_FUTURE_RUN_COUNTER };
123        assert!(
124            run_delta <= 3,
125            "Must run normal future multiple times.",
126        );
127    }
128
129    #[tokio::test]
130    async fn shares_runtime() {
131        static mut NORMAL_FUTURE_RUN_COUNTER: u64 = 0;
132        let block_for_ms: u64 = 1000;
133        let run_each_ms: u64 = 100;
134
135        let blocking_future = async move {
136            thread::sleep(Duration::from_millis(block_for_ms));    
137        };
138        let normal_future = async move {
139            loop {
140                wait(run_each_ms).await;
141
142                unsafe {
143                    NORMAL_FUTURE_RUN_COUNTER += 1;
144                }
145            }
146        };
147    
148        let futures: Vec<Pin<Box<dyn Future<Output = ()>>>> = vec![
149            Box::pin(with_thread(blocking_future)),
150            Box::pin(
151                with_thread(normal_future),
152            ),
153        ];
154        
155        // race futures to first completion
156        future::select_all(futures).await;
157
158        // normal future can run at most `BLOCK_FOR_MS` / `RUN_EACH_MS` times
159        let expected_run_count = block_for_ms / run_each_ms;
160        // assert that normal future run multiple time (close to `expected_run_count` times)
161        let run_delta = expected_run_count - unsafe { NORMAL_FUTURE_RUN_COUNTER };
162        assert!(
163            run_delta <= 3,
164            "Must run normal future multiple times.",
165        );
166    }
167
168    #[tokio::test]
169    async fn runs_nested_blocking_futures() {
170        static mut NORMAL_FUTURE_RUN_COUNTER: u64 = 0;
171        let block_for_ms: u64 = 1000;
172        let run_each_ms: u64 = 100;
173
174        let blocking_future = async move {
175            let fut = async move {
176                thread::sleep(Duration::from_millis(block_for_ms));    
177            };   
178
179            fut.await;
180        };
181        let normal_future = async move {
182            loop {
183                wait(run_each_ms).await;
184
185                unsafe {
186                    NORMAL_FUTURE_RUN_COUNTER += 1;
187                }
188            }
189        };
190    
191        let futures: Vec<Pin<Box<dyn Future<Output = ()>>>> = vec![
192            Box::pin(with_thread(blocking_future)),
193            Box::pin(normal_future),
194        ];
195        
196        // race futures to first completion
197        future::select_all(futures).await;
198
199        // normal future can run at most `BLOCK_FOR_MS` / `RUN_EACH_MS` times
200        let expected_run_count = block_for_ms / run_each_ms;
201        // assert that normal future run multiple time (close to `expected_run_count` times)
202        let run_delta = expected_run_count - unsafe { NORMAL_FUTURE_RUN_COUNTER };
203        assert!(
204            run_delta <= 3,
205            "Must run normal future multiple times.",
206        );
207    }
208
209    #[tokio::test]
210    async fn runs_nested_futures() {
211        static mut NORMAL_FUTURE_RUN_COUNTER: u64 = 0;
212        let block_for_ms: u64 = 1000;
213        let run_each_ms: u64 = 100;
214
215        let blocking_future = async move {
216            let fut = async move {
217                thread::sleep(Duration::from_millis(block_for_ms));    
218            };   
219
220            fut.await;
221        };
222        let normal_future = async move {
223            loop {
224                wait(run_each_ms).await;
225
226                unsafe {
227                    NORMAL_FUTURE_RUN_COUNTER += 1;
228                }
229            }
230        };
231    
232        let futures: Vec<Pin<Box<dyn Future<Output = ()>>>> = vec![
233            Box::pin(
234                with_thread(
235                    with_thread(blocking_future),
236                ),
237            ),
238            Box::pin(
239                with_thread(
240                    with_thread(
241                        with_thread(normal_future),
242                    ),
243                ),
244            ),
245        ];
246        
247        // race futures to first completion
248        future::select_all(futures).await;
249
250        // normal future can run at most `BLOCK_FOR_MS` / `RUN_EACH_MS` times
251        let expected_run_count = block_for_ms / run_each_ms;
252        // assert that normal future run multiple time (close to `expected_run_count` times)
253        let run_delta = expected_run_count - unsafe { NORMAL_FUTURE_RUN_COUNTER };
254        assert!(
255            run_delta <= 3,
256            "Must run normal future multiple times.",
257        );
258    }
259}