use crate::flog::{flog, FloggableDebug};
use nix::sys::signal::{SigSet, SigmaskHow, Signal};
use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::time::Duration;
impl FloggableDebug for std::thread::ThreadId {}
static MAIN_THREAD_ID: OnceLock<usize> = OnceLock::new();
const THREAD_ASSERTS_CFG_FOR_TESTING: bool = cfg!(test);
static IS_FORKED_PROC: AtomicBool = AtomicBool::new(false);
const IO_WAIT_FOR_WORK_DURATION: Duration = Duration::from_millis(500);
type WorkItem = Box<dyn FnOnce() + 'static + Send>;
pub fn init() {
MAIN_THREAD_ID
.set(thread_id())
.expect("threads::init() must only be called once (at startup)!");
extern "C" fn child_post_fork() {
IS_FORKED_PROC.store(true, Ordering::Relaxed);
}
unsafe {
let result = libc::pthread_atfork(None, None, Some(child_post_fork));
assert_eq!(result, 0, "pthread_atfork() failure: {}", errno::errno());
}
}
#[inline(always)]
fn main_thread_id() -> usize {
#[cold]
fn init_not_called() -> ! {
panic!("threads::init() was not called at startup!");
}
match MAIN_THREAD_ID.get() {
None => init_not_called(),
Some(id) => *id,
}
}
#[inline(always)]
fn thread_id() -> usize {
static THREAD_COUNTER: AtomicUsize = AtomicUsize::new(1);
thread_local! {
static THREAD_ID: usize = THREAD_COUNTER.fetch_add(1, Ordering::Relaxed);
}
let id = THREAD_ID.with(|id| *id);
debug_assert_ne!(id, 0, "TLS storage not initialized!");
id
}
#[inline(always)]
pub fn is_main_thread() -> bool {
thread_id() == main_thread_id()
}
#[inline(always)]
pub fn assert_is_main_thread() {
#[cold]
fn not_main_thread() -> ! {
panic!("Function is not running on the main thread!");
}
if !is_main_thread() && !THREAD_ASSERTS_CFG_FOR_TESTING {
not_main_thread();
}
}
#[inline(always)]
pub fn assert_is_background_thread() {
#[cold]
fn not_background_thread() -> ! {
panic!("Function is not allowed to be called on the main thread!");
}
if is_main_thread() && !THREAD_ASSERTS_CFG_FOR_TESTING {
not_background_thread();
}
}
pub fn is_forked_child() -> bool {
IS_FORKED_PROC.load(Ordering::Relaxed)
}
#[inline(always)]
pub fn assert_is_not_forked_child() {
#[cold]
fn panic_is_forked_child() {
panic!("Function called from forked child!");
}
if is_forked_child() {
panic_is_forked_child();
}
}
pub fn spawn<F: FnOnce() + Send + 'static>(callback: F) -> bool {
let saved_set = {
let new_set = {
let mut set = SigSet::all();
set.remove(Signal::SIGILL); set.remove(Signal::SIGFPE); set.remove(Signal::SIGBUS); set.remove(Signal::SIGSEGV); set.remove(Signal::SIGSTOP); set.remove(Signal::SIGKILL); set
};
new_set
.thread_swap_mask(SigmaskHow::SIG_BLOCK)
.expect("Failed to override thread signal mask!")
};
let result = match std::thread::Builder::new().spawn(callback) {
Ok(handle) => {
let thread_id = thread_id();
flog!(iothread, "rust thread", thread_id, "spawned");
drop(handle);
true
}
Err(e) => {
eprintf!("rust thread spawn failure: %s\n", e);
false
}
};
saved_set
.thread_set_mask()
.expect("Failed to restore thread signal mask!");
result
}
#[derive(Default)]
struct ThreadPoolProtected {
pub request_queue: std::collections::VecDeque<WorkItem>,
pub total_threads: usize,
pub waiting_threads: usize,
}
pub struct ThreadPool {
shared: Mutex<ThreadPoolProtected>,
cond_var: std::sync::Condvar,
soft_min_threads: usize,
max_threads: usize,
}
impl std::fmt::Debug for ThreadPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ThreadPool")
.field("min_threads", &self.soft_min_threads)
.field("max_threads", &self.max_threads)
.finish()
}
}
impl ThreadPool {
pub fn new(soft_min_threads: usize, max_threads: usize) -> Arc<Self> {
Arc::new(ThreadPool {
shared: Default::default(),
cond_var: Default::default(),
soft_min_threads,
max_threads,
})
}
pub fn perform<F: FnOnce() + 'static + Send>(self: &Arc<Self>, func: F) -> usize {
let work_item = Box::new(func);
enum ThreadAction {
None,
Wake,
Spawn,
}
let local_thread_count;
let thread_action = {
let mut data = self.shared.lock().expect("Mutex poisoned!");
local_thread_count = data.total_threads;
data.request_queue.push_back(work_item);
flog!(
iothread,
"enqueuing work item (count is ",
data.request_queue.len(),
")"
);
if data.waiting_threads >= data.request_queue.len() {
ThreadAction::Wake
} else if data.total_threads < self.max_threads {
data.total_threads += 1;
ThreadAction::Spawn
} else {
ThreadAction::None
}
};
match thread_action {
ThreadAction::None => (),
ThreadAction::Wake => {
flog!(iothread, "notifying thread ", std::thread::current().id());
self.cond_var.notify_one();
}
ThreadAction::Spawn => {
if self.spawn_thread() {
flog!(iothread, "pthread spawned");
} else {
self.shared.lock().expect("Mutex poisoned!").total_threads -= 1;
}
}
}
local_thread_count
}
fn spawn_thread(self: &Arc<Self>) -> bool {
let pool = Arc::clone(self);
self::spawn(move || {
pool.run_worker();
})
}
}
pub struct MainThread<T> {
data: T,
_marker: PhantomData<*const ()>,
}
unsafe impl<T: 'static> Send for MainThread<T> {}
unsafe impl<T: 'static> Sync for MainThread<T> {}
impl<T> MainThread<T> {
pub const fn new(value: T) -> Self {
Self {
data: value,
_marker: PhantomData,
}
}
pub fn get(&self) -> &T {
assert_is_main_thread();
&self.data
}
}
impl ThreadPool {
fn run_worker(&self) {
while let Some(work_item) = self.dequeue_work_or_commit_to_exit() {
flog!(
iothread,
"pthread ",
std::thread::current().id(),
" got work"
);
work_item();
}
flog!(
iothread,
"pthread ",
std::thread::current().id(),
" exiting"
);
}
fn dequeue_work_or_commit_to_exit(&self) -> Option<WorkItem> {
let mut data = self.shared.lock().expect("Mutex poisoned!");
if data.request_queue.is_empty()
&& data.total_threads == self.soft_min_threads
&& IO_WAIT_FOR_WORK_DURATION > Duration::ZERO
{
data.waiting_threads += 1;
data = self
.cond_var
.wait_timeout(data, IO_WAIT_FOR_WORK_DURATION)
.expect("Mutex poisoned!")
.0;
data.waiting_threads -= 1;
}
let result = data.request_queue.pop_front();
if result.is_none() {
data.total_threads -= 1;
}
result
}
}
#[cfg(test)]
mod tests {
use nix::sys::signal::{SigSet, SigmaskHow, Signal};
use super::{spawn, thread_id};
use std::sync::{
atomic::{AtomicI32, Ordering},
Arc, Condvar, Mutex,
};
use std::time::Duration;
#[test]
fn test_thread_ids() {
let start_thread_id = thread_id();
assert_eq!(start_thread_id, thread_id());
let spawned_thread_id = std::thread::spawn(thread_id).join();
assert_ne!(start_thread_id, spawned_thread_id.unwrap());
}
#[test]
fn std_thread_inherits_sigmask() {
let (saved_set, t1_set) = {
let saved_set = {
let new_set = {
let mut set = SigSet::empty();
set.add(Signal::SIGILL); set
};
new_set
.thread_swap_mask(SigmaskHow::SIG_BLOCK)
.expect("Failed to set thread mask!")
};
let t1_set = SigSet::empty()
.thread_swap_mask(SigmaskHow::SIG_UNBLOCK)
.expect("Failed to get own altered thread mask!");
(saved_set, t1_set)
};
let t2_set = std::thread::scope(|_| {
SigSet::empty()
.thread_swap_mask(SigmaskHow::SIG_BLOCK)
.expect("Failed to get existing sigmask for new thread")
});
assert_eq!(t1_set, t2_set);
saved_set
.thread_set_mask()
.expect("Failed to restore sigmask!");
}
#[test]
fn test_pthread() {
struct Context {
val: AtomicI32,
condvar: Condvar,
}
let ctx = Arc::new(Context {
val: AtomicI32::new(3),
condvar: Condvar::new(),
});
let mutex = Mutex::new(());
let ctx2 = ctx.clone();
let made = spawn(move || {
ctx2.val.fetch_add(2, Ordering::Release);
ctx2.condvar.notify_one();
printf!("condvar signalled\n");
});
assert!(made);
let lock = mutex.lock().unwrap();
let (_lock, timeout) = ctx
.condvar
.wait_timeout_while(lock, Duration::from_secs(5), |()| {
printf!("looping with lock held\n");
if ctx.val.load(Ordering::Acquire) != 5 {
printf!("test_pthread: value did not yet reach goal\n");
return true;
}
false
})
.unwrap();
assert!(
!timeout.timed_out(),
concat!(
"Timeout waiting for condition variable to be notified! ",
"Does the platform support signalling a condvar without the mutex held?"
)
);
}
}