use std::fmt;
use std::marker::PhantomData;
use std::sync::OnceLock;
use std::sync::atomic::Ordering;
use std::thread::{self, ThreadId};
pub trait AtomicRepr: Copy + 'static {
type Atomic: 'static + Send + Sync;
fn new_atomic(val: Self) -> Self::Atomic;
fn load(atomic: &Self::Atomic) -> Self;
fn store(atomic: &Self::Atomic, val: Self);
fn into_inner(atomic: Self::Atomic) -> Self;
}
macro_rules! impl_atomic_repr {
($ty:ty, $atomic:ty) => {
impl AtomicRepr for $ty {
type Atomic = $atomic;
fn new_atomic(val: Self) -> Self::Atomic {
<$atomic>::new(val)
}
fn load(atomic: &Self::Atomic) -> Self {
atomic.load(Ordering::Relaxed)
}
fn store(atomic: &Self::Atomic, val: Self) {
atomic.store(val, Ordering::Relaxed);
}
fn into_inner(atomic: Self::Atomic) -> Self {
atomic.into_inner()
}
}
};
}
impl_atomic_repr!(bool, std::sync::atomic::AtomicBool);
impl_atomic_repr!(u8, std::sync::atomic::AtomicU8);
impl_atomic_repr!(u16, std::sync::atomic::AtomicU16);
impl_atomic_repr!(u32, std::sync::atomic::AtomicU32);
impl_atomic_repr!(usize, std::sync::atomic::AtomicUsize);
impl_atomic_repr!(i8, std::sync::atomic::AtomicI8);
impl_atomic_repr!(i16, std::sync::atomic::AtomicI16);
impl_atomic_repr!(i32, std::sync::atomic::AtomicI32);
impl_atomic_repr!(isize, std::sync::atomic::AtomicIsize);
#[derive(Debug)]
pub struct RelaxedAtomic<T: AtomicRepr> {
inner: T::Atomic,
}
impl<T: AtomicRepr> RelaxedAtomic<T> {
#[inline]
pub fn new(val: T) -> Self {
Self {
inner: T::new_atomic(val),
}
}
#[inline]
pub fn load(&self) -> T {
T::load(&self.inner)
}
#[inline]
pub fn store(&self, val: T) {
T::store(&self.inner, val);
}
pub fn into_inner(self) -> T {
T::into_inner(self.inner)
}
}
impl RelaxedAtomic<u32> {
#[inline]
pub fn fetch_add(&self, val: u32) -> u32 {
self.inner.fetch_add(val, Ordering::Relaxed)
}
}
impl RelaxedAtomic<u32> {
#[inline]
pub fn fetch_sub(&self, val: u32) -> u32 {
self.inner.fetch_sub(val, Ordering::Relaxed)
}
}
impl<T: AtomicRepr + PartialEq> PartialEq for RelaxedAtomic<T> {
fn eq(&self, other: &Self) -> bool {
self.load() == other.load()
}
}
impl<T: AtomicRepr + Eq> Eq for RelaxedAtomic<T> {}
impl<T: AtomicRepr + Default> Default for RelaxedAtomic<T> {
fn default() -> Self {
Self::new(T::default())
}
}
impl<T: AtomicRepr> Clone for RelaxedAtomic<T> {
fn clone(&self) -> Self {
Self::new(self.load())
}
}
impl<T: AtomicRepr + fmt::Display> fmt::Display for RelaxedAtomic<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.load().fmt(f)
}
}
static MAIN_THREAD_ID: OnceLock<ThreadId> = OnceLock::new();
pub fn designate_main_thread() {
let _ = MAIN_THREAD_ID.set(thread::current().id());
}
pub fn main_thread_id() -> Option<ThreadId> {
MAIN_THREAD_ID.get().copied()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MainThreadToken(PhantomData<*mut ()>);
impl MainThreadToken {
#[expect(
unsafe_code,
reason = "phantom data marker; !Send + !Sync prevents token leakage"
)]
pub unsafe fn new_unchecked() -> Self {
Self(PhantomData)
}
pub fn try_new() -> Option<Self> {
let designated = MAIN_THREAD_ID.get()?;
if *designated == thread::current().id() {
Some(Self(PhantomData))
} else {
None
}
}
}
pub use send_wrapper::SendWrapper;
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn worker_thread_never_obtains_token() {
let on_worker = thread::spawn(|| MainThreadToken::try_new().is_some())
.join()
.expect("worker thread");
assert!(!on_worker);
}
#[test]
fn try_new_returns_some_after_designation_on_same_thread() {
designate_main_thread();
match main_thread_id() {
Some(id) if id == thread::current().id() => {
assert!(MainThreadToken::try_new().is_some());
}
_ => {
assert!(MainThreadToken::try_new().is_none());
}
}
}
}