#![deny(unsafe_code)]
use std::fmt;
use std::mem;
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
use std::sync::MutexGuard;
use std::sync::{Arc, Condvar, Mutex, OnceLock};
use std::thread;
use std::time::Duration;
use std::time::Instant;
#[derive(Debug)]
pub struct SharedThread<T> {
state: Mutex<State<T>>,
exit_signal: Arc<ExitSignal>,
output: OnceLock<T>,
}
#[derive(Debug)]
struct ExitSignal {
mutex: Mutex<bool>,
condvar: Condvar,
}
enum State<T> {
Running(thread::JoinHandle<T>),
Exited,
Panicked,
}
use State::*;
impl<T> fmt::Debug for State<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Running { .. } => write!(f, "Running"),
Exited => write!(f, "Exited"),
Panicked => write!(f, "Panicked"),
}
}
}
impl<T: Send + 'static> SharedThread<T> {
pub fn spawn<F>(f: F) -> Self
where
F: FnOnce() -> T + Send + 'static,
{
let exit_signal = Arc::new(ExitSignal {
mutex: Mutex::new(false),
condvar: Condvar::new(),
});
let exit_signal_clone = Arc::clone(&exit_signal);
let handle = thread::spawn(move || {
let unwind_result = catch_unwind(AssertUnwindSafe(f));
let mut guard = lock_ignoring_poison(&exit_signal_clone.mutex);
*guard = true;
exit_signal_clone.condvar.notify_all();
match unwind_result {
Ok(return_value) => return_value,
Err(panic) => resume_unwind(panic),
}
});
SharedThread {
state: Mutex::new(Running(handle)),
exit_signal,
output: OnceLock::new(),
}
}
}
impl<T> SharedThread<T> {
fn join_exited_thread(&self, exit_signal_guard: MutexGuard<bool>) -> &T {
debug_assert!(*exit_signal_guard, "the thread exited");
let mut state_guard = lock_ignoring_poison(&self.state);
match &*state_guard {
Running(_) => {}
Exited => return self.output.get().unwrap(),
Panicked => panic!("shared thread panicked"),
};
let Running(handle) = mem::replace(&mut *state_guard, Panicked) else {
unreachable!()
};
match handle.join() {
Ok(return_value) => {
let set_result = self.output.set(return_value);
assert!(set_result.is_ok(), "output must be previously unset");
*state_guard = Exited;
self.output.get().unwrap()
}
Err(panic) => resume_unwind(panic),
}
}
pub fn join(&self) -> &T {
let mut exit_signal_guard = lock_ignoring_poison(&self.exit_signal.mutex);
while !*exit_signal_guard {
exit_signal_guard = wait_ignoring_poison(&self.exit_signal.condvar, exit_signal_guard);
}
self.join_exited_thread(exit_signal_guard)
}
pub fn join_timeout(&self, timeout: Duration) -> Option<&T> {
let deadline = Instant::now() + timeout;
self.join_deadline(deadline)
}
pub fn join_deadline(&self, deadline: Instant) -> Option<&T> {
let mut exit_signal_guard = lock_ignoring_poison(&self.exit_signal.mutex);
while !*exit_signal_guard {
if Instant::now() > deadline {
return None;
}
exit_signal_guard = wait_deadline_ignoring_poison(
&self.exit_signal.condvar,
exit_signal_guard,
deadline,
);
}
Some(self.join_exited_thread(exit_signal_guard))
}
pub fn try_join(&self) -> Option<&T> {
let exit_signal_guard = lock_ignoring_poison(&self.exit_signal.mutex);
if *exit_signal_guard {
Some(self.join_exited_thread(exit_signal_guard))
} else {
None
}
}
pub fn into_output(self) -> T {
self.join();
self.output.into_inner().expect("should be set")
}
pub fn is_finished(&self) -> bool {
*lock_ignoring_poison(&self.exit_signal.mutex)
}
}
fn lock_ignoring_poison<T>(mutex: &Mutex<T>) -> MutexGuard<'_, T> {
match mutex.lock() {
Ok(guard) => guard,
Err(e) => e.into_inner(),
}
}
fn wait_ignoring_poison<'guard, T>(
condvar: &Condvar,
guard: MutexGuard<'guard, T>,
) -> MutexGuard<'guard, T> {
match condvar.wait(guard) {
Ok(guard) => guard,
Err(e) => e.into_inner(),
}
}
fn wait_deadline_ignoring_poison<'guard, T>(
condvar: &Condvar,
guard: MutexGuard<'guard, T>,
deadline: Instant,
) -> MutexGuard<'guard, T> {
let timeout = deadline.saturating_duration_since(Instant::now());
match condvar.wait_timeout(guard, timeout) {
Ok((guard, _)) => guard,
Err(e) => e.into_inner().0,
}
}
#[cfg(test)]
mod test {
use super::*;
use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
#[test]
fn test_join_and_try_join() {
static STOP_FLAG: AtomicBool = AtomicBool::new(false);
let bg_thread = SharedThread::spawn(|| {
while !STOP_FLAG.load(Relaxed) {}
42
});
thread::scope(|scope| {
let mut joiner_handles = Vec::new();
for _ in 0..10 {
joiner_handles.push(scope.spawn(|| {
bg_thread.join();
}));
}
for _ in 0..100 {
assert!(bg_thread.try_join().is_none());
assert!(!bg_thread.is_finished());
}
STOP_FLAG.store(true, Relaxed);
while !bg_thread.is_finished() {}
assert_eq!(bg_thread.try_join(), Some(&42));
});
}
#[test]
fn test_try_join_only() {
static STOP_FLAG: AtomicBool = AtomicBool::new(false);
let bg_thread = SharedThread::spawn(|| {
while !STOP_FLAG.load(Relaxed) {}
42
});
for _ in 0..100 {
assert!(bg_thread.try_join().is_none());
}
STOP_FLAG.store(true, Relaxed);
while bg_thread.try_join().is_none() {}
assert_eq!(bg_thread.try_join(), Some(&42));
}
#[test]
fn test_into_inner() {
let thread = SharedThread::spawn(|| String::from("foo"));
let result: String = thread.into_output();
assert_eq!(result, "foo");
}
#[test]
fn test_panic_messages() {
let thread = SharedThread::spawn(|| panic!("original message"));
let panic_error = catch_unwind(|| thread.join()).unwrap_err();
assert_eq!(panic_error.downcast_ref(), Some(&"original message"));
let second_panic_error = catch_unwind(|| thread.join()).unwrap_err();
assert_eq!(
second_panic_error.downcast_ref(),
Some(&"shared thread panicked"),
);
}
}