1use std::fmt;
5
6pub struct JoinHandle<T = ()>(Option<std::thread::JoinHandle<T>>);
9
10impl<T> fmt::Debug for JoinHandle<T> {
11 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
12 f.pad("JoinHandle { .. }")
13 }
14}
15
16impl<T> Drop for JoinHandle<T> {
17 fn drop(&mut self) {
18 if let Some(inner) = self.0.take() {
19 let res = inner.join();
20 if res.is_err() && !std::thread::panicking() {
21 res.unwrap();
22 }
23 }
24 }
25}
26
27impl<T> JoinHandle<T> {
28 pub fn thread(&self) -> &std::thread::Thread {
29 self.0.as_ref().unwrap().thread()
30 }
31 pub fn join(mut self) -> T {
32 let inner = self.0.take().unwrap();
33 inner.join().unwrap()
34 }
35 pub fn detach(mut self) -> std::thread::JoinHandle<T> {
36 let inner = self.0.take().unwrap();
37 inner
38 }
39}
40
41impl<T> From<std::thread::JoinHandle<T>> for JoinHandle<T> {
42 fn from(inner: std::thread::JoinHandle<T>) -> JoinHandle<T> {
43 JoinHandle(Some(inner))
44 }
45}
46
47#[derive(Debug)]
48pub struct Builder(std::thread::Builder);
49
50impl Builder {
51 pub fn new() -> Builder {
52 Builder(std::thread::Builder::new())
53 }
54
55 pub fn name(self, name: String) -> Builder {
56 Builder(self.0.name(name))
57 }
58
59 pub fn stack_size(self, size: usize) -> Builder {
60 Builder(self.0.stack_size(size))
61 }
62
63 pub fn spawn<F, T>(self, f: F) -> std::io::Result<JoinHandle<T>>
64 where
65 F: FnOnce() -> T,
66 F: Send + 'static,
67 T: Send + 'static,
68 {
69 self.0.spawn(f).map(JoinHandle::from)
70 }
71}
72
73pub fn spawn<F, T>(f: F) -> JoinHandle<T>
74where
75 F: FnOnce() -> T,
76 F: Send + 'static,
77 T: Send + 'static,
78{
79 Builder::new().spawn(f).expect("failed to spawn thread")
80}
81
82#[test]
83fn smoke() {
84 use std::sync::atomic::{AtomicU32, Ordering};
85
86 static COUNTER: AtomicU32 = AtomicU32::new(0);
87
88 drop(spawn(|| COUNTER.fetch_add(1, Ordering::SeqCst)));
89 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
90
91 let res = std::panic::catch_unwind(|| {
92 let _handle = Builder::new()
93 .name("panicky".to_string())
94 .spawn(|| COUNTER.fetch_add(1, Ordering::SeqCst))
95 .unwrap();
96 panic!("boom")
97 });
98 assert!(res.is_err());
99
100 assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
101
102 let res = std::panic::catch_unwind(|| {
103 let handle = spawn(|| panic!("boom"));
104 let () = handle.join();
105 });
106 assert!(res.is_err());
107
108 let res = std::panic::catch_unwind(|| {
109 let _handle = spawn(|| panic!("boom"));
110 });
111 assert!(res.is_err());
112}