1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// Copyright (c) 2022-2023 Yuki Kishimoto
// Distributed under the MIT software license

//! Thread

use core::fmt;
use core::time::Duration;

use futures_util::stream::{AbortHandle, Abortable};
use futures_util::Future;
#[cfg(feature = "blocking")]
use tokio::runtime::{Builder, Runtime};

#[cfg(target_arch = "wasm32")]
mod wasm;

type Result<T, E = Box<dyn std::error::Error>> = core::result::Result<T, E>;

#[cfg(feature = "blocking")]
fn new_current_thread() -> Result<Runtime> {
    Ok(Builder::new_current_thread().enable_all().build()?)
}

/// Thread Error
#[derive(Debug)]
pub enum Error {
    /// Join Error
    JoinError,
}

impl std::error::Error for Error {}

impl fmt::Display for Error {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::JoinError => write!(f, "impossible to join thread"),
        }
    }
}

/// Join Handle
pub enum JoinHandle<T> {
    /// Std
    #[cfg(not(target_arch = "wasm32"))]
    Std(std::thread::JoinHandle<T>),
    /// Tokio
    #[cfg(not(target_arch = "wasm32"))]
    Tokio(tokio::task::JoinHandle<T>),
    /// Wasm
    #[cfg(target_arch = "wasm32")]
    Wasm(self::wasm::JoinHandle<T>),
}

impl<T> JoinHandle<T> {
    /// Join
    pub async fn join(self) -> Result<T, Error> {
        match self {
            #[cfg(not(target_arch = "wasm32"))]
            Self::Std(handle) => handle.join().map_err(|_| Error::JoinError),
            #[cfg(not(target_arch = "wasm32"))]
            Self::Tokio(handle) => handle.await.map_err(|_| Error::JoinError),
            #[cfg(target_arch = "wasm32")]
            Self::Wasm(handle) => handle.join().await.map_err(|_| Error::JoinError),
        }
    }
}

/// Spawn
#[cfg(not(target_arch = "wasm32"))]
pub fn spawn<T>(future: T) -> Option<JoinHandle<T::Output>>
where
    T: Future + Send + 'static,
    T::Output: Send + 'static,
{
    #[cfg(feature = "blocking")]
    match new_current_thread() {
        Ok(rt) => {
            let handle = std::thread::spawn(move || {
                let res = rt.block_on(future);
                rt.shutdown_timeout(Duration::from_millis(100));
                res
            });
            Some(JoinHandle::Std(handle))
        }
        Err(_) => None,
    }

    #[cfg(not(feature = "blocking"))]
    {
        let handle = tokio::task::spawn(future);
        Some(JoinHandle::Tokio(handle))
    }
}

/// Spawn
#[cfg(target_arch = "wasm32")]
pub fn spawn<T>(future: T) -> Option<JoinHandle<T::Output>>
where
    T: Future + 'static,
{
    let handle = self::wasm::spawn(future);
    Some(JoinHandle::Wasm(handle))
}

/// Spawn abortable thread
#[cfg(not(target_arch = "wasm32"))]
pub fn abortable<T>(future: T) -> AbortHandle
where
    T: Future + Send + 'static,
    T::Output: Send + 'static,
{
    let (abort_handle, abort_registration) = AbortHandle::new_pair();
    spawn(Abortable::new(future, abort_registration));
    abort_handle
}

/// Spawn abortable thread
#[cfg(target_arch = "wasm32")]
pub fn abortable<T>(future: T) -> AbortHandle
where
    T: Future + 'static,
{
    let (abort_handle, abort_registration) = AbortHandle::new_pair();
    spawn(Abortable::new(future, abort_registration));
    abort_handle
}

/// Sleep
pub async fn sleep(duration: Duration) {
    #[cfg(not(target_arch = "wasm32"))]
    tokio::time::sleep(duration).await;
    #[cfg(target_arch = "wasm32")]
    gloo_timers::future::sleep(duration).await;
}