#![allow(clippy::disallowed_types, reason = "Need Arc to hold queue")]
use crate::runtime::{QueuedObject, UnwindPolicy};
use alloc::sync::{Arc, Weak};
use arbitrary_int::prelude::*;
use atomic::Atomic;
use core::cell::Cell;
use core::mem::ManuallyDrop;
use core::num::{NonZeroU16, NonZeroUsize};
use core::sync::atomic::AtomicBool;
use core::sync::atomic::Ordering;
use crossbeam_queue::SegQueue;
use std::thread::AccessError;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[repr(transparent)]
pub struct UniqueThreadId(NonZeroUsize);
impl UniqueThreadId {
const MIN: UniqueThreadId = UniqueThreadId(NonZeroUsize::new(1).unwrap());
#[inline]
#[track_caller]
pub fn from_index(index: usize) -> Self {
UniqueThreadId(
Self::MIN
.0
.checked_add(index)
.expect("impossible to have more than usize::MAX - 1 threads"),
)
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, bytemuck::NoUninit)]
#[repr(u8)]
enum ThreadStateFlag {
Live = 0,
QueuedObjects,
Dying,
Dead,
}
impl ThreadStateFlag {
#[inline]
pub fn is_live(self) -> bool {
match self {
ThreadStateFlag::Live | ThreadStateFlag::QueuedObjects => true,
ThreadStateFlag::Dying | ThreadStateFlag::Dead => false,
}
}
#[inline]
pub fn is_dead_or_dying(self) -> bool {
!self.is_live()
}
}
pub struct ObjectQueue {
short_id: ShortThreadId,
inner: SegQueue<QueuedObject>,
}
impl ObjectQueue {
#[inline]
pub unsafe fn new(short_id: ShortThreadId) -> ObjectQueue {
ObjectQueue {
short_id,
inner: SegQueue::new(),
}
}
#[inline]
pub unsafe fn push(&self, object: QueuedObject) {
self.inner.push(object);
}
#[cold]
#[inline(never)]
unsafe fn do_process(&self) {
while let Some(object) = self.inner.pop() {
unsafe {
super::explicit_merge(self.short_id, object);
}
}
}
}
impl Drop for ObjectQueue {
fn drop(&mut self) {
unsafe { self.do_process() }
}
}
pub struct SharedThreadInfo {
_id: UniqueThreadId,
short_id: ShortThreadId,
state_flag: Atomic<ThreadStateFlag>,
queued_objects: Weak<ObjectQueue>,
}
impl SharedThreadInfo {
#[inline]
pub fn get_by_id(id: ShortThreadId) -> Option<&'static SharedThreadInfo> {
THREADS.get(id.index())?.ok()
}
#[cold]
#[inline(never)]
pub unsafe fn queue_object(&self, object: QueuedObject) {
nounwind::abort_unwind(|| {
match self.queued_objects.upgrade() {
Some(queue) => {
unsafe {
queue.push(object);
}
let _ = self.state_flag.compare_exchange(
ThreadStateFlag::Live,
ThreadStateFlag::QueuedObjects,
Ordering::Relaxed,
Ordering::Relaxed,
);
}
None => {
let this_state = self.state_flag.load(Ordering::Acquire);
assert!(
this_state.is_dead_or_dying(),
"thread is {this_state:?} but has no queue"
);
unsafe {
super::explicit_merge(self.short_id, object);
}
}
}
});
}
}
pub struct LocalThreadState {
shared_info: &'static SharedThreadInfo,
short_id: ShortThreadId,
queue: ManuallyDrop<Arc<ObjectQueue>>,
}
impl LocalThreadState {
#[inline]
pub fn short_id(&self) -> ShortThreadId {
self.short_id
}
#[cold]
#[inline(never)]
pub fn init_tid() -> Option<ShortThreadId> {
nounwind::abort_unwind(|| LocalThreadState::with_current(LocalThreadState::short_id).ok())
}
#[inline]
pub fn with_current<R>(
func: impl FnOnce(&LocalThreadState) -> R,
) -> Result<R, LocalThreadAccessError> {
match THIS_THREAD_STATE.try_with(|this| match this {
Ok(state) => Ok(func(state)),
Err(error) => Err(*error),
}) {
Ok(Ok(res)) => Ok(res),
Ok(Err(ThreadStateInitError::IdOverflow(cause))) => {
Err(LocalThreadAccessError::IdOverflow(cause))
}
Ok(Err(ThreadStateInitError::AlreadyDied)) | Err(AccessError { .. }) => {
Err(LocalThreadAccessError::Dead)
}
}
}
#[inline]
pub fn existing_short_id() -> Result<ShortThreadId, LocalThreadAccessError> {
match THIS_THREAD_STATE_FAST.with(|fast| (fast.status.get(), fast.short_id.get())) {
(LocalThreadStatus::Uninit, None) => Err(LocalThreadAccessError::Uninitialized),
(LocalThreadStatus::DeadOrDying, None) => Err(LocalThreadAccessError::Dead),
(LocalThreadStatus::Active, Some(short_id)) => Ok(short_id),
(_, Some(_)) | (LocalThreadStatus::Active, None) => {
unsafe { core::hint::unreachable_unchecked() }
}
}
}
#[inline]
#[expect(
clippy::manual_unwrap_or_default,
reason = "clearer handling of AccessError"
)]
pub fn currently_needs_collect() -> bool {
match THIS_THREAD_STATE_FAST.try_with(|fast| {
!matches!(
fast.shared_state_flag.get().load(Ordering::Relaxed),
ThreadStateFlag::Live
)
}) {
Ok(res) => res,
Err(AccessError { .. }) => {
false
}
}
}
#[cold]
#[inline(never)]
pub(super) fn collect_slow<T: UnwindPolicy>() {
let _ = Self::with_current(|state| {
T::maybe_abort_unwind(|| {
if std::thread::panicking() {
return;
}
if !matches!(
state.shared_info.state_flag.load(Ordering::Relaxed),
ThreadStateFlag::Live
) {
state.collect_force();
}
});
});
}
#[cold]
#[inline(never)]
pub fn collect_force(&self) {
let this_state = self.shared_info.state_flag.load(Ordering::Acquire);
match this_state {
ThreadStateFlag::Live | ThreadStateFlag::QueuedObjects => {
let Some(queue) = Weak::upgrade(&self.shared_info.queued_objects) else {
return;
};
loop {
unsafe {
queue.do_process();
}
let _ = self.shared_info.state_flag.compare_exchange(
ThreadStateFlag::QueuedObjects,
ThreadStateFlag::Live,
Ordering::AcqRel,
Ordering::Acquire,
);
if queue.inner.is_empty() {
break;
}
}
}
ThreadStateFlag::Dead | ThreadStateFlag::Dying => {
}
}
}
}
impl Drop for LocalThreadState {
fn drop(&mut self) {
THIS_THREAD_STATE_FAST.with(|fast| {
assert_eq!(
fast.status.replace(LocalThreadStatus::DeadOrDying),
LocalThreadStatus::Active,
);
assert_eq!(fast.short_id.replace(None), Some(self.short_id));
fast.shared_state_flag.set(&DUMMY_STATE_FLAG);
});
let old_state = self
.shared_info
.state_flag
.swap(ThreadStateFlag::Dying, Ordering::SeqCst);
assert!(old_state.is_live(), "Cannot destroy a {old_state:?} thread");
unsafe { ManuallyDrop::drop(&mut self.queue) };
assert_eq!(
self.shared_info.state_flag.compare_exchange(
ThreadStateFlag::Dying,
ThreadStateFlag::Dead,
Ordering::SeqCst,
Ordering::SeqCst,
),
Ok(ThreadStateFlag::Dying),
);
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, bytemuck::NoUninit)]
#[repr(u8)]
enum LocalThreadStatus {
DeadOrDying,
Uninit,
Active,
}
static DUMMY_STATE_FLAG: Atomic<ThreadStateFlag> = Atomic::new(ThreadStateFlag::Dying);
#[derive(Debug)]
pub struct LocalThreadStateFast {
status: Cell<LocalThreadStatus>,
short_id: Cell<Option<ShortThreadId>>,
shared_state_flag: Cell<&'static Atomic<ThreadStateFlag>>,
}
thread_local! {
static THIS_THREAD_STATE: Result<LocalThreadState, ThreadStateInitError> = nounwind::abort_unwind(init_thread);
static THIS_THREAD_STATE_FAST: LocalThreadStateFast = const { LocalThreadStateFast {
status: Cell::new(LocalThreadStatus::Uninit),
short_id: Cell::new(None),
shared_state_flag: Cell::new(&DUMMY_STATE_FLAG),
} };
}
static SHORT_THREAD_IDS_EXHAUSTED: AtomicBool = AtomicBool::new(false);
static THREADS: boxcar::Vec<Result<&'static SharedThreadInfo, ThreadIdOverflowError>> =
boxcar::Vec::new();
fn init_thread() -> Result<LocalThreadState, ThreadStateInitError> {
let old_status = THIS_THREAD_STATE_FAST.with(|fast| fast.status.get());
match old_status {
LocalThreadStatus::DeadOrDying => {
return Err(ThreadStateInitError::AlreadyDied);
}
LocalThreadStatus::Uninit => {} LocalThreadStatus::Active => {
panic!("Thread already initialized")
}
}
if SHORT_THREAD_IDS_EXHAUSTED.load(Ordering::Acquire) {
Err(ThreadIdOverflowError.into())
} else {
let mut queued_objects = None;
let index = THREADS.push_with(|id| {
let id = UniqueThreadId::from_index(id);
match ShortThreadId::try_from(id) {
Ok(short_id) => {
queued_objects = Some(Arc::new(unsafe { ObjectQueue::new(short_id) }));
Ok(Box::leak(Box::new(SharedThreadInfo {
_id: id,
short_id,
state_flag: Atomic::new(ThreadStateFlag::Live),
queued_objects: Arc::downgrade(queued_objects.as_ref().unwrap()),
})))
}
Err(ThreadIdOverflowError) => {
SHORT_THREAD_IDS_EXHAUSTED.store(true, Ordering::Release);
Err(ThreadIdOverflowError)
}
}
});
let shared_info = THREADS[index]?;
assert_eq!(
THIS_THREAD_STATE_FAST.with(|fast| {
(
core::ptr::from_ref(fast.shared_state_flag.replace(&shared_info.state_flag)),
fast.status.replace(LocalThreadStatus::Active),
fast.short_id.replace(Some(shared_info.short_id)),
)
}),
(
core::ptr::from_ref(&DUMMY_STATE_FLAG),
LocalThreadStatus::Uninit,
None
)
);
Ok(LocalThreadState {
shared_info,
short_id: shared_info.short_id,
queue: ManuallyDrop::new(queued_objects.unwrap()),
})
}
}
#[derive(Debug, thiserror::Error, Copy, Clone)]
pub enum ThreadStateInitError {
#[error("Thread has already died so cannot be re-initialized")]
AlreadyDied,
#[error("Failed to initialize thread: {0}")]
IdOverflow(#[from] ThreadIdOverflowError),
}
#[derive(Copy, Clone, Debug, thiserror::Error, Eq, PartialEq)]
#[error(
"Thread ID overflows {} bits, so cannot participate in biased reference counting",
ShortThreadId::BITS
)]
pub struct ThreadIdOverflowError;
#[derive(Debug, thiserror::Error, Clone, Eq, PartialEq)]
pub enum LocalThreadAccessError {
#[error("Local thread has not been initialized yet")]
Uninitialized,
#[error("Local thread is either dead or dying")]
Dead,
#[error("Local thread cannot participate in biased reference counting: {0}")]
IdOverflow(#[from] ThreadIdOverflowError),
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[repr(transparent)]
pub struct ShortThreadId(NonZeroU16);
impl ShortThreadId {
pub const BITS: u32 = 12;
pub const MAX: u12 = u12::MAX;
#[inline]
pub const fn new(x: u12) -> Option<Self> {
if x.value() != 0 {
Some(unsafe { ShortThreadId(NonZeroU16::new_unchecked(x.value())) })
} else {
None
}
}
#[inline]
pub const fn value(self) -> u12 {
unsafe { u12::new_unchecked(self.0.get()) }
}
#[inline]
pub const fn index(self) -> usize {
unsafe { self.0.get().unchecked_sub(1) as usize }
}
}
impl TryFrom<UniqueThreadId> for ShortThreadId {
type Error = ThreadIdOverflowError;
#[inline]
fn try_from(value: UniqueThreadId) -> Result<Self, Self::Error> {
let value = NonZeroU16::try_from(value.0).map_err(|_| ThreadIdOverflowError)?;
if value.get() <= Self::MAX.value() {
Ok(ShortThreadId(value))
} else {
Err(ThreadIdOverflowError)
}
}
}