thread-group 1.0.0

std::thread::ThreadGroup prototype
use std::fmt;
use std::io;
use std::marker::PhantomData;
use std::mem;
use std::panic;
use std::sync::{Arc, Mutex};
use std::thread;

use crate::wait_group::WaitGroup;
use cfg_if::cfg_if;
use panic::catch_unwind;

type SharedVec<T> = Arc<Mutex<Vec<T>>>;
type SharedOption<T> = Arc<Mutex<Option<T>>>;

/// A group for spawning threads.
pub struct ThreadGroup<'env, R> {
    /// The list of the thread join handles.
    handles: SharedVec<SharedOption<thread::JoinHandle<()>>>,

    /// Used to wait until all subgroups all dropped.
    wait_group: WaitGroup,

    res: Option<thread::Result<R>>,

    /// Borrows data with invariant lifetime `'env`.
    _marker: PhantomData<&'env mut &'env ()>,
}

impl<'env, R> ThreadGroup<'env, R> {
    /// Create a new instance of `ThreadGroup`.
    pub fn new<F>(mut f: F) -> Self
    where
        F: FnMut(&Self) -> R,
    {
        let mut group = ThreadGroup::<'env> {
            handles: SharedVec::default(),
            wait_group: WaitGroup::new(),
            res: None,
            _marker: PhantomData,
        };

        // Execute the groupd function, but catch any panics.
        let res = panic::catch_unwind(panic::AssertUnwindSafe(|| f(&group)));
        group.res = Some(res);

        group
    }

    /// Wait for all threads to finish.
    pub fn join(self) -> thread::Result<R> {
        if let Some(Err(err)) = self.res {
            // TODO: cancel all other threads.
            return Err(err);
        }

        if let Some(Err(err)) = self.res {
            // TODO: cancel all other threads.
            return Err(err);
        }

        // Wait until all nested groups are dropped.
        if let Err(id) = self.wait_group.wait() {
            for handle in self.handles.lock().unwrap().iter() {
                let handle = handle.lock().unwrap().take().unwrap();
                if id == handle.thread().id() {
                    let err = handle.join().unwrap_err();
                    // TODO: cancel all other threads.
                    return Err(err);
                }
            }
        };

        // Join all remaining spawned threads.
        let panics: Vec<_> = self
            .handles
            .lock()
            .unwrap()
            // Filter handles that haven't been joined, join them, and collect errors.
            .drain(..)
            .filter_map(|handle| handle.lock().unwrap().take())
            .filter_map(|handle| handle.join().err())
            .collect();

        // If `f` has panicked, resume unwinding.
        // If any of the child threads have panicked, return the panic errors.
        // Otherwise, everything is OK and return the result of `f`.
        match self.res.unwrap() {
            Err(err) => panic::resume_unwind(err),
            Ok(res) => {
                if panics.is_empty() {
                    Ok(res)
                } else {
                    Err(Box::new(panics))
                }
            }
        }
    }
}

unsafe impl<R> Sync for ThreadGroup<'_, R> {}

impl<'env, R: 'env + Send> ThreadGroup<'env, R> {
    /// Spawns a groupd thread.
    pub fn spawn<'group, F, T>(&'group self, f: F)
    where
        F: FnOnce(&ThreadGroup<'env, R>) -> T,
        F: Send + 'env,
        T: Send + 'env,
    {
        // TODO: this requires a join right here.
        self.builder()
            .spawn(f)
            .expect("failed to spawn thread in group")
    }

    /// Creates a builder that can configure a thread before spawning.
    pub fn builder<'group>(&'group self) -> ThreadGroupBuilder<'group, 'env, R> {
        ThreadGroupBuilder {
            group: self,
            builder: thread::Builder::new(),
        }
    }
}

impl<R> fmt::Debug for ThreadGroup<'_, R> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.pad("ThreadGroup { .. }")
    }
}

/// Configures the properties of a new thread.
#[derive(Debug)]
pub struct ThreadGroupBuilder<'group, 'env, R> {
    group: &'group ThreadGroup<'env, R>,
    builder: thread::Builder,
}

