Skip to main content

kithara_platform/
thread.rs

1pub use std::time::Duration;
2use std::{
3    hash::{DefaultHasher, Hash, Hasher},
4    sync::atomic::{AtomicUsize, Ordering},
5};
6
7#[cfg(target_arch = "wasm32")]
8use wasm_bindgen::JsCast;
9#[cfg(target_arch = "wasm32")]
10use wasm_safe_thread::Builder as WasmThreadBuilder;
11
12#[cfg(not(target_arch = "wasm32"))]
13pub type Thread = std::thread::Thread;
14
15#[cfg(target_arch = "wasm32")]
16pub type Thread = wasm_safe_thread::Thread;
17
18#[cfg(not(target_arch = "wasm32"))]
19pub type ThreadId = std::thread::ThreadId;
20
21#[cfg(target_arch = "wasm32")]
22pub type ThreadId = wasm_safe_thread::ThreadId;
23
24#[cfg(not(target_arch = "wasm32"))]
25#[inline]
26pub fn yield_now() {
27    std::thread::yield_now();
28}
29
30#[cfg(target_arch = "wasm32")]
31#[inline]
32pub fn yield_now() {}
33
34/// Returns `true` when running inside a Web Worker.
35#[cfg(target_arch = "wasm32")]
36#[inline]
37#[must_use]
38pub fn is_worker_thread() -> bool {
39    js_sys::global()
40        .dyn_into::<web_sys::DedicatedWorkerGlobalScope>()
41        .is_ok()
42}
43
44/// Returns `true` when running on the browser main thread.
45#[cfg(target_arch = "wasm32")]
46#[inline]
47#[must_use]
48pub fn is_main_thread() -> bool {
49    !is_worker_thread()
50}
51
52/// Returns `false` on native targets.
53#[cfg(not(target_arch = "wasm32"))]
54#[inline]
55#[must_use]
56pub fn is_worker_thread() -> bool {
57    false
58}
59
60/// Returns `true` on native targets.
61#[cfg(not(target_arch = "wasm32"))]
62#[inline]
63#[must_use]
64pub fn is_main_thread() -> bool {
65    true
66}
67
68/// Panic if called from a non-main thread on wasm32.
69#[inline]
70pub fn assert_main_thread(_label: &str) {
71    #[cfg(target_arch = "wasm32")]
72    if !is_main_thread() {
73        panic!("main-thread-only call executed on worker thread: {_label}");
74    }
75}
76
77/// Panic if called from the wasm main thread.
78#[inline]
79pub fn assert_not_main_thread(_label: &str) {
80    #[cfg(target_arch = "wasm32")]
81    if is_main_thread() {
82        panic!("worker-thread-only call executed on main thread: {_label}");
83    }
84}
85
86#[cfg(not(target_arch = "wasm32"))]
87pub type JoinHandle<T> = std::thread::JoinHandle<T>;
88
89#[cfg(target_arch = "wasm32")]
90pub type JoinHandle<T> = wasm_safe_thread::JoinHandle<T>;
91
92/// Get a handle to the current thread.
93#[cfg(not(target_arch = "wasm32"))]
94#[inline]
95#[must_use]
96pub fn current() -> Thread {
97    std::thread::current()
98}
99
100/// Get a handle to the current thread.
101#[cfg(target_arch = "wasm32")]
102#[inline]
103#[must_use]
104pub fn current() -> Thread {
105    wasm_safe_thread::current()
106}
107
108/// Spawn a new thread.
109///
110/// On WASM, uses [`wasm_safe_thread::Builder`] with an explicit `shim_name`
111/// so workers can locate the wasm-bindgen JS shim for `initSync`.
112#[cfg(not(target_arch = "wasm32"))]
113pub fn spawn<F, T>(f: F) -> JoinHandle<T>
114where
115    F: FnOnce() -> T + Send + 'static,
116    T: Send + 'static,
117{
118    std::thread::spawn(f)
119}
120
121/// Number of active threads spawned via [`spawn_named`].
122///
123/// Incremented on spawn, decremented when the thread function returns.
124/// Used by thread-budget tests to count only kithara-owned threads.
125static ACTIVE_NAMED_THREADS: AtomicUsize = AtomicUsize::new(0);
126
127/// Returns the number of currently active threads spawned via [`spawn_named`].
128#[must_use]
129pub fn active_named_thread_count() -> usize {
130    ACTIVE_NAMED_THREADS.load(Ordering::Acquire)
131}
132
133/// Wrap `f` to bracket its execution with the named-thread counter —
134/// increments on entry (at call site, before spawn), decrements after the
135/// closure returns. Used by all [`spawn_named`] variants.
136fn counted<F, T>(f: F) -> impl FnOnce() -> T + Send + 'static
137where
138    F: FnOnce() -> T + Send + 'static,
139    T: Send + 'static,
140{
141    ACTIVE_NAMED_THREADS.fetch_add(1, Ordering::Release);
142    move || {
143        let result = f();
144        ACTIVE_NAMED_THREADS.fetch_sub(1, Ordering::Release);
145        result
146    }
147}
148
149/// Spawn a new named thread.
150///
151/// Sets the OS thread name and tracks the thread in [`active_named_thread_count`].
152/// The counter is decremented automatically when `f` returns.
153///
154/// # Panics
155///
156/// Panics if the OS refuses to create the thread.
157#[cfg(not(target_arch = "wasm32"))]
158pub fn spawn_named<F, T, N: Into<String>>(name: N, f: F) -> JoinHandle<T>
159where
160    F: FnOnce() -> T + Send + 'static,
161    T: Send + 'static,
162{
163    std::thread::Builder::new()
164        .name(name.into())
165        .spawn(counted(f))
166        .expect(
167            "BUG: spawn_named must succeed; thread::Builder only fails on OS resource exhaustion",
168        )
169}
170
171/// Spawn a new named thread (WASM variant).
172///
173/// # Panics
174///
175/// Panics if the OS refuses to create the thread.
176#[cfg(target_arch = "wasm32")]
177pub fn spawn_named<F, T, N: Into<String>>(name: N, f: F) -> JoinHandle<T>
178where
179    F: FnOnce() -> T + Send + 'static,
180    T: Send + 'static,
181{
182    let _name = name.into();
183    spawn(counted(f))
184}
185
186#[cfg(target_arch = "wasm32")]
187pub fn spawn<F, T>(f: F) -> JoinHandle<T>
188where
189    F: FnOnce() -> T + Send + 'static,
190    T: Send + 'static,
191{
192    /// The wasm-bindgen JS shim name (crate name with hyphens → underscores).
193    /// Workers use this to locate the JS module for `initSync`.
194    const SHIM_NAME: &str = "kithara-wasm";
195
196    WasmThreadBuilder::new()
197        .shim_name(SHIM_NAME.to_owned())
198        .spawn(move || {
199            console_error_panic_hook::set_once();
200            f()
201        })
202        .expect("BUG: WASM Worker spawn must succeed; only fails on OS resource exhaustion")
203}
204
205/// Block the current thread for at least `duration`.
206#[cfg(not(target_arch = "wasm32"))]
207#[inline]
208pub fn sleep(duration: Duration) {
209    std::thread::sleep(duration);
210}
211
212#[cfg(target_arch = "wasm32")]
213#[inline]
214pub fn sleep(duration: Duration) {
215    wasm_safe_thread::sleep(duration);
216}
217
218/// Block until the current thread is explicitly unparked.
219#[cfg(not(target_arch = "wasm32"))]
220#[inline]
221pub fn park() {
222    std::thread::park();
223}
224
225/// Block until the current thread is explicitly unparked.
226#[cfg(target_arch = "wasm32")]
227#[inline]
228pub fn park() {
229    wasm_safe_thread::park();
230}
231
232/// Block until unparked or until `duration` elapses.
233#[cfg(not(target_arch = "wasm32"))]
234#[inline]
235pub fn park_timeout(duration: Duration) {
236    std::thread::park_timeout(duration);
237}
238
239/// Block until unparked or until `duration` elapses.
240#[cfg(target_arch = "wasm32")]
241#[inline]
242pub fn park_timeout(duration: Duration) {
243    wasm_safe_thread::park_timeout(duration);
244}
245
246/// Hash of the current thread's ID, usable for shard indexing.
247#[cfg(not(target_arch = "wasm32"))]
248#[inline]
249#[must_use]
250pub fn current_thread_id() -> u64 {
251    let id = current().id();
252    let mut hasher = DefaultHasher::new();
253    id.hash(&mut hasher);
254    hasher.finish()
255}
256
257#[cfg(target_arch = "wasm32")]
258#[inline]
259#[must_use]
260pub fn current_thread_id() -> u64 {
261    let id = current().id();
262    let mut hasher = DefaultHasher::new();
263    id.hash(&mut hasher);
264    hasher.finish()
265}
266
267/// Returns the number of hardware threads available.
268#[cfg(not(target_arch = "wasm32"))]
269#[inline]
270#[must_use]
271pub fn available_parallelism() -> Option<std::num::NonZeroUsize> {
272    std::thread::available_parallelism().ok()
273}
274
275#[cfg(target_arch = "wasm32")]
276#[inline]
277#[must_use]
278pub fn available_parallelism() -> Option<std::num::NonZeroUsize> {
279    wasm_safe_thread::available_parallelism().ok()
280}
281
282#[cfg(test)]
283mod tests {
284    use std::time::Instant;
285
286    use kithara_test_utils::kithara;
287
288    use super::*;
289
290    #[kithara::test]
291    fn native_thread_detectors_are_consistent() {
292        #[cfg(not(target_arch = "wasm32"))]
293        {
294            assert!(is_main_thread());
295            assert!(!is_worker_thread());
296            assert_main_thread("native-main");
297            assert_not_main_thread("native-main");
298        }
299    }
300
301    #[kithara::test]
302    fn park_timeout_returns_after_unpark() {
303        #[cfg(not(target_arch = "wasm32"))]
304        {
305            let parked = current();
306            let start = Instant::now();
307            let join = spawn(move || {
308                sleep(Duration::from_millis(5));
309                parked.unpark();
310            });
311            park_timeout(Duration::from_secs(1));
312            join.join()
313                .expect("BUG: wake-helper thread joined cleanly without panicking");
314            assert!(start.elapsed() < Duration::from_millis(250));
315        }
316    }
317}