use once_cell::sync::OnceCell;
use slab::Slab;
use std::{
panic::{catch_unwind, AssertUnwindSafe},
sync::{mpsc, Arc},
thread::Result,
};
use crate::threading;
type SlabPtr = usize;
#[cfg(test)]
mod tests;
#[derive(Debug)]
pub struct ThreadGroup<Sched: ?Sized> {
state: Arc<threading::Mutex<State<Sched>>>,
}
impl<Sched: ?Sized> Clone for ThreadGroup<Sched> {
fn clone(&self) -> Self {
Self {
state: Arc::clone(&self.state),
}
}
}
#[derive(Debug)]
pub struct ThreadGroupJoinHandle {
result_recv: mpsc::Receiver<Result<()>>,
}
pub struct ThreadGroupLockGuard<'a, Sched: ?Sized> {
state_ref: &'a Arc<threading::Mutex<State<Sched>>>,
guard: threading::MutexGuard<'a, State<Sched>>,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub struct ThreadId(SlabPtr);
pub trait Scheduler: Send + 'static {
fn choose_next_thread(&mut self) -> Option<ThreadId>;
fn thread_exited(&mut self, thread_id: ThreadId) {
let _ = thread_id;
}
}
#[derive(Debug)]
struct State<Sched: ?Sized> {
threads: Slab<WorkerThread>,
num_threads: usize,
cur_thread_id: Option<ThreadId>,
shutting_down: bool,
result_send: mpsc::Sender<Result<()>>,
sched: Sched,
}
#[derive(Debug)]
struct WorkerThread {
join_handle: Option<threading::JoinHandle<()>>,
}
thread_local! {
static TLB: OnceCell<ThreadLocalBlock> = OnceCell::new();
}
struct ThreadLocalBlock {
thread_id: ThreadId,
state: Arc<threading::Mutex<State<dyn Scheduler>>>,
}
impl<Sched: Scheduler> ThreadGroup<Sched> {
pub fn new(sched: Sched) -> (Self, ThreadGroupJoinHandle) {
let (send, recv) = mpsc::channel();
let state = Arc::new(threading::Mutex::new(State {
threads: Slab::new(),
num_threads: 0,
cur_thread_id: None,
shutting_down: false,
result_send: send,
sched,
}));
(Self { state }, ThreadGroupJoinHandle { result_recv: recv })
}
}
impl ThreadGroupJoinHandle {
pub fn join(self) -> Result<()> {
self.result_recv.recv().unwrap()
}
}
impl<Sched: Scheduler + ?Sized> ThreadGroup<Sched> {
pub fn lock(&self) -> ThreadGroupLockGuard<'_, Sched> {
ThreadGroupLockGuard {
state_ref: &self.state,
guard: self.state.lock().unwrap(),
}
}
}
impl<'a, Sched: Scheduler> ThreadGroupLockGuard<'a, Sched> {
pub fn spawn(&mut self, f: impl FnOnce(ThreadId) + Send + 'static) -> ThreadId {
if self.guard.shutting_down && self.guard.num_threads == 0 {
panic!("thread group has already been shut down");
}
let state = Arc::clone(self.state_ref);
let ptr: SlabPtr = self
.guard
.threads
.insert(WorkerThread { join_handle: None });
let thread_id = ThreadId(ptr);
self.guard.num_threads += 1;
let join_handle = threading::spawn(move || {
let state2 = Arc::clone(&state);
TLB.with(|cell| {
cell.set(ThreadLocalBlock { thread_id, state })
.ok()
.unwrap()
});
threading::park();
let result = catch_unwind(AssertUnwindSafe(move || {
f(thread_id);
}));
finalize_thread(state2, thread_id, result);
});
self.guard.threads[ptr].join_handle = Some(join_handle);
log::trace!("created {:?}", thread_id);
thread_id
}
pub fn preempt(&mut self) {
assert!(
TLB.with(|cell| cell.get().is_none()),
"this method cannot be called from a worker thread"
);
let guard = &mut *self.guard;
log::trace!("preempting {:?}", guard.cur_thread_id);
if let Some(thread_id) = guard.cur_thread_id {
let join_handle = guard.threads[thread_id.0].join_handle.as_ref().unwrap();
join_handle.thread().park();
}
guard.unpark_next_thread();
}
pub fn shutdown(&mut self) {
if self.guard.shutting_down {
return;
}
log::trace!("shutdown requested");
self.guard.shutting_down = true;
if self.guard.num_threads == 0 {
self.guard.complete_shutdown();
} else {
log::trace!(
"shutdown is pending because there are {} thread(s) remaining",
self.guard.num_threads
);
}
}
}
impl<'a, Sched: Scheduler + ?Sized> ThreadGroupLockGuard<'a, Sched> {
pub fn scheduler(&mut self) -> &mut Sched {
&mut self.guard.sched
}
}
impl<Sched: Scheduler> State<Sched> {
fn unpark_next_thread(&mut self) {
(self as &mut State<dyn Scheduler>).unpark_next_thread();
}
fn complete_shutdown(&mut self) {
(self as &mut State<dyn Scheduler>).complete_shutdown();
}
}
impl State<dyn Scheduler> {
fn unpark_next_thread(&mut self) {
self.cur_thread_id = self.sched.choose_next_thread();
log::trace!("scheduling {:?}", self.cur_thread_id);
if let Some(thread_id) = self.cur_thread_id {
let join_handle = self.threads[thread_id.0].join_handle.as_ref().unwrap();
join_handle.thread().unpark();
}
}
fn complete_shutdown(&mut self) {
assert_eq!(self.num_threads, 0);
log::trace!("shutdown is complete");
let _ = self.result_send.send(Ok(()));
}
}
pub fn yield_now() {
let thread_group: Arc<threading::Mutex<State<dyn Scheduler>>> = TLB
.with(|cell| cell.get().map(|tlb| Arc::clone(&tlb.state)))
.expect("current thread does not belong to a thread group");
{
let mut state_guard = thread_group.lock().unwrap();
log::trace!("{:?} yielded the processor", state_guard.cur_thread_id);
state_guard.unpark_next_thread();
}
threading::park();
}
pub unsafe fn exit_thread() -> ! {
let (thread_id, thread_group) = TLB
.with(|cell| {
cell.get()
.map(|tlb| (tlb.thread_id, Arc::clone(&tlb.state)))
})
.expect("current thread does not belong to a thread group");
finalize_thread(thread_group, thread_id, Ok(()));
unsafe { threading::exit_thread() };
}
fn finalize_thread(
thread_group: Arc<threading::Mutex<State<dyn Scheduler>>>,
thread_id: ThreadId,
result: Result<()>,
) {
log::trace!("{:?} exited with result {:?}", thread_id, result);
let mut state_guard = thread_group.lock().unwrap();
state_guard.sched.thread_exited(thread_id);
state_guard.threads.remove(thread_id.0);
state_guard.num_threads -= 1;
if let Err(e) = result {
let _ = state_guard.result_send.send(Err(e));
return;
}
if state_guard.num_threads == 0 && state_guard.shutting_down {
state_guard.complete_shutdown();
return;
}
state_guard.unpark_next_thread();
}
pub fn current_thread() -> Option<ThreadId> {
TLB.with(|cell| cell.get().map(|tlb| tlb.thread_id))
}