use alloc::boxed::Box;
use core::cell::Cell;
use core::ptr::{null, null_mut};
use core::sync::atomic::{AtomicU32, Ordering};
use errno::{set_errno, Errno};
use libc::{c_int, c_void};
use rustix_futex_sync::RwLock;
#[cfg(target_env = "gnu")]
const PTHREAD_KEYS_MAX: u32 = 1024;
#[cfg(target_env = "musl")]
const PTHREAD_KEYS_MAX: u32 = 128;
const PTHREAD_DESTRUCTOR_ITERATIONS: u8 = 4;
#[derive(Clone, Copy)]
struct KeyData {
next_key: libc::pthread_key_t,
destructors: [Option<unsafe extern "C" fn(_: *mut c_void)>; PTHREAD_KEYS_MAX as usize],
}
#[derive(Clone, Copy)]
struct ValueWithEpoch {
epoch: u32,
data: *mut c_void,
}
impl ValueWithEpoch {
const fn new() -> Self {
ValueWithEpoch {
epoch: 0,
data: null_mut(),
}
}
}
static KEY_DATA: RwLock<KeyData> = RwLock::new(KeyData {
next_key: 0,
destructors: [None; PTHREAD_KEYS_MAX as usize],
});
static EPOCHS: [AtomicU32; PTHREAD_KEYS_MAX as usize] =
[const { AtomicU32::new(0) }; PTHREAD_KEYS_MAX as usize];
#[thread_local]
static VALUES: [Cell<ValueWithEpoch>; PTHREAD_KEYS_MAX as usize] =
[const { Cell::new(ValueWithEpoch::new()) }; PTHREAD_KEYS_MAX as usize];
#[thread_local]
static HAS_REGISTERED_CLEANUP: Cell<bool> = Cell::new(false);
#[no_mangle]
unsafe extern "C" fn pthread_getspecific(key: libc::pthread_key_t) -> *mut c_void {
libc!(libc::pthread_getspecific(key));
let latest_epoch = match EPOCHS.get(key as usize) {
Some(epoch) => epoch,
None => return null_mut(),
};
let latest_epoch = latest_epoch.load(Ordering::SeqCst);
let ValueWithEpoch { epoch, data } = VALUES[key as usize].get();
if epoch < latest_epoch {
null_mut()
} else {
data
}
}
#[no_mangle]
unsafe extern "C" fn pthread_setspecific(key: libc::pthread_key_t, value: *const c_void) -> c_int {
libc!(libc::pthread_setspecific(key, value));
if !HAS_REGISTERED_CLEANUP.get() {
origin::thread::at_exit(Box::new(move || {
for _ in 0..PTHREAD_DESTRUCTOR_ITERATIONS {
let mut ran_dtor = false;
for i in 0..PTHREAD_KEYS_MAX {
let data = pthread_getspecific(i as libc::pthread_key_t);
if data.is_null() {
continue;
}
ran_dtor = true;
let dtor = {
let key_data = KEY_DATA.read();
key_data.destructors[i as usize]
};
if let Some(dtor) = dtor {
pthread_setspecific(i as libc::pthread_key_t, null());
dtor(data);
}
}
if !ran_dtor {
break;
}
}
}));
HAS_REGISTERED_CLEANUP.set(true);
}
let latest_epoch = match EPOCHS.get(key as usize) {
Some(epoch) => epoch,
None => return libc::EINVAL,
};
let latest_epoch = latest_epoch.load(Ordering::SeqCst);
VALUES[key as usize].set(ValueWithEpoch {
epoch: latest_epoch,
data: value.cast_mut(),
});
0
}
#[no_mangle]
unsafe extern "C" fn pthread_key_create(
key: *mut libc::pthread_key_t,
dtor: Option<unsafe extern "C" fn(_: *mut c_void)>,
) -> c_int {
libc!(libc::pthread_key_create(key, dtor));
extern "C" fn empty_dtor(_: *mut c_void) {}
let mut key_data = KEY_DATA.write();
let mut next_key = key_data.next_key;
if next_key < PTHREAD_KEYS_MAX {
key_data.next_key = next_key + 1;
} else {
for (index, dtor) in key_data.destructors.iter().enumerate() {
if dtor.is_none() {
if EPOCHS[index].fetch_add(1, Ordering::SeqCst) == 0 {
panic!("detected epoch counter overflow");
}
next_key = index as libc::pthread_key_t;
break;
}
}
if next_key >= PTHREAD_KEYS_MAX {
set_errno(Errno(libc::EAGAIN));
return -1;
}
}
*key = next_key;
key_data.destructors[next_key as usize] = Some(dtor.unwrap_or(empty_dtor));
0
}
#[no_mangle]
unsafe extern "C" fn pthread_key_delete(key: libc::pthread_key_t) -> c_int {
libc!(libc::pthread_key_delete(key));
let mut key_data = KEY_DATA.write();
key_data.destructors[key as usize] = None;
0
}