ctx_thread/
lib.rs

1#![doc(test(
2    no_crate_inject,
3    attr(
4        deny(warnings, rust_2018_idioms),
5        allow(dead_code, unused_assignments, unused_variables)
6    )
7))]
8#![deny(missing_docs, missing_debug_implementations, rust_2018_idioms)]
9
10//! Threads that run within a context.
11//!
12//! Most of the time, threads that outlive the parent thread are considered a code smell.
13//! Ctx-thread ensures that all threads are joined before returning from the scope. Child threads
14//! have access to the Context object, which they can use to poll the status of the thread group.
15//! If one of the threads panics, the context is cancelled.
16//!
17//! # Scope
18//!
19//! This library is based on the [crossbeam](https://docs.rs/crossbeam/0.8.0/crossbeam/)'s scoped threads:
20//!
21//! ```
22//! use ctx_thread::scope;
23//!
24//! let people = vec![
25//!     "Alice".to_string(),
26//!     "Bob".to_string(),
27//!     "Carol".to_string(),
28//! ];
29//!
30//! scope(|ctx| {
31//!     for person in &people {
32//!         ctx.spawn(move |_| {
33//!             println!("Hello, {}", person);
34//!         });
35//!     }
36//! }).unwrap();
37//! ```
38//!
39//! # Context
40//!
41//! Aside from referring to the outer scope, threads may check the extra methods and return if
42//! necessary:
43//!
44//! ```
45//! use ctx_thread::scope;
46//!
47//!
48//! scope(|ctx| {
49//!     ctx.spawn(|ctx| {
50//!         while ctx.active() {
51//!             // do work
52//!         }
53//!     });
54//!
55//!     ctx.spawn(|ctx| {
56//!         ctx.cancel();
57//!     });
58//! }).unwrap();
59//! ```
60//! Note that these context based cancellations are a form of cooperative scheduling. Threads
61//! can still block even if a context expires.
62
63use std::fmt;
64use std::io;
65use std::marker::PhantomData;
66use std::mem;
67use std::panic;
68use std::sync::{Arc, Mutex};
69use std::thread;
70
71use cfg_if::cfg_if;
72use crossbeam_utils::sync::WaitGroup;
73use std::sync::atomic::{AtomicBool, Ordering};
74
75type SharedVec<T> = Arc<Mutex<Vec<T>>>;
76type SharedOption<T> = Arc<Mutex<Option<T>>>;
77
78/// Creates a new scope for spawning threads.
79///
80/// All child threads that haven't been manually joined will be automatically joined just before
81/// this function invocation ends. If all joined threads have successfully completed, `Ok` is
82/// returned with the return value of `f`. If any of the joined threads has panicked, an `Err` is
83/// returned containing errors from panicked threads.
84///
85/// # Examples
86///
87/// ```
88/// use ctx_thread::scope;
89///
90/// let var = vec![1, 2, 3];
91///
92/// scope(|ctx| {
93///     ctx.spawn(|_| {
94///         println!("A child thread borrowing `var`: {:?}", var);
95///     });
96/// }).unwrap();
97/// ```
98pub fn scope<'env, F, R>(f: F) -> thread::Result<R>
99where
100    F: FnOnce(&Context<'env>) -> R,
101{
102    let wg = WaitGroup::new();
103
104    let ctx = Context::<'env> {
105        done: Arc::new(AtomicBool::new(false)),
106        handles: SharedVec::default(),
107        wait_group: wg.clone(),
108        _marker: PhantomData,
109    };
110
111    // Execute the scoped function, but catch any panics.
112    let result = panic::catch_unwind(panic::AssertUnwindSafe(|| f(&ctx)));
113
114    // Signal to any remaining threads that the context is done if f panicked.
115    if result.is_err() {
116        ctx.cancel();
117    }
118
119    // Wait until all nested scopes are dropped.
120    drop(ctx.wait_group);
121    wg.wait();
122
123    // Join all remaining spawned threads.
124    let panics: Vec<_> = ctx
125        .handles
126        .lock()
127        .unwrap()
128        // Filter handles that haven't been joined, join them, and collect errors.
129        .drain(..)
130        .filter_map(|handle| handle.lock().unwrap().take())
131        .filter_map(|handle| handle.join().err())
132        .collect();
133
134    // If `f` has panicked, resume unwinding.
135    // If any of the child threads have panicked, return the panic errors.
136    // Otherwise, everything is OK and return the result of `f`.
137    match result {
138        Err(err) => panic::resume_unwind(err),
139        Ok(res) => {
140            if panics.is_empty() {
141                Ok(res)
142            } else {
143                Err(Box::new(panics))
144            }
145        }
146    }
147}
148
149/// The context in which threads run, including their scope and thread group status.
150pub struct Context<'env> {
151    done: Arc<AtomicBool>,
152
153    /// The list of the thread join handles.
154    handles: SharedVec<SharedOption<thread::JoinHandle<()>>>,
155
156    /// Used to wait until all subscopes all dropped.
157    wait_group: WaitGroup,
158
159    /// Borrows data with invariant lifetime `'env`.
160    _marker: PhantomData<&'env mut &'env ()>,
161}
162
163unsafe impl Sync for Context<'_> {}
164
165impl<'env> Context<'env> {
166    /// Check if the current context has finished. Threads performing work should regularly check
167    /// and return early if cancellation has been signalled. Usually this indicates some critical
168    /// failure in a sibling thread, thus making the result of the current thread inconsequential.
169    ///
170    /// # Examples
171    ///
172    /// ```rust
173    /// use ctx_thread::scope;
174    ///
175    /// scope(|ctx| {
176    ///     ctx.spawn(|ctx| {
177    ///         assert_eq!(ctx.active(), !ctx.done());
178    ///         ctx.spawn(|ctx| {
179    ///            ctx.cancel()
180    ///         });
181    ///
182    ///         while ctx.active() {}
183    ///     });
184    /// }).unwrap();
185    /// ```
186    pub fn done(&self) -> bool {
187        self.done.load(Ordering::Relaxed)
188    }
189
190    /// Signals cancellation of the current context, causing [done] to return true. A cancelled
191    /// context cannot be re-enabled.
192    /// [done]: Context::done
193    pub fn cancel(&self) {
194        self.done.store(true, Ordering::Relaxed)
195    }
196
197    /// Alias for !ctx.done(); which is easier on the eyes in for loops.
198    pub fn active(&self) -> bool {
199        !self.done()
200    }
201
202    /// Spawns a scoped thread, providing a derived context.
203    ///
204    /// This method is similar to the [`spawn`] function in Rust's standard library. The difference
205    /// is that this thread is scoped, meaning it's guaranteed to terminate before the scope exits,
206    /// allowing it to reference variables outside the scope.
207    ///
208    /// The scoped thread is passed a reference to this scope as an argument, which can be used for
209    /// spawning nested threads.
210    ///
211    /// The returned [handle](ContextJoinHandle) can be used to manually
212    /// [join](ContextJoinHandle::join) the thread before the scope exits.
213    ///
214    /// This will create a thread using default parameters of [`ScopedThreadBuilder`], if you want to specify the
215    /// stack size or the name of the thread, use this API instead.
216    ///
217    /// [`spawn`]: std::thread::spawn
218    ///
219    /// # Panics
220    ///
221    /// Panics if the OS fails to create a thread; use [`ScopedThreadBuilder::spawn`]
222    /// to recover from such errors.
223    ///
224    /// # Examples
225    ///
226    /// ```
227    /// use ctx_thread::scope;
228    ///
229    /// scope(|ctx| {
230    ///     let handle = ctx.spawn(|_| {
231    ///         println!("A child thread is running");
232    ///         42
233    ///     });
234    ///
235    ///     // Join the thread and retrieve its result.
236    ///     let res = handle.join().unwrap();
237    ///     assert_eq!(res, 42);
238    /// }).unwrap();
239    /// ```
240    pub fn spawn<'scope, F, T>(&'scope self, f: F) -> ContextJoinHandle<'scope, T>
241    where
242        F: FnOnce(&Context<'env>) -> T,
243        F: Send + 'env,
244        T: Send + 'env,
245    {
246        self.builder()
247            .spawn(|ctx| {
248                let result = panic::catch_unwind(panic::AssertUnwindSafe(|| f(ctx)));
249                if let Err(e) = result {
250                    ctx.cancel();
251                    panic::resume_unwind(e)
252                }
253
254                result.unwrap()
255            })
256            .expect("failed to spawn scoped thread")
257    }
258
259    /// Creates a builder that can configure a thread before spawning.
260    ///
261    /// # Examples
262    ///
263    /// ```
264    /// use ctx_thread::scope;
265    ///
266    /// scope(|ctx| {
267    ///     ctx.builder()
268    ///         .name(String::from("child"))
269    ///         .stack_size(1024)
270    ///         .spawn(|_| println!("A child thread is running"))
271    ///         .unwrap();
272    /// }).unwrap();
273    /// ```
274    pub fn builder<'scope>(&'scope self) -> ContextThreadBuilder<'scope, 'env> {
275        ContextThreadBuilder {
276            scope: self,
277            builder: thread::Builder::new(),
278        }
279    }
280}
281
282impl fmt::Debug for Context<'_> {
283    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
284        f.pad("Scope { .. }")
285    }
286}
287
288/// Configures the properties of a new thread.
289///
290/// The two configurable properties are:
291///
292/// - [`name`]: Specifies an [associated name for the thread][naming-threads].
293/// - [`stack_size`]: Specifies the [desired stack size for the thread][stack-size].
294///
295/// The [`spawn`] method will take ownership of the builder and return an [`io::Result`] of the
296/// thread handle with the given configuration.
297///
298/// The [`Context::spawn`] method uses a builder with default configuration and unwraps its return
299/// value. You may want to use this builder when you want to recover from a failure to launch a
300/// thread.
301///
302/// # Examples
303///
304/// ```
305/// use ctx_thread::scope;
306///
307/// scope(|ctx| {
308///     ctx.builder()
309///         .spawn(|_| println!("Running a child thread"))
310///         .unwrap();
311/// }).unwrap();
312/// ```
313///
314/// [`name`]: ContextThreadBuilder::name
315/// [`stack_size`]: ContextThreadBuilder::stack_size
316/// [`spawn`]: ContextThreadBuilder::spawn
317/// [`io::Result`]: std::io::Result
318/// [naming-threads]: std::thread#naming-threads
319/// [stack-size]: std::thread#stack-size
320#[derive(Debug)]
321pub struct ContextThreadBuilder<'scope, 'env> {
322    scope: &'scope Context<'env>,
323    builder: thread::Builder,
324}
325
326impl<'scope, 'env> ContextThreadBuilder<'scope, 'env> {
327    /// Sets the name for the new thread.
328    ///
329    /// The name must not contain null bytes (`\0`).
330    ///
331    /// For more information about named threads, see [here][naming-threads].
332    pub fn name(mut self, name: String) -> ContextThreadBuilder<'scope, 'env> {
333        self.builder = self.builder.name(name);
334        self
335    }
336
337    /// Sets the size of the stack for the new thread.
338    ///
339    /// The stack size is measured in bytes.
340    ///
341    /// For more information about the stack size for threads, see [here][stack-size].
342    pub fn stack_size(mut self, size: usize) -> ContextThreadBuilder<'scope, 'env> {
343        self.builder = self.builder.stack_size(size);
344        self
345    }
346
347    /// Spawns a scoped thread with this configuration, providing a derived context.
348    ///
349    /// The scoped thread is passed a reference to this scope as an argument, which can be used for
350    /// spawning nested threads.
351    ///
352    /// The returned handle can be used to manually join the thread before the scope exits.
353    ///
354    /// # Errors
355    ///
356    /// Unlike the [`Scope::spawn`] method, this method yields an
357    /// [`io::Result`] to capture any failure to create the thread at
358    /// the OS level.
359    ///
360    /// [`io::Result`]: std::io::Result
361    ///
362    /// # Panics
363    ///
364    /// Panics if a thread name was set and it contained null bytes.
365    ///
366    /// # Examples
367    ///
368    /// ```
369    /// use ctx_thread::scope;
370    ///
371    /// scope(|ctx| {
372    ///     let handle = ctx.builder()
373    ///         .spawn(|_| {
374    ///             println!("A child thread is running");
375    ///             42
376    ///         })
377    ///         .unwrap();
378    ///
379    ///     // Join the thread and retrieve its result.
380    ///     let res = handle.join().unwrap();
381    ///     assert_eq!(res, 42);
382    /// }).unwrap();
383    /// ```
384    pub fn spawn<F, T>(self, f: F) -> io::Result<ContextJoinHandle<'scope, T>>
385    where
386        F: FnOnce(&Context<'env>) -> T,
387        F: Send + 'env,
388        T: Send + 'env,
389    {
390        // The result of `f` will be stored here.
391        let result = SharedOption::default();
392
393        // Spawn the thread and grab its join handle and thread handle.
394        let (handle, thread) = {
395            let result = Arc::clone(&result);
396
397            // A clone of the context that will be moved into the new thread.
398            let ctx = Context::<'env> {
399                done: self.scope.done.clone(),
400                handles: Arc::clone(&self.scope.handles),
401                wait_group: self.scope.wait_group.clone(),
402                _marker: PhantomData,
403            };
404
405            // Spawn the thread.
406            let handle = {
407                let closure = move || {
408                    // Make sure the scope is inside the closure with the proper `'env` lifetime.
409                    let scope: Context<'env> = ctx;
410
411                    // Run the closure.
412                    let res = f(&scope);
413
414                    // Store the result if the closure didn't panic.
415                    *result.lock().unwrap() = Some(res);
416                };
417
418                // Allocate `closure` on the heap and erase the `'env` bound.
419                let closure: Box<dyn FnOnce() + Send + 'env> = Box::new(closure);
420                let closure: Box<dyn FnOnce() + Send + 'static> =
421                    unsafe { mem::transmute(closure) };
422
423                // Finally, spawn the closure.
424                self.builder.spawn(move || closure())?
425            };
426
427            let thread = handle.thread().clone();
428            let handle = Arc::new(Mutex::new(Some(handle)));
429            (handle, thread)
430        };
431
432        // Add the handle to the shared list of join handles.
433        self.scope.handles.lock().unwrap().push(Arc::clone(&handle));
434
435        Ok(ContextJoinHandle {
436            handle,
437            result,
438            thread,
439            _marker: PhantomData,
440        })
441    }
442}
443
444unsafe impl<T> Send for ContextJoinHandle<'_, T> {}
445unsafe impl<T> Sync for ContextJoinHandle<'_, T> {}
446
447/// A handle that can be used to join its context thread.
448///
449/// This struct is created by the [`Context::spawn`] method and the
450/// [`ContextJoinHandle::spawn`] method.
451pub struct ContextJoinHandle<'scope, T> {
452    /// A join handle to the spawned thread.
453    handle: SharedOption<thread::JoinHandle<()>>,
454
455    /// Holds the result of the inner closure.
456    result: SharedOption<T>,
457
458    /// A handle to the the spawned thread.
459    thread: thread::Thread,
460
461    /// Borrows the parent scope with lifetime `'scope`.
462    _marker: PhantomData<&'scope ()>,
463}
464
465impl<T> ContextJoinHandle<'_, T> {
466    /// Waits for the thread to finish and returns its result.
467    ///
468    /// If the child thread panics, an error is returned.
469    ///
470    /// # Panics
471    ///
472    /// This function may panic on some platforms if a thread attempts to join itself or otherwise
473    /// may create a deadlock with joining threads.
474    ///
475    /// # Examples
476    ///
477    /// ```
478    /// use ctx_thread::scope;
479    ///
480    /// scope(|ctx| {
481    ///     let handle1 = ctx.spawn(|_| println!("I'm a happy thread :)"));
482    ///     let handle2 = ctx.spawn(|_| panic!("I'm a sad thread :("));
483    ///
484    ///     // Join the first thread and verify that it succeeded.
485    ///     let res = handle1.join();
486    ///     assert!(res.is_ok());
487    ///
488    ///     // Join the second thread and verify that it panicked.
489    ///     let res = handle2.join();
490    ///     assert!(res.is_err());
491    /// }).unwrap();
492    /// ```
493    pub fn join(self) -> thread::Result<T> {
494        // Take out the handle. The handle will surely be available because the root scope waits
495        // for nested scopes before joining remaining threads.
496        let handle = self.handle.lock().unwrap().take().unwrap();
497
498        // Join the thread and then take the result out of its inner closure.
499        handle
500            .join()
501            .map(|()| self.result.lock().unwrap().take().unwrap())
502    }
503
504    /// Returns a handle to the underlying thread.
505    pub fn thread(&self) -> &thread::Thread {
506        &self.thread
507    }
508}
509
510cfg_if! {
511    if #[cfg(unix)] {
512        use std::os::unix::thread::{JoinHandleExt, RawPthread};
513
514        impl<T> JoinHandleExt for ContextJoinHandle<'_, T> {
515            fn as_pthread_t(&self) -> RawPthread {
516                // Borrow the handle. The handle will surely be available because the root scope waits
517                // for nested scopes before joining remaining threads.
518                let handle = self.handle.lock().unwrap();
519                handle.as_ref().unwrap().as_pthread_t()
520            }
521            fn into_pthread_t(self) -> RawPthread {
522                self.as_pthread_t()
523            }
524        }
525    } else if #[cfg(windows)] {
526        use std::os::windows::io::{AsRawHandle, IntoRawHandle, RawHandle};
527
528        impl<T> AsRawHandle for ContextJoinHandle<'_, T> {
529            fn as_raw_handle(&self) -> RawHandle {
530                // Borrow the handle. The handle will surely be available because the root scope waits
531                // for nested scopes before joining remaining threads.
532                let handle = self.handle.lock().unwrap();
533                handle.as_ref().unwrap().as_raw_handle()
534            }
535        }
536
537        #[cfg(windows)]
538        impl<T> IntoRawHandle for ContextJoinHandle<'_, T> {
539            fn into_raw_handle(self) -> RawHandle {
540                self.as_raw_handle()
541            }
542        }
543    }
544}
545
546impl<T> fmt::Debug for ContextJoinHandle<'_, T> {
547    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
548        f.pad(&format!(
549            "ScopedJoinHandle {{ name: {:?} }}",
550            self.thread.name()
551        ))
552    }
553}
554
555#[cfg(test)]
556mod tests {
557    use super::*;
558
559    #[test]
560    fn test_cancellation_nested() {
561        scope(|ctx| {
562            ctx.spawn(|ctx| while !ctx.done() {});
563
564            ctx.spawn(|ctx| {
565                while ctx.active() {
566                    ctx.spawn(|ctx| ctx.cancel());
567                }
568            });
569        })
570        .unwrap()
571    }
572
573    #[test]
574    #[should_panic]
575    fn test_panic_cancellation() {
576        scope(|ctx| {
577            ctx.spawn(|_| panic!());
578            assert!(ctx.active())
579        })
580        .unwrap()
581    }
582}