use alloc::alloc::dealloc;
use alloc::boxed::Box;
use core::alloc::Layout;
use core::arch::global_asm;
use core::cell::UnsafeCell;
use core::fmt::Debug;
use core::marker::PhantomData;
use core::num::NonZeroUsize;
use core::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use sc::nr::MUNMAP;
use crate::eprintln;
use rusl::platform::{CloneFlags, MapAdditionalFlags, MapRequiredFlag, MemoryProtection};
use rusl::unistd::mmap;
use crate::error::Result;
use crate::sync::futex_wait_fast;
pub struct JoinHandle<T: Sized> {
tsm: Tsm,
_pd: PhantomData<T>,
}
const UNFINISHED: u32 = 1;
impl<T: Sized> JoinHandle<T> {
#[must_use]
pub fn join(self) -> Option<T> {
unsafe {
futex_wait_fast(self.tsm.get_futex(), UNFINISHED);
let val = self.tsm.get_value::<T>().into_inner();
self.tsm.dealloc();
core::mem::forget(self);
val
}
}
}
impl<T: Sized> Drop for JoinHandle<T> {
fn drop(&mut self) {
unsafe {
if self
.tsm
.get_sync()
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_err()
{
futex_wait_fast(self.tsm.get_futex(), UNFINISHED);
self.tsm.dealloc();
}
}
}
}
#[repr(transparent)]
#[derive(Copy, Clone, Debug)]
struct Tsm(*mut u8);
impl Tsm {
const FUTEX_OFFSET: usize = core::mem::size_of::<AtomicBool>()
+ padding(
core::mem::size_of::<AtomicBool>(),
core::mem::align_of::<AtomicU32>(),
);
const SELF_SZ_OFFSET: usize = Self::FUTEX_OFFSET
+ core::mem::size_of::<AtomicU32>()
+ padding(
Self::FUTEX_OFFSET + core::mem::size_of::<AtomicU32>(),
core::mem::align_of::<usize>(),
);
const SELF_ALIGN_OFFSET: usize = Self::SELF_SZ_OFFSET
+ core::mem::size_of::<usize>()
+ padding(
Self::SELF_SZ_OFFSET + core::mem::size_of::<usize>(),
core::mem::align_of::<usize>(),
);
#[expect(clippy::cast_ptr_alignment)]
unsafe fn init<T>() -> Self {
let layout = Self::layout_thread_shared_memory::<T>();
let ptr = alloc::alloc::alloc(layout);
ptr.cast::<AtomicBool>().write(AtomicBool::new(false));
ptr.add(Self::FUTEX_OFFSET)
.cast::<AtomicU32>()
.write(AtomicU32::new(UNFINISHED));
ptr.add(Self::SELF_SZ_OFFSET)
.cast::<usize>()
.write(layout.size());
ptr.add(Self::SELF_ALIGN_OFFSET)
.cast::<usize>()
.write(layout.align());
ptr.add(Self::value_offset::<UnsafeCell<Option<T>>>())
.cast::<UnsafeCell<Option<T>>>()
.write(UnsafeCell::new(None));
Self(ptr)
}
const fn layout_thread_shared_memory<T: Sized>() -> Layout {
let mut base = push_aligned::<AtomicBool>(0, 0);
base = push_aligned::<AtomicU32>(base.size(), base.align());
base = push_aligned::<usize>(base.size(), base.align());
base = push_aligned::<usize>(base.size(), base.align());
let last = push_aligned::<UnsafeCell<Option<T>>>(base.size(), base.align());
unsafe {
let padded = last.size() + padding(last.size(), last.align());
Layout::from_size_align_unchecked(padded, last.align())
}
}
#[inline]
unsafe fn get_sync(self) -> &'static AtomicBool {
self.0.cast::<AtomicBool>().as_ref().unwrap_unchecked()
}
#[inline]
#[expect(clippy::cast_ptr_alignment)]
unsafe fn get_futex(self) -> &'static AtomicU32 {
self.0
.add(Self::FUTEX_OFFSET)
.cast::<AtomicU32>()
.as_ref()
.unwrap_unchecked()
}
#[inline]
#[expect(clippy::cast_ptr_alignment)]
unsafe fn get_layout(self) -> Layout {
let size = self.0.add(Self::SELF_SZ_OFFSET).cast::<usize>().read();
let align = self.0.add(Self::SELF_ALIGN_OFFSET).cast::<usize>().read();
Layout::from_size_align_unchecked(size, align)
}
#[inline]
unsafe fn value_offset<T>() -> usize {
Self::SELF_ALIGN_OFFSET
+ core::mem::size_of::<usize>()
+ padding(
Self::SELF_ALIGN_OFFSET + core::mem::size_of::<usize>(),
core::mem::align_of::<T>(),
)
}
#[inline]
unsafe fn get_value<T>(self) -> UnsafeCell<Option<T>> {
self.0
.add(Self::value_offset::<UnsafeCell<Option<T>>>())
.cast::<UnsafeCell<Option<T>>>()
.read()
}
#[inline]
unsafe fn value_mut<T>(self) -> *mut Option<T> {
self.0
.add(Self::value_offset::<UnsafeCell<Option<T>>>())
.cast::<UnsafeCell<Option<T>>>()
.as_ref()
.unwrap_unchecked()
.get()
}
#[inline]
unsafe fn dealloc(self) {
let layout = self.get_layout();
dealloc(self.0, layout);
}
}
const fn push_aligned<T>(base: usize, max_align: usize) -> Layout {
let t_align = core::mem::align_of::<T>();
let pad = padding(base, t_align);
let base = base + pad;
let max_align = max(max_align, t_align);
unsafe { Layout::from_size_align_unchecked(base + core::mem::size_of::<T>(), max_align) }
}
const fn max(a: usize, b: usize) -> usize {
if a > b {
a
} else {
b
}
}
const fn padding(base: usize, align: usize) -> usize {
let modulo = base % align;
if modulo == 0 {
0
} else {
align - modulo
}
}
#[repr(C)]
#[derive(Copy, Clone)]
pub(crate) struct ThreadLocalStorage {
pub(crate) self_addr: usize,
pub(crate) stack_info: Option<ThreadDealloc>,
}
impl ThreadLocalStorage {
#[inline]
fn thread_stack_info(&self) -> Option<&ThreadDealloc> {
#[cfg(target_arch = "x86_64")]
{
self.stack_info.as_ref()
}
#[cfg(target_arch = "aarch64")]
{
if self.self_addr == 0 {
None
} else {
self.stack_info.as_ref()
}
}
}
}
#[repr(C)]
#[derive(Copy, Clone)]
pub(crate) struct ThreadDealloc {
stack_addr: usize,
stack_sz: usize,
tsm: Tsm,
}
pub fn spawn<T, F>(func: F) -> Result<JoinHandle<T>>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
let flags = CloneFlags::CLONE_VM
| CloneFlags::CLONE_FS
| CloneFlags::CLONE_FILES
| CloneFlags::CLONE_SIGHAND
| CloneFlags::CLONE_THREAD
| CloneFlags::CLONE_SYSVSEM
| CloneFlags::CLONE_CHILD_CLEARTID
| CloneFlags::CLONE_DETACHED
| CloneFlags::CLONE_SETTLS;
let stack_sz = 8192 * 16 * 16;
let guard_sz = 0;
let size = guard_sz + stack_sz;
let tsm = unsafe { Tsm::init::<T>() };
let df = move || {
unsafe {
let func_ret = func();
(*tsm.value_mut()) = Some(func_ret);
if tsm
.get_sync()
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_err()
{
sc::syscall!(SET_TID_ADDRESS, 0);
tsm.dealloc();
}
dealloc(get_tls_ptr().cast(), Layout::new::<ThreadLocalStorage>());
}
};
let (start_fn, fn_caller) = unsafe { onwed_split_fn_once(df) };
let map_ptr = unsafe {
mmap(
None,
NonZeroUsize::new_unchecked(size),
MemoryProtection::PROT_READ | MemoryProtection::PROT_WRITE,
MapRequiredFlag::MapPrivate,
MapAdditionalFlags::MAP_ANONYMOUS,
None,
0,
)?
};
let mut stack = map_ptr + size;
stack -= stack % core::mem::size_of::<usize>();
stack -= core::mem::size_of::<StartArgs>();
let args = stack as *mut StartArgs;
unsafe {
(*args).start_arg = fn_caller;
}
let tls = Box::into_raw(Box::new(ThreadLocalStorage {
self_addr: 0,
stack_info: Some(ThreadDealloc {
stack_addr: map_ptr,
stack_sz: size,
tsm,
}),
}));
unsafe {
(*tls).self_addr = tls as usize;
}
#[expect(clippy::cast_possible_truncation)]
unsafe {
__clone(
start_fn,
stack,
flags.bits() as i32,
args as _,
tls as usize,
tsm.get_futex().as_ptr() as usize,
map_ptr,
stack_sz,
);
}
Ok(JoinHandle {
tsm,
_pd: PhantomData,
})
}
#[inline]
unsafe fn onwed_split_fn_once<F: FnOnce()>(f: F) -> (usize, usize) {
let t = start_fn::<F>;
let d = Box::into_raw(Box::new(f));
(t as *const () as usize, d as usize)
}
#[repr(C)]
struct StartArgs {
start_arg: usize,
}
unsafe extern "C" fn start_fn<F: FnOnce()>(ptr: *mut StartArgs) -> i32 {
let args = ptr.read();
let func = args.start_arg as *mut F;
let boxed_run = Box::from_raw(func);
(boxed_run)();
0
}
extern "C" {
fn __clone(
start_fn: usize,
stack_ptr: usize,
flags: i32,
args_ptr: usize,
tls_ptr: usize,
child_tid_ptr: usize,
stack_unmap_ptr: usize,
stack_sz: usize,
) -> i32;
}
#[cfg(target_arch = "x86_64")]
global_asm!(
".text",
".global __clone",
".hidden __clone",
".type __clone,@function",
"__clone:",
"xor eax, eax",
"mov al, 56",
"mov r11, rdi",
"mov rdi, rdx",
"xor rdx, rdx",
"mov r10, r9",
"mov r9, r11",
"and rsi, -16",
"sub rsi,8",
"mov [rsi], rcx",
"sub rsi, 8",
"mov rcx, [8 + rsp]",
"mov [rsi], rcx",
"mov rcx, [16 + rsp]",
"sub rsi, 8",
"mov [rsi], rcx",
"syscall",
"test eax, eax",
"jnz 1f",
"xor ebp, ebp",
"pop r13",
"pop r12",
"pop rdi",
"call r9",
"xor rax, rax",
"mov al, 11",
"mov rdi, r12",
"mov rsi, r13",
"syscall",
"xor eax,eax",
"mov al, 60",
"mov rdi, 0",
"syscall",
"1: ret",
);
#[cfg(target_arch = "aarch64")]
global_asm!(
".global __clone",
".hidden __clone",
".type __clone,@function",
"__clone:",
"and x1, x1, #-16",
"stp x0, x3, [x1, #-16]!",
"stp x7, x6, [x1, #-16]!",
"uxtw x0, w2",
"eor x2, x2, x2",
"mov x3, x4",
"mov x4, x5",
"mov x8, #220",
"svc #0",
"cbz x0, 1f",
"ret",
"1: ldp x21, x20, [sp], #16",
"ldp x1, x0, [sp], #16",
"blr x1",
"mov x0, x20",
"mov x1, x21",
"mov x8, #215",
"svc #0",
"mov x0, 0",
"mov x8, #93",
"svc #0",
);
#[inline]
#[must_use]
fn get_tls_ptr() -> *mut ThreadLocalStorage {
let mut output: usize;
#[cfg(target_arch = "x86_64")]
unsafe {
core::arch::asm!("mov {x}, fs:0", x = out(reg) output);
}
#[cfg(target_arch = "aarch64")]
unsafe {
core::arch::asm!("mrs {x}, tpidr_el0", x = out(reg) output);
}
output as _
}
#[panic_handler]
pub fn on_panic(info: &core::panic::PanicInfo) -> ! {
let tls = get_tls_ptr();
unsafe {
let stack_info = tls.read();
if let Some(stack_dealloc) = stack_info.thread_stack_info() {
dealloc(tls.cast(), Layout::new::<ThreadLocalStorage>());
let map_ptr = stack_dealloc.stack_addr;
let map_len = stack_dealloc.stack_sz;
let tsm = stack_dealloc.tsm;
let should_dealloc = tsm
.get_sync()
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_err();
if should_dealloc {
sc::syscall!(SET_TID_ADDRESS, 0);
tsm.dealloc();
}
#[cfg(target_arch = "x86_64")]
core::arch::asm!(
"syscall",
"xor eax, eax",
"mov al, 60",
"mov rdi, 0",
"syscall",
in("rax") MUNMAP,
in("rdi") map_ptr,
in("rsi") map_len,
options(nostack, noreturn)
);
#[cfg(target_arch = "aarch64")]
core::arch::asm!(
"svc #0",
"mov x0, 0",
"mov x8, #93",
"svc #0",
in("x8") MUNMAP,
in("x0") map_ptr,
in("x1") map_len,
options(nostack, noreturn)
);
} else {
eprintln!("Main thread panicked: {}", info);
rusl::process::exit(1)
}
}
}