sync_utils/
blocking_async.rs

1use std::{
2    cell::UnsafeCell,
3    future::Future,
4    mem::transmute,
5    sync::{Arc, Condvar, Mutex},
6};
7
8use tokio::runtime::Runtime;
9
10/// For the use in blocking context.
11/// spawn a future into given tokio runtime and wait for result.
12///
13struct BlockingFutureInner<R>
14where
15    R: Sync + Send + 'static,
16{
17    res: UnsafeCell<Option<R>>,
18    cond: Condvar,
19    done: Mutex<bool>,
20}
21
22impl<R> BlockingFutureInner<R>
23where
24    R: Sync + Send + 'static,
25{
26    #[inline(always)]
27    fn done(&self, r: R) {
28        let _res: &mut Option<R> = unsafe { transmute(self.res.get()) };
29        _res.replace(r);
30        let mut guard = self.done.lock().unwrap();
31        *guard = true;
32        self.cond.notify_one();
33    }
34
35    #[inline(always)]
36    fn take_res(&self) -> R {
37        let _res: &mut Option<R> = unsafe { transmute(self.res.get()) };
38        _res.take().unwrap()
39    }
40}
41
42unsafe impl<R> Send for BlockingFutureInner<R> where R: Sync + Send + Clone + 'static {}
43
44unsafe impl<R> Sync for BlockingFutureInner<R> where R: Sync + Send + Clone + 'static {}
45
46pub struct BlockingFuture<R: Sync + Send + 'static>(Arc<BlockingFutureInner<R>>);
47
48impl<R> BlockingFuture<R>
49where
50    R: Sync + Send + Clone + 'static,
51{
52    #[inline(always)]
53    pub fn new() -> Self {
54        Self(Arc::new(BlockingFutureInner {
55            res: UnsafeCell::new(None),
56            cond: Condvar::new(),
57            done: Mutex::new(false),
58        }))
59    }
60
61    pub fn block_on<F>(&mut self, rt: &Runtime, f: F) -> R
62    where
63        F: Future<Output = R> + Send + Sync + 'static,
64    {
65        let _self = self.0.clone();
66        let _ = rt.spawn(async move {
67            let res = f.await;
68            _self.done(res);
69        });
70        let _self = self.0.as_ref();
71        let mut guard = _self.done.lock().unwrap();
72        loop {
73            if *guard {
74                return _self.take_res();
75            }
76            guard = _self.cond.wait(guard).unwrap();
77        }
78    }
79}
80
81#[cfg(test)]
82mod tests {
83
84    use std::time::Duration;
85
86    use tokio::time::sleep;
87
88    use super::*;
89
90    #[test]
91    fn test_spawn() {
92        let rt = tokio::runtime::Builder::new_multi_thread()
93            .enable_all()
94            .worker_threads(1)
95            .build()
96            .unwrap();
97
98        let mut bf = BlockingFuture::new();
99        let res = bf.block_on(&rt, async move {
100            sleep(Duration::from_secs(1)).await;
101            println!("exec future");
102            sleep(Duration::from_secs(1)).await;
103            return "hello world".to_string();
104        });
105        println!("got res {}", res);
106    }
107}