1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
use std::sync::OnceLock;

use futures::{
    executor::ThreadPool, future::BoxFuture, task::SpawnExt, Future, FutureExt, SinkExt, StreamExt,
};

/// Future executor must implement this trait to support register to hala register system.
pub trait FutureSpawner {
    /// The implementation must panic if this function spawn future failed.
    fn spawn_boxed_future(&self, future: BoxFuture<'static, ()>);
}

static REGISTER: OnceLock<Box<dyn FutureSpawner + Send + Sync + 'static>> = OnceLock::new();

/// Register global spawner implementation.
pub fn register_spawner<S: FutureSpawner + Send + Sync + 'static>(spawner: S) {
    if REGISTER.set(Box::new(spawner)).is_err() {
        panic!("Call register_spawner twice.");
    }
}

/// Using global register [`Spawner`] to start a new future task.
pub fn future_spawn<Fut>(fut: Fut)
where
    Fut: Future<Output = ()> + Send + 'static,
{
    let spawner = REGISTER.get_or_init(|| {
        #[cfg(not(feature = "futures-executor"))]
        panic!("Call register_spawner first");

        #[cfg(feature = "futures-executor")]
        Box::new(
            ThreadPool::builder()
                .pool_size(num_cpus::get())
                .create()
                .unwrap(),
        )
    });

    spawner.spawn_boxed_future(fut.boxed())
}

impl FutureSpawner for futures::executor::ThreadPool {
    fn spawn_boxed_future(&self, future: BoxFuture<'static, ()>) {
        self.spawn(future)
            .expect("futures::executor::ThreadPool spawn failed");
    }
}

pub fn block_on<Fut, R>(fut: Fut) -> R
where
    Fut: Future<Output = R> + Send + 'static,
    R: Send + 'static,
{
    let (mut sender, mut receiver) = futures::channel::mpsc::channel::<R>(0);

    future_spawn(async move {
        let r = fut.await;
        _ = sender.send(r).await;
    });

    futures::executor::block_on(async move { receiver.next().await.unwrap() })
}