use std::{
mem::MaybeUninit,
sync::{
atomic::{AtomicIsize, Ordering},
mpsc, Arc,
},
thread,
};
use winapi::um::{
errhandlingapi, handleapi, processthreadsapi, synchapi,
winbase::INFINITE,
winnt::{DUPLICATE_SAME_ACCESS, HANDLE},
};
pub unsafe fn exit_thread() -> ! {
unsafe { processthreadsapi::ExitThread(0) };
unreachable!();
}
pub use std::thread::ThreadId;
#[derive(Debug)]
pub struct JoinHandle<T> {
#[allow(dead_code)]
std_handle: thread::JoinHandle<T>,
thread: Thread,
}
pub fn spawn(f: impl FnOnce() + Send + 'static) -> JoinHandle<()> {
let (send, recv) = mpsc::channel();
let std_handle = thread::spawn(move || {
let _ = send.send(THREAD_DATA.with(Arc::clone));
f()
});
let data = recv.recv().unwrap();
let thread = Thread { data };
JoinHandle { std_handle, thread }
}
impl<T> JoinHandle<T> {
pub fn thread(&self) -> &Thread {
&self.thread
}
}
thread_local! {
static THREAD_DATA: Arc<ThreadData> = Arc::new(ThreadData {
token_count: AtomicIsize::new(0),
hthread: current_hthread(),
remote_op_mutex: Mutex::new(()),
});
}
#[derive(Debug, Clone)]
pub struct Thread {
data: Arc<ThreadData>,
}
#[derive(Debug)]
struct ThreadData {
token_count: AtomicIsize,
hthread: HANDLE,
remote_op_mutex: Mutex<()>,
}
unsafe impl Send for ThreadData {}
unsafe impl Sync for ThreadData {}
#[allow(dead_code)]
pub fn current() -> Thread {
Thread {
data: THREAD_DATA.with(Arc::clone),
}
}
pub fn park() {
THREAD_DATA.with(|td| {
let token_count_cell = &td.token_count;
let mut token_count = token_count_cell.fetch_sub(1, Ordering::Relaxed) - 1;
while token_count < 0 {
unsafe {
synchapi::WaitOnAddress(
token_count_cell.as_mut_ptr().cast(), (&token_count) as *const _ as *mut _, std::mem::size_of::<isize>(), INFINITE, );
}
token_count = token_count_cell.load(Ordering::Relaxed);
}
})
}
impl Thread {
pub fn unpark(&self) {
let _guard = self.data.remote_op_mutex.lock().unwrap();
let token_count_cell = &self.data.token_count;
if token_count_cell.fetch_add(1, Ordering::Relaxed) == -1 {
unsafe { synchapi::WakeByAddressAll(token_count_cell.as_mut_ptr().cast()) };
unsafe { processthreadsapi::ResumeThread(self.data.hthread) };
}
}
pub fn park(&self) {
let _guard = self.data.remote_op_mutex.lock().unwrap();
let token_count_cell = &self.data.token_count;
if token_count_cell.fetch_sub(1, Ordering::Relaxed) == 0 {
unsafe { processthreadsapi::SuspendThread(self.data.hthread) };
unsafe {
processthreadsapi::GetThreadContext(
self.data.hthread,
MaybeUninit::uninit().as_mut_ptr(),
);
}
}
}
}
fn current_hthread() -> HANDLE {
let cur_pseudo_hthread = unsafe { processthreadsapi::GetCurrentThread() };
let cur_hprocess = unsafe { processthreadsapi::GetCurrentProcess() };
let mut cur_hthread = MaybeUninit::uninit();
assert_win32_ok(unsafe {
handleapi::DuplicateHandle(
cur_hprocess,
cur_pseudo_hthread, cur_hprocess,
cur_hthread.as_mut_ptr(), 0, 0, DUPLICATE_SAME_ACCESS,
)
});
assert_win32_nonnull(unsafe { cur_hthread.assume_init() })
}
fn assert_win32_ok<T: Default + PartialEq<T> + Copy>(b: T) {
if b == T::default() {
panic_last_error();
}
}
fn assert_win32_nonnull<T: IsNull>(b: T) -> T {
if b.is_null() {
panic_last_error();
}
b
}
trait IsNull {
fn is_null(&self) -> bool;
}
impl<T: ?Sized> IsNull for *const T {
fn is_null(&self) -> bool {
(*self).is_null()
}
}
impl<T: ?Sized> IsNull for *mut T {
fn is_null(&self) -> bool {
(*self).is_null()
}
}
#[cold]
fn panic_last_error() -> ! {
panic!("Win32 error 0x{:08x}", unsafe {
errhandlingapi::GetLastError()
});
}
pub use mutex::{Mutex, MutexGuard};
mod mutex {
use std::{
cell::UnsafeCell,
fmt,
sync::atomic::{AtomicBool, Ordering},
};
use winapi::um::{synchapi, winbase::INFINITE};
pub struct Mutex<T: ?Sized> {
locked: AtomicBool,
data: UnsafeCell<T>,
}
unsafe impl<T: ?Sized + Send> Sync for Mutex<T> {}
pub struct MutexGuard<'a, T: ?Sized> {
data: &'a mut T,
locked: &'a AtomicBool,
}
impl<T> Mutex<T> {
#[inline]
pub const fn new(x: T) -> Self {
Self {
data: UnsafeCell::new(x),
locked: AtomicBool::new(false),
}
}
}
impl<T: ?Sized> Mutex<T> {
#[inline]
pub fn lock(&self) -> Result<MutexGuard<'_, T>, ()> {
while self
.locked
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_err()
{
unsafe {
synchapi::WaitOnAddress(
self.locked.as_mut_ptr().cast(), (&true) as *const bool as *mut _, std::mem::size_of::<bool>(), INFINITE, );
}
}
Ok(MutexGuard {
data: unsafe { &mut *self.data.get() },
locked: &self.locked,
})
}
}
impl<T: ?Sized + fmt::Debug> fmt::Debug for Mutex<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("Mutex")
}
}
impl<T: ?Sized> Drop for MutexGuard<'_, T> {
#[inline]
fn drop(&mut self) {
self.locked.store(false, Ordering::Release);
unsafe { synchapi::WakeByAddressSingle(self.locked.as_mut_ptr().cast()) };
}
}
impl<T: ?Sized> std::ops::Deref for MutexGuard<'_, T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
self.data
}
}
impl<T: ?Sized> std::ops::DerefMut for MutexGuard<'_, T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
self.data
}
}
}