jod_thread/
lib.rs

1//! **J**oin **O**n **D**rop thread (`jod_thread`) is a thin wrapper around `std::thread`,
2//! which makes sure that by default all threads are joined.
3
4use std::fmt;
5
6/// Like `thread::JoinHandle`, but joins the thread on drop and propagates
7/// panics by default.
8pub struct JoinHandle<T = ()>(Option<std::thread::JoinHandle<T>>);
9
10impl<T> fmt::Debug for JoinHandle<T> {
11    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
12        f.pad("JoinHandle { .. }")
13    }
14}
15
16impl<T> Drop for JoinHandle<T> {
17    fn drop(&mut self) {
18        if let Some(inner) = self.0.take() {
19            let res = inner.join();
20            if res.is_err() && !std::thread::panicking() {
21                res.unwrap();
22            }
23        }
24    }
25}
26
27impl<T> JoinHandle<T> {
28    pub fn thread(&self) -> &std::thread::Thread {
29        self.0.as_ref().unwrap().thread()
30    }
31    pub fn join(mut self) -> T {
32        let inner = self.0.take().unwrap();
33        inner.join().unwrap()
34    }
35    pub fn detach(mut self) -> std::thread::JoinHandle<T> {
36        let inner = self.0.take().unwrap();
37        inner
38    }
39}
40
41impl<T> From<std::thread::JoinHandle<T>> for JoinHandle<T> {
42    fn from(inner: std::thread::JoinHandle<T>) -> JoinHandle<T> {
43        JoinHandle(Some(inner))
44    }
45}
46
47#[derive(Debug)]
48pub struct Builder(std::thread::Builder);
49
50impl Builder {
51    pub fn new() -> Builder {
52        Builder(std::thread::Builder::new())
53    }
54
55    pub fn name(self, name: String) -> Builder {
56        Builder(self.0.name(name))
57    }
58
59    pub fn stack_size(self, size: usize) -> Builder {
60        Builder(self.0.stack_size(size))
61    }
62
63    pub fn spawn<F, T>(self, f: F) -> std::io::Result<JoinHandle<T>>
64    where
65        F: FnOnce() -> T,
66        F: Send + 'static,
67        T: Send + 'static,
68    {
69        self.0.spawn(f).map(JoinHandle::from)
70    }
71}
72
73pub fn spawn<F, T>(f: F) -> JoinHandle<T>
74where
75    F: FnOnce() -> T,
76    F: Send + 'static,
77    T: Send + 'static,
78{
79    Builder::new().spawn(f).expect("failed to spawn thread")
80}
81
82#[test]
83fn smoke() {
84    use std::sync::atomic::{AtomicU32, Ordering};
85
86    static COUNTER: AtomicU32 = AtomicU32::new(0);
87
88    drop(spawn(|| COUNTER.fetch_add(1, Ordering::SeqCst)));
89    assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
90
91    let res = std::panic::catch_unwind(|| {
92        let _handle = Builder::new()
93            .name("panicky".to_string())
94            .spawn(|| COUNTER.fetch_add(1, Ordering::SeqCst))
95            .unwrap();
96        panic!("boom")
97    });
98    assert!(res.is_err());
99
100    assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
101
102    let res = std::panic::catch_unwind(|| {
103        let handle = spawn(|| panic!("boom"));
104        let () = handle.join();
105    });
106    assert!(res.is_err());
107
108    let res = std::panic::catch_unwind(|| {
109        let _handle = spawn(|| panic!("boom"));
110    });
111    assert!(res.is_err());
112}