use crate::arch::{
clone, get_thread_pointer, munmap_and_exit_thread, set_thread_pointer, TLS_OFFSET,
};
use alloc::boxed::Box;
use alloc::vec::Vec;
use core::any::Any;
use core::cmp::max;
use core::ffi::c_void;
use core::mem::{align_of, size_of};
use core::ptr::{self, drop_in_place, null, null_mut};
use core::slice;
use core::sync::atomic::Ordering::SeqCst;
use core::sync::atomic::{AtomicU32, AtomicU8};
use memoffset::offset_of;
use rustix::io;
use rustix::param::{linux_execfn, page_size};
use rustix::process::{getrlimit, Pid, RawNonZeroPid, Resource};
use rustix::runtime::{set_tid_address, StartupTlsInfo};
use rustix::thread::gettid;
pub(super) unsafe extern "C" fn entry(fn_: *mut Box<dyn FnOnce() -> Option<Box<dyn Any>>>) -> ! {
let fn_ = Box::from_raw(fn_);
#[cfg(feature = "log")]
log::trace!("Thread[{:?}] launched", current_thread_id());
#[cfg(debug_assertions)]
{
extern "C" {
#[link_name = "llvm.frameaddress"]
fn builtin_frame_address(level: i32) -> *const u8;
#[link_name = "llvm.returnaddress"]
fn builtin_return_address(level: i32) -> *const u8;
#[cfg(target_arch = "aarch64")]
#[link_name = "llvm.sponentry"]
fn builtin_sponentry() -> *const u8;
}
debug_assert_eq!(builtin_return_address(0), core::ptr::null());
debug_assert_ne!(builtin_frame_address(0), core::ptr::null());
#[cfg(not(any(target_arch = "x86", target_arch = "arm")))]
debug_assert_eq!(builtin_frame_address(0).addr() & 0xf, 0);
#[cfg(target_arch = "arm")]
debug_assert_eq!(builtin_frame_address(0).addr() & 0x3, 0);
#[cfg(target_arch = "x86")]
debug_assert_eq!(builtin_frame_address(0).addr() & 0xf, 8);
debug_assert_eq!(builtin_frame_address(1), core::ptr::null());
#[cfg(target_arch = "aarch64")]
debug_assert_ne!(builtin_sponentry(), core::ptr::null());
#[cfg(target_arch = "aarch64")]
debug_assert_eq!(builtin_sponentry().addr() & 0xf, 0);
debug_assert_eq!(current_thread_id(), gettid());
}
let _result = fn_();
exit_thread()
}
#[repr(C)]
struct Metadata {
#[cfg(any(target_arch = "aarch64", target_arch = "arm", target_arch = "riscv64"))]
thread: ThreadData,
abi: Abi,
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
thread: ThreadData,
}
#[repr(C)]
#[cfg_attr(target_arch = "arm", repr(align(8)))]
struct Abi {
#[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
dtv: *const c_void,
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
this: *mut Abi,
#[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
pad: [usize; 1],
}
#[derive(Copy, Clone)]
pub struct Thread(*mut ThreadData);
impl Thread {
#[inline]
pub fn from_raw(raw: *mut c_void) -> Self {
Self(raw.cast())
}
#[inline]
pub fn to_raw(self) -> *mut c_void {
self.0.cast()
}
}
struct ThreadData {
thread_id: AtomicU32,
detached: AtomicU8,
stack_addr: *mut c_void,
stack_size: usize,
guard_size: usize,
map_size: usize,
dtors: Vec<Box<dyn FnOnce()>>,
}
const INITIAL: u8 = 0;
const DETACHED: u8 = 1;
const ABANDONED: u8 = 2;
impl ThreadData {
#[inline]
fn new(
tid: Option<Pid>,
stack_addr: *mut c_void,
stack_size: usize,
guard_size: usize,
map_size: usize,
) -> Self {
Self {
thread_id: AtomicU32::new(Pid::as_raw(tid)),
detached: AtomicU8::new(INITIAL),
stack_addr,
stack_size,
guard_size,
map_size,
dtors: Vec::new(),
}
}
}
#[inline]
fn current_metadata() -> *mut Metadata {
get_thread_pointer()
.cast::<u8>()
.wrapping_sub(offset_of!(Metadata, abi))
.cast()
}
#[inline]
pub fn current_thread() -> Thread {
unsafe { Thread(&mut (*current_metadata()).thread) }
}
#[inline]
pub fn current_thread_id() -> Pid {
let raw = unsafe { (*current_thread().0).thread_id.load(SeqCst) };
debug_assert_ne!(raw, 0);
let tid = unsafe { Pid::from_raw_nonzero(RawNonZeroPid::new_unchecked(raw)) };
debug_assert_eq!(tid, gettid(), "`current_thread_id` disagrees with `gettid`");
tid
}
#[cfg(feature = "set_thread_id")]
#[doc(hidden)]
#[inline]
pub unsafe fn set_current_thread_id_after_a_fork(tid: Pid) {
assert_ne!(
tid.as_raw_nonzero().get(),
(*current_thread().0).thread_id.load(SeqCst),
"current thread ID already matches new thread ID"
);
assert_eq!(tid, gettid(), "new thread ID disagrees with `gettid`");
(*current_thread().0)
.thread_id
.store(tid.as_raw_nonzero().get(), SeqCst);
}
#[inline]
pub fn current_thread_tls_addr(offset: usize) -> *mut c_void {
#[cfg(any(target_arch = "aarch64", target_arch = "arm", target_arch = "riscv64"))]
{
crate::arch::get_thread_pointer()
.cast::<u8>()
.wrapping_add(TLS_OFFSET)
.wrapping_add(size_of::<Abi>())
.wrapping_add(offset)
.cast()
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe {
get_thread_pointer()
.cast::<u8>()
.wrapping_add(TLS_OFFSET)
.wrapping_sub(STARTUP_TLS_INFO.mem_size)
.wrapping_add(offset)
.cast()
}
}
#[inline]
pub unsafe fn thread_stack(thread: Thread) -> (*mut c_void, usize, usize) {
let data = &*thread.0;
(data.stack_addr, data.stack_size, data.guard_size)
}
pub fn at_thread_exit(func: Box<dyn FnOnce()>) {
unsafe {
(*current_thread().0).dtors.push(func);
}
}
pub(crate) fn call_thread_dtors(current: Thread) {
while let Some(func) = unsafe { (*current.0).dtors.pop() } {
#[cfg(feature = "log")]
if log::log_enabled!(log::Level::Trace) {
log::trace!(
"Thread[{:?}] calling `at_thread_exit`-registered function",
unsafe { (*current.0).thread_id.load(SeqCst) },
);
}
func();
}
}
unsafe fn exit_thread() -> ! {
let current = current_thread();
call_thread_dtors(current);
let state = (*current.0)
.detached
.compare_exchange(INITIAL, ABANDONED, SeqCst, SeqCst);
if let Err(e) = state {
#[cfg(feature = "log")]
let current_thread_id = (*current.0).thread_id.load(SeqCst);
let current_map_size = (*current.0).map_size;
let current_stack_addr = (*current.0).stack_addr;
let current_guard_size = (*current.0).guard_size;
#[cfg(feature = "log")]
log::trace!("Thread[{:?}] exiting as detached", current_thread_id);
debug_assert_eq!(e, DETACHED);
drop_in_place(current.0);
let map_size = current_map_size;
if map_size != 0 {
let _ = set_tid_address(null_mut());
let map = current_stack_addr.cast::<u8>().sub(current_guard_size);
munmap_and_exit_thread(map.cast(), map_size);
}
} else {
#[cfg(feature = "log")]
if log::log_enabled!(log::Level::Trace) {
log::trace!(
"Thread[{:?}] exiting as joinable",
(*current.0).thread_id.load(SeqCst)
);
}
}
rustix::runtime::exit_thread(0)
}
pub(super) unsafe fn initialize_main_thread(mem: *mut c_void) {
use rustix::mm::{mmap_anonymous, MapFlags, ProtFlags};
STARTUP_TLS_INFO = rustix::runtime::startup_tls_info();
let execfn = linux_execfn().to_bytes_with_nul();
let stack_base = execfn.as_ptr().add(execfn.len());
let stack_base = stack_base.map_addr(|ptr| round_up(ptr, page_size())) as *mut c_void;
let stack_map_size = getrlimit(Resource::Stack).current.unwrap() as usize;
let stack_least = stack_base.cast::<u8>().sub(stack_map_size);
let stack_size = stack_least.offset_from(mem.cast::<u8>()) as usize;
let guard_size = 0;
let map_size = 0;
let tls_data_align = STARTUP_TLS_INFO.align;
let header_align = align_of::<Metadata>();
let metadata_align = max(tls_data_align, header_align);
let mut alloc_size = 0;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
let tls_data_bottom = alloc_size;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
alloc_size += round_up(STARTUP_TLS_INFO.mem_size, metadata_align);
}
let header = alloc_size;
alloc_size += size_of::<Metadata>();
#[cfg(any(target_arch = "aarch64", target_arch = "arm", target_arch = "riscv64"))]
{
alloc_size = round_up(alloc_size, tls_data_align);
}
#[cfg(any(target_arch = "aarch64", target_arch = "arm", target_arch = "riscv64"))]
let tls_data_bottom = alloc_size;
#[cfg(any(target_arch = "aarch64", target_arch = "arm", target_arch = "riscv64"))]
{
alloc_size += round_up(STARTUP_TLS_INFO.mem_size, tls_data_align);
}
let new = mmap_anonymous(
null_mut(),
alloc_size,
ProtFlags::READ | ProtFlags::WRITE,
MapFlags::PRIVATE,
)
.unwrap()
.cast::<u8>();
debug_assert_eq!(new.addr() % metadata_align, 0);
let tls_data = new.add(tls_data_bottom);
let metadata: *mut Metadata = new.add(header).cast();
let newtls: *mut Abi = &mut (*metadata).abi;
ptr::write(
metadata,
Metadata {
abi: Abi {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
this: newtls,
#[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
dtv: null(),
#[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
pad: [0_usize; 1],
},
thread: ThreadData::new(
Some(gettid()),
stack_least.cast(),
stack_size,
guard_size,
map_size,
),
},
);
slice::from_raw_parts_mut(tls_data, STARTUP_TLS_INFO.file_size).copy_from_slice(
slice::from_raw_parts(
STARTUP_TLS_INFO.addr.cast::<u8>(),
STARTUP_TLS_INFO.file_size,
),
);
slice::from_raw_parts_mut(
tls_data.add(STARTUP_TLS_INFO.file_size),
STARTUP_TLS_INFO.mem_size - STARTUP_TLS_INFO.file_size,
)
.fill(0);
set_thread_pointer(newtls.cast::<u8>().cast());
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub fn create_thread(
fn_: Box<dyn FnOnce() -> Option<Box<dyn Any>>>,
stack_size: usize,
guard_size: usize,
) -> io::Result<Thread> {
use rustix::mm::{mmap_anonymous, mprotect, MapFlags, MprotectFlags, ProtFlags};
let (startup_tls_align, startup_tls_mem_size) =
unsafe { (STARTUP_TLS_INFO.align, STARTUP_TLS_INFO.mem_size) };
let tls_data_align = startup_tls_align;
let page_align = page_size();
let stack_align = 16;
let header_align = align_of::<Metadata>();
let metadata_align = max(tls_data_align, header_align);
let stack_metadata_align = max(stack_align, metadata_align);
assert!(stack_metadata_align <= page_align);
let mut map_size = 0;
map_size += round_up(guard_size, page_align);
let stack_bottom = map_size;
map_size += round_up(stack_size, stack_metadata_align);
let stack_top = map_size;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
let tls_data_bottom = map_size;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
map_size += round_up(startup_tls_mem_size, tls_data_align);
}
let header = map_size;
map_size += size_of::<Metadata>();
#[cfg(any(target_arch = "aarch64", target_arch = "arm", target_arch = "riscv64"))]
{
map_size = round_up(map_size, tls_data_align);
}
#[cfg(any(target_arch = "aarch64", target_arch = "arm", target_arch = "riscv64"))]
let tls_data_bottom = map_size;
#[cfg(any(target_arch = "aarch64", target_arch = "arm", target_arch = "riscv64"))]
{
map_size += round_up(startup_tls_mem_size, tls_data_align);
}
unsafe {
let map = mmap_anonymous(
null_mut(),
map_size,
ProtFlags::empty(),
MapFlags::PRIVATE | MapFlags::STACK,
)?
.cast::<u8>();
mprotect(
map.add(stack_bottom).cast(),
map_size - stack_bottom,
MprotectFlags::READ | MprotectFlags::WRITE,
)?;
let stack = map.add(stack_top);
let stack_least = map.add(stack_bottom);
let tls_data = map.add(tls_data_bottom);
let metadata: *mut Metadata = map.add(header).cast();
let newtls: *mut Abi = &mut (*metadata).abi;
ptr::write(
metadata,
Metadata {
abi: Abi {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
this: newtls,
#[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
dtv: null(),
#[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
pad: [0_usize; 1],
},
thread: ThreadData::new(
None, stack_least.cast(),
stack_size,
guard_size,
map_size,
),
},
);
slice::from_raw_parts_mut(tls_data, STARTUP_TLS_INFO.file_size).copy_from_slice(
slice::from_raw_parts(
STARTUP_TLS_INFO.addr.cast::<u8>(),
STARTUP_TLS_INFO.file_size,
),
);
let flags = CloneFlags::VM
| CloneFlags::FS
| CloneFlags::FILES
| CloneFlags::SIGHAND
| CloneFlags::THREAD
| CloneFlags::SYSVSEM
| CloneFlags::SETTLS
| CloneFlags::CHILD_CLEARTID
| CloneFlags::CHILD_SETTID
| CloneFlags::PARENT_SETTID;
let thread_id_ptr = (*metadata).thread.thread_id.as_mut_ptr();
#[cfg(target_arch = "x86_64")]
let clone_res = clone(
flags.bits(),
stack.cast(),
thread_id_ptr,
thread_id_ptr,
newtls.cast::<u8>().cast(),
Box::into_raw(Box::new(fn_)),
);
#[cfg(any(
target_arch = "x86",
target_arch = "aarch64",
target_arch = "arm",
target_arch = "riscv64"
))]
let clone_res = clone(
flags.bits(),
stack.cast(),
thread_id_ptr,
newtls.cast::<u8>().cast(),
thread_id_ptr,
Box::into_raw(Box::new(fn_)),
);
if clone_res >= 0 {
Ok(Thread(&mut (*metadata).thread))
} else {
Err(io::Errno::from_raw_os_error(-clone_res as i32))
}
}
}
fn round_up(addr: usize, boundary: usize) -> usize {
(addr + (boundary - 1)) & boundary.wrapping_neg()
}
#[inline]
pub unsafe fn detach_thread(thread: Thread) {
#[cfg(feature = "log")]
let thread_id = (*thread.0).thread_id.load(SeqCst);
#[cfg(feature = "log")]
if log::log_enabled!(log::Level::Trace) {
log::trace!(
"Thread[{:?}] marked as detached by Thread[{:?}]",
thread_id,
current_thread_id()
);
}
if (*thread.0).detached.swap(DETACHED, SeqCst) == ABANDONED {
wait_for_thread_exit(thread);
#[cfg(feature = "log")]
log_thread_to_be_freed(thread_id);
free_thread_memory(thread);
}
}
pub unsafe fn join_thread(thread: Thread) {
#[cfg(feature = "log")]
let thread_id = (*thread.0).thread_id.load(SeqCst);
#[cfg(feature = "log")]
if log::log_enabled!(log::Level::Trace) {
log::trace!(
"Thread[{:?}] is being joined by Thread[{:?}]",
thread_id,
current_thread_id()
);
}
wait_for_thread_exit(thread);
debug_assert_eq!((*thread.0).detached.load(SeqCst), ABANDONED);
#[cfg(feature = "log")]
log_thread_to_be_freed(thread_id);
free_thread_memory(thread);
}
unsafe fn wait_for_thread_exit(thread: Thread) {
use rustix::thread::{futex, FutexFlags, FutexOperation};
let thread = &mut *thread.0;
let thread_id = &mut thread.thread_id;
let id_value = thread_id.load(SeqCst);
if let Some(id_value) = Pid::from_raw(id_value) {
match futex(
thread_id.as_mut_ptr(),
FutexOperation::Wait,
FutexFlags::empty(),
id_value.as_raw_nonzero().get(),
null(),
null_mut(),
0,
) {
Ok(_) => {}
Err(e) => debug_assert_eq!(e, io::Errno::AGAIN),
}
}
}
#[cfg(feature = "log")]
unsafe fn log_thread_to_be_freed(thread_id: u32) {
if log::log_enabled!(log::Level::Trace) {
log::trace!("Thread[{:?}] memory being freed", thread_id);
}
}
unsafe fn free_thread_memory(thread: Thread) {
use rustix::mm::munmap;
let map_size = (*thread.0).map_size;
let stack_addr = (*thread.0).stack_addr;
let guard_size = (*thread.0).guard_size;
drop_in_place(thread.0);
if map_size != 0 {
let map = stack_addr.cast::<u8>().sub(guard_size);
munmap(map.cast(), map_size).unwrap();
}
}
#[inline]
pub fn default_stack_size() -> usize {
unsafe { max(page_size() * 2, STARTUP_TLS_INFO.stack_size) }
}
#[inline]
pub fn default_guard_size() -> usize {
page_size() * 4
}
static mut STARTUP_TLS_INFO: StartupTlsInfo = StartupTlsInfo {
addr: null(),
mem_size: 0,
file_size: 0,
align: 0,
stack_size: 0,
};
#[cfg(target_arch = "arm")]
#[no_mangle]
unsafe extern "C" fn __aeabi_read_tp() -> *mut c_void {
get_thread_pointer()
}
bitflags::bitflags! {
struct CloneFlags: u32 {
const NEWTIME = linux_raw_sys::general::CLONE_NEWTIME; const VM = linux_raw_sys::general::CLONE_VM;
const FS = linux_raw_sys::general::CLONE_FS;
const FILES = linux_raw_sys::general::CLONE_FILES;
const SIGHAND = linux_raw_sys::general::CLONE_SIGHAND;
const PIDFD = linux_raw_sys::general::CLONE_PIDFD; const PTRACE = linux_raw_sys::general::CLONE_PTRACE;
const VFORK = linux_raw_sys::general::CLONE_VFORK;
const PARENT = linux_raw_sys::general::CLONE_PARENT;
const THREAD = linux_raw_sys::general::CLONE_THREAD;
const NEWNS = linux_raw_sys::general::CLONE_NEWNS;
const SYSVSEM = linux_raw_sys::general::CLONE_SYSVSEM;
const SETTLS = linux_raw_sys::general::CLONE_SETTLS;
const PARENT_SETTID = linux_raw_sys::general::CLONE_PARENT_SETTID;
const CHILD_CLEARTID = linux_raw_sys::general::CLONE_CHILD_CLEARTID;
const DETACHED = linux_raw_sys::general::CLONE_DETACHED;
const UNTRACED = linux_raw_sys::general::CLONE_UNTRACED;
const CHILD_SETTID = linux_raw_sys::general::CLONE_CHILD_SETTID;
const NEWCGROUP = linux_raw_sys::general::CLONE_NEWCGROUP; const NEWUTS = linux_raw_sys::general::CLONE_NEWUTS;
const NEWIPC = linux_raw_sys::general::CLONE_NEWIPC;
const NEWUSER = linux_raw_sys::general::CLONE_NEWUSER;
const NEWPID = linux_raw_sys::general::CLONE_NEWPID;
const NEWNET = linux_raw_sys::general::CLONE_NEWNET;
const IO = linux_raw_sys::general::CLONE_IO;
}
}