impl<'group, 'env, R: 'env + Send> ThreadGroupBuilder<'group, 'env, R> {
    /// Sets the name for the new thread.
    pub fn name(mut self, name: String) -> Self {
        self.builder = self.builder.name(name);
        self
    }

    /// Sets the size of the stack for the new thread.
    pub fn stack_size(mut self, size: usize) -> Self {
        self.builder = self.builder.stack_size(size);
        self
    }

    /// Spawns a grouped thread with this configuration.
    pub fn spawn<F, T>(self, f: F) -> io::Result<()>
    where
        F: FnOnce(&ThreadGroup<'env, R>) -> T,
        F: Send + 'env,
        T: Send + 'env,
    {
        // The result of `f` will be stored here.
        let result = SharedOption::default();
        let result = Arc::clone(&result);

        // A clone of the group that will be moved into the new thread.
        let group = ThreadGroup::<'env> {
            handles: Arc::clone(&self.group.handles),
            wait_group: self.group.wait_group.clone(),
            res: None,
            _marker: PhantomData,
        };

        // Create the closure which gets spawned.
        let closure = move || {
            // Make sure the group is inside the closure with the proper `'env` lifetime.
            let group: ThreadGroup<'env, R> = group;

            // Run the closure and store the result if the closure didn't panic.
            match catch_unwind(panic::AssertUnwindSafe(|| f(&group))) {
                Ok(res) => *result.lock().unwrap() = Some(res),
                Err(err) => {
                    group.wait_group.set_panic_id(thread::current().id());
                    panic::resume_unwind(err);
                }
            };
        };

        // Allocate `closure` on the heap and erase the `'env` bound.
        let closure: Box<dyn FnOnce() + Send + 'env> = Box::new(closure);
        let closure: Box<dyn FnOnce() + Send + 'static> = unsafe { mem::transmute(closure) };

        // Finally, spawn the closure.
        let handle = self.builder.spawn(move || closure())?;
        let handle = Arc::new(Mutex::new(Some(handle)));

        // Add the handle to the shared list of join handles.
        self.group.handles.lock().unwrap().push(handle);

        Ok(())
    }
}

unsafe impl<T> Send for JoinHandle<'_, T> {}
unsafe impl<T> Sync for JoinHandle<'_, T> {}

/// A handle that can be used to join its groupd thread.
pub struct JoinHandle<'group, T> {
    /// A join handle to the spawned thread.
    handle: SharedOption<thread::JoinHandle<()>>,

    /// Holds the result of the inner closure.
    result: SharedOption<T>,

    /// A handle to the the spawned thread.
    thread: thread::Thread,

    /// Borrows the parent group with lifetime `'group`.
    _marker: PhantomData<&'group ()>,
}

impl<T> JoinHandle<'_, T> {
    /// Waits for the thread to finish and returns its result.
    ///
    /// If the child thread panics, an error is returned.
    pub fn join(self) -> thread::Result<T> {
        // Take out the handle. The handle will surely be available because the root group waits
        // for nested groups before joining remaining threads.
        let handle = self.handle.lock().unwrap().take().unwrap();

        // Join the thread and then take the result out of its inner closure.
        handle
            .join()
            .map(|()| self.result.lock().unwrap().take().unwrap())
    }

    /// Returns a handle to the underlying thread.
    pub fn thread(&self) -> &thread::Thread {
        &self.thread
    }
}

cfg_if! {
    if #[cfg(unix)] {
        use std::os::unix::thread::{JoinHandleExt, RawPthread};

        impl<T> JoinHandleExt for ScopedJoinHandle<'_, T> {
            fn as_pthread_t(&self) -> RawPthread {
                // Borrow the handle. The handle will surely be available because the root group waits
                // for nested groups before joining remaining threads.
                let handle = self.handle.lock().unwrap();
                handle.as_ref().unwrap().as_pthread_t()
            }
            fn into_pthread_t(self) -> RawPthread {
                self.as_pthread_t()
            }
        }
    } else if #[cfg(windows)] {
        use std::os::windows::io::{AsRawHandle, IntoRawHandle, RawHandle};

        impl<T> AsRawHandle for JoinHandle<'_, T> {
            fn as_raw_handle(&self) -> RawHandle {
                // Borrow the handle. The handle will surely be available because the root group waits
                // for nested groups before joining remaining threads.
                let handle = self.handle.lock().unwrap();
                handle.as_ref().unwrap().as_raw_handle()
            }
        }

        impl<T> IntoRawHandle for JoinHandle<'_, T> {
            fn into_raw_handle(self) -> RawHandle {
                self.as_raw_handle()
            }
        }
    }
}

impl<T> fmt::Debug for JoinHandle<'_, T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.pad("JoinHandle { .. }")
    }
}