thread_amount/
lib.rs

1#![warn(clippy::pedantic)]
2
3use std::num::NonZeroUsize;
4
5#[cfg_attr(any(target_os = "macos", target_os = "ios"), path = "osx.rs")]
6#[cfg_attr(target_os = "freebsd", path = "freebsd.rs")]
7#[cfg_attr(target_os = "linux", path = "linux.rs")]
8#[cfg_attr(target_family = "windows", path = "windows.rs")]
9mod implementation;
10
11/// Gets the amount of threads for the current process.
12/// Returns `None` if there are no threads.
13#[must_use]
14pub fn thread_amount() -> Option<NonZeroUsize> {
15    implementation::thread_amount()
16}
17
18/// Check if the current process is single-threaded.
19#[must_use]
20pub fn is_single_threaded() -> bool {
21    implementation::is_single_threaded()
22}
23
24#[cfg(test)]
25mod tests {
26    use std::num::NonZeroUsize;
27    use std::sync::{Arc, Barrier};
28    use std::thread;
29    use std::time::Duration;
30
31    use super::*;
32
33    #[track_caller]
34    fn wait_for_count_to_stabilize(expected: usize) {
35        let mut current = 0;
36
37        // Poll for up to 2.5 seconds
38        for _ in 0..50 {
39            current = thread_amount().map_or(0, NonZeroUsize::get);
40            if current == expected {
41                return;
42            }
43
44            thread::sleep(Duration::from_millis(50));
45        }
46
47        panic!(
48            "Timed out waiting for thread count to stabilize at {expected}. Last count: {current}"
49        );
50    }
51
52    mod thread_amount_tests {
53        use super::*;
54
55        #[test]
56        fn spawn_increases_count() {
57            let initial = thread_amount().unwrap().get();
58            let barrier = Arc::new(Barrier::new(2));
59            let c_barrier = barrier.clone();
60
61            let handle = thread::spawn(move || {
62                c_barrier.wait(); // Wait for main thread to check
63                c_barrier.wait(); // Wait for main thread to release
64            });
65
66            barrier.wait(); // Wait for spawned thread to be active
67            let new_count = thread_amount().unwrap().get();
68            assert_eq!(new_count, initial + 1);
69
70            barrier.wait();
71            handle.join().unwrap();
72
73            // Ensure count returns to baseline
74            wait_for_count_to_stabilize(initial);
75        }
76
77        #[test]
78        fn many_threads_simultaneously() {
79            let initial = thread_amount().unwrap().get();
80            let num_threads = 5;
81            let barrier = Arc::new(Barrier::new(num_threads + 1));
82            let mut handles = Vec::new();
83
84            for _ in 0..num_threads {
85                let c_barrier = barrier.clone();
86                handles.push(thread::spawn(move || {
87                    c_barrier.wait(); // Sync start
88                    c_barrier.wait(); // Sync end
89                }));
90            }
91
92            barrier.wait(); // All threads are now running
93            wait_for_count_to_stabilize(initial + num_threads);
94
95            barrier.wait(); // Release all threads
96            for handle in handles {
97                handle.join().unwrap();
98            }
99
100            wait_for_count_to_stabilize(initial);
101        }
102
103        #[test]
104        fn nested_spawning() {
105            let initial = thread_amount().unwrap().get();
106            let barrier = Arc::new(Barrier::new(2));
107            let b_clone = barrier.clone();
108
109            let h1 = thread::spawn(move || {
110                // Thread 1 spawns Thread 2
111                let nested_h = thread::spawn(move || {
112                    b_clone.wait(); // Wait A: active
113                    b_clone.wait(); // Wait B: exit
114                });
115
116                nested_h.join().unwrap();
117            });
118
119            barrier.wait(); // Wait A: Both threads should be active
120            let expected = initial + 2; // Main + h1 + nested_h
121            wait_for_count_to_stabilize(expected);
122
123            barrier.wait(); // Release and cleanup
124            h1.join().unwrap();
125
126            wait_for_count_to_stabilize(initial);
127        }
128
129        #[test]
130
131        fn count_decreases_after_join() {
132            let initial = thread_amount().unwrap().get();
133            let h = thread::spawn(|| thread::sleep(Duration::from_millis(50)));
134
135            // Wait for it to be running
136            wait_for_count_to_stabilize(initial + 1);
137
138            h.join().unwrap();
139
140            // Ensure it goes back down
141            wait_for_count_to_stabilize(initial);
142        }
143
144        #[test]
145
146        fn rapid_churn() {
147            let initial = thread_amount().unwrap().get();
148            for _ in 0..50 {
149                thread::spawn(|| {}).join().unwrap();
150            }
151            wait_for_count_to_stabilize(initial);
152        }
153
154        #[test]
155        fn named_threads_are_counted() {
156            let initial = thread_amount().unwrap().get();
157            let barrier = Arc::new(Barrier::new(2));
158            let c_barrier = barrier.clone();
159
160            let h = thread::Builder::new()
161                .name("test-worker".into())
162                .spawn(move || {
163                    c_barrier.wait();
164                    c_barrier.wait();
165                })
166                .unwrap();
167
168            barrier.wait();
169            wait_for_count_to_stabilize(initial + 1);
170
171            barrier.wait();
172            h.join().unwrap();
173            wait_for_count_to_stabilize(initial);
174        }
175
176        #[test]
177        fn panicking_thread_decrements_count() {
178            let initial = thread_amount().unwrap().get();
179            let h = thread::spawn(|| panic!("Intentional panic for testing"));
180            let _ = h.join();
181
182            wait_for_count_to_stabilize(initial);
183        }
184    }
185
186    mod is_single_threaded_tests {
187        use super::*;
188
189        #[test]
190        fn lifecycle_is_relative() {
191            // Establish baseline for THIS test run
192            let initial = thread_amount().unwrap().get();
193
194            // Only test the `true` case if we start at 1
195            if initial == 1 {
196                assert!(is_single_threaded(), "Should be true when count is 1");
197            }
198
199            let barrier = Arc::new(Barrier::new(2));
200            let c_barrier = barrier.clone();
201
202            let h = thread::spawn(move || {
203                c_barrier.wait(); // Sync 1: Alive
204                c_barrier.wait(); // Sync 2: Ready to exit
205            });
206
207            barrier.wait(); // Wait for new thread to be definitely active
208
209            // Count MUST be higher now
210            wait_for_count_to_stabilize(initial + 1);
211            assert!(!is_single_threaded(), "Cannot be single-threaded with active child");
212
213            // Finish child thread
214            barrier.wait();
215            h.join().unwrap();
216
217            // Wait for count to return to original baseline
218            wait_for_count_to_stabilize(initial);
219
220            if initial == 1 {
221                assert!(is_single_threaded(), "Should return to true");
222            }
223        }
224
225        #[test]
226        fn test_lifecycle_relative_to_baseline() {
227            let initial_count = thread_amount().unwrap().get();
228            let initial_state = is_single_threaded();
229
230            // We can only test the 'true' case if the baseline happens to be 1
231            if initial_count == 1 {
232                assert!(initial_state, "Test started at 1, so state should be true");
233            } else {
234                assert!(!initial_state, "Test started at >1, so state should be false");
235            }
236
237            let barrier = Arc::new(Barrier::new(2));
238            let c_barrier = barrier.clone();
239
240            // 2. Spawn a new thread
241            let h = thread::spawn(move || {
242                c_barrier.wait(); // Sync 1: Alive
243                c_barrier.wait(); // Sync 2: Ready to exit
244            });
245
246            // 3. Wait for the new thread to be active
247            barrier.wait();
248            wait_for_count_to_stabilize(initial_count + 1); // State is now baseline + 1
249
250            // We are *definitely* multi-threaded now
251            assert!(!is_single_threaded(), "Should be false when multi-threaded");
252
253            // 4. Finish child thread
254            barrier.wait();
255            h.join().unwrap();
256
257            // 5. Wait for count to return to the original baseline
258            wait_for_count_to_stabilize(initial_count);
259
260            // The state should be restored to whatever it was at the start
261            assert_eq!(
262                is_single_threaded(),
263                initial_state,
264                "State should be restored to initial state"
265            );
266        }
267
268        #[test]
269        fn test_state_is_restored_after_panic() {
270            let initial_count = thread_amount().unwrap().get();
271            let initial_state = is_single_threaded();
272
273            let h = thread::spawn(|| {
274                panic!("Intentional panic to test thread cleanup");
275            });
276
277            // Catch the panic
278            let _ = h.join();
279
280            // Wait for the OS to reap the thread
281            wait_for_count_to_stabilize(initial_count);
282
283            // The state should be restored to whatever it was before the test.
284            assert_eq!(
285                is_single_threaded(),
286                initial_state,
287                "State should be restored after panicking thread is joined"
288            );
289        }
290
291        #[test]
292        fn test_state_with_many_threads() {
293            let initial_count = thread_amount().unwrap().get();
294            let initial_state = is_single_threaded();
295            let num_threads = 10;
296
297            let barrier = Arc::new(Barrier::new(num_threads + 1));
298            let mut handles = Vec::new();
299
300            for _ in 0..num_threads {
301                let c_barrier = barrier.clone();
302                handles.push(thread::spawn(move || {
303                    c_barrier.wait(); // All threads sync here
304                    c_barrier.wait(); // All threads wait to exit
305                }));
306            }
307
308            // Wait for all threads to be active
309            barrier.wait();
310            wait_for_count_to_stabilize(initial_count + num_threads);
311
312            // We are definitely multi-threaded now
313            assert!(!is_single_threaded(), "Should be false with 10 active threads");
314
315            // Release threads
316            barrier.wait();
317
318            for h in handles {
319                h.join().unwrap();
320            }
321
322            // Wait for all threads to be joined
323            wait_for_count_to_stabilize(initial_count);
324
325            // State should return to original
326            assert_eq!(
327                is_single_threaded(),
328                initial_state,
329                "State should be restored after many threads are joined"
330            );
331        }
332    }
333}