#[cfg(not(windows))]
use std::sync::Mutex;
use std::sync::OnceLock;
#[cfg(not(windows))]
use libc::{MAP_ANON, MAP_FAILED, MAP_PRIVATE, PROT_NONE, PROT_READ, PROT_WRITE};
#[cfg(not(windows))]
use super::g::current_g;
use super::g::{Stack, STACK_GUARD, G};
#[cfg(windows)]
mod win32 {
pub const MEM_COMMIT: u32 = 0x0000_1000;
pub const MEM_RESERVE: u32 = 0x0000_2000;
pub const MEM_RELEASE: u32 = 0x0000_8000;
pub const PAGE_READWRITE: u32 = 0x04;
pub const PAGE_NOACCESS: u32 = 0x01;
#[link(name = "kernel32")]
unsafe extern "system" {
pub fn VirtualAlloc(
lpAddress: *mut u8,
dwSize: usize,
flAllocationType: u32,
flProtect: u32,
) -> *mut u8;
pub fn VirtualFree(
lpAddress: *mut u8,
dwSize: usize,
dwFreeType: u32,
) -> i32;
pub fn VirtualProtect(
lpAddress: *mut u8,
dwSize: usize,
flNewProtect: u32,
lpflOldProtect: *mut u32,
) -> i32;
}
}
pub(crate) const STACK_MIN: usize = 8 * 1024;
pub(crate) const STACK_MAX: usize = 1024 * 1024 * 1024;
pub(crate) const GOROUTINE_STACK_BYTES: usize = STACK_MIN;
pub(crate) const G0_STACK_BYTES: usize = 64 * 1024;
pub(crate) fn page_size() -> usize {
static PAGE_SIZE: OnceLock<usize> = OnceLock::new();
*PAGE_SIZE.get_or_init(|| {
#[cfg(not(windows))]
{
let n = unsafe { libc::sysconf(libc::_SC_PAGESIZE) };
assert!(n > 0, "sysconf(_SC_PAGESIZE) returned {n}");
n as usize
}
#[cfg(windows)]
{ 4096usize }
})
}
pub(crate) unsafe fn stack_alloc_size(size: usize) -> Result<Stack, &'static str> {
debug_assert!(size.is_power_of_two() || size == STACK_MAX,
"stack_alloc_size: size must be a power of two");
let ps = page_size();
let total = size + ps;
#[cfg(not(windows))]
{
let base = unsafe {
libc::mmap(
std::ptr::null_mut(),
total,
PROT_READ | PROT_WRITE,
MAP_ANON | MAP_PRIVATE,
-1,
0,
)
};
if base == MAP_FAILED {
return Err("stack_alloc_size: mmap failed");
}
if unsafe { libc::mprotect(base, ps, PROT_NONE) } != 0 {
unsafe { libc::munmap(base, total) };
return Err("stack_alloc_size: mprotect guard page failed");
}
let base_addr = base as usize;
return Ok(Stack { lo: base_addr + ps, hi: base_addr + total });
}
#[cfg(windows)]
{
use win32::*;
let base = unsafe {
VirtualAlloc(
std::ptr::null_mut(),
total,
MEM_COMMIT | MEM_RESERVE,
PAGE_READWRITE,
)
};
if base.is_null() {
return Err("stack_alloc_size: VirtualAlloc failed");
}
let mut old_protect: u32 = 0;
if unsafe { VirtualProtect(base, ps, PAGE_NOACCESS, &mut old_protect) } == 0 {
unsafe { VirtualFree(base, 0, MEM_RELEASE) };
return Err("stack_alloc_size: VirtualProtect guard page failed");
}
let base_addr = base as usize;
return Ok(Stack { lo: base_addr + ps, hi: base_addr + total });
}
}
pub(crate) unsafe fn stack_alloc() -> Result<Stack, &'static str> {
unsafe { stack_alloc_size(GOROUTINE_STACK_BYTES) }
}
pub(crate) unsafe fn g0_stack_alloc() -> Result<Stack, &'static str> {
unsafe { stack_alloc_size(G0_STACK_BYTES) }
}
pub(crate) unsafe fn stack_free(stack: &Stack) {
let ps = page_size();
let base = (stack.lo - ps) as *mut u8;
#[cfg(not(windows))]
{
let total = (stack.hi - stack.lo) + ps;
unsafe { libc::munmap(base as *mut libc::c_void, total) };
}
#[cfg(windows)]
{
use win32::{MEM_RELEASE, VirtualFree};
unsafe { VirtualFree(base, 0, MEM_RELEASE) };
}
}
#[cfg(not(windows))]
pub(crate) unsafe fn newstack(gp: *mut G) -> isize {
let old_stack = Stack {
lo: unsafe { (*gp).stack.lo },
hi: unsafe { (*gp).stack.hi },
};
let old_size = old_stack.hi - old_stack.lo;
if old_size >= STACK_MAX {
eprintln!("goroutine stack overflow: stack size {old_size} >= STACK_MAX ({STACK_MAX})");
unsafe { libc::abort() };
}
let new_size = (old_size * 2).min(STACK_MAX);
let new_stack = unsafe {
stack_alloc_size(new_size).expect("newstack: failed to allocate new goroutine stack")
};
let delta = unsafe { copystack(gp, &old_stack, &new_stack) };
unsafe {
(*gp).stack = Stack { lo: new_stack.lo, hi: new_stack.hi };
(*gp).stackguard0 = new_stack.lo + STACK_GUARD;
}
unsafe { stack_free(&old_stack) };
delta
}
unsafe fn copystack(gp: *mut G, old_stack: &Stack, new_stack: &Stack) -> isize {
let old_lo = old_stack.lo;
let old_hi = old_stack.hi;
let new_lo = new_stack.lo;
let new_hi = new_stack.hi;
let _old_size = old_hi - old_lo;
let new_size = new_hi - new_lo;
let saved_sp = unsafe { (*gp).sched.sp };
let live_start_old = if saved_sp != 0 && saved_sp >= old_lo && saved_sp < old_hi {
saved_sp
} else {
old_lo };
let live_bytes = old_hi - live_start_old;
let live_start_new = new_hi - live_bytes;
debug_assert!(
new_size >= live_bytes,
"copystack: new stack ({new_size} B) too small for live region ({live_bytes} B)"
);
unsafe {
std::ptr::copy_nonoverlapping(
live_start_old as *const u8,
live_start_new as *mut u8,
live_bytes,
);
}
let delta: isize = new_hi as isize - old_hi as isize;
let mut addr = live_start_new;
let word = std::mem::size_of::<usize>();
while addr + word <= new_hi {
let val = unsafe { *(addr as *const usize) };
if val >= old_lo && val < old_hi {
unsafe { *(addr as *mut usize) = ((val as isize) + delta) as usize };
}
addr += word;
}
unsafe {
let sp = (*gp).sched.sp;
if sp >= old_lo && sp < old_hi {
(*gp).sched.sp = ((sp as isize) + delta) as usize;
}
let bp = (*gp).sched.bp;
if bp >= old_lo && bp < old_hi {
(*gp).sched.bp = ((bp as isize) + delta) as usize;
}
}
delta
}
#[cfg(not(windows))]
static PREV_SIGSEGV: Mutex<Option<libc::sigaction>> = Mutex::new(None);
#[cfg(not(windows))]
pub(crate) unsafe fn install_sigsegv_handler() {
let mut sa: libc::sigaction = unsafe { std::mem::zeroed() };
sa.sa_sigaction = sigsegv_handler as *const () as usize;
sa.sa_flags = (libc::SA_SIGINFO | libc::SA_ONSTACK | libc::SA_RESTART) as _;
unsafe { libc::sigemptyset(&mut sa.sa_mask) };
let mut old: libc::sigaction = unsafe { std::mem::zeroed() };
let ret = unsafe { libc::sigaction(libc::SIGSEGV, &sa, &mut old) };
assert_eq!(ret, 0, "install_sigsegv_handler: sigaction failed");
*PREV_SIGSEGV.lock().unwrap() = Some(old);
}
#[cfg(not(windows))]
unsafe extern "C" fn sigsegv_handler(
sig: libc::c_int,
info: *mut libc::siginfo_t,
ctx: *mut libc::c_void,
) {
let gp = current_g();
if !gp.is_null() {
let fault_addr = unsafe { (*info).si_addr() } as usize;
let guard_lo = unsafe { (*gp).stack.lo } - page_size();
let guard_hi = unsafe { (*gp).stack.lo };
if fault_addr >= guard_lo && fault_addr < guard_hi {
let delta = unsafe { newstack(gp) };
unsafe { update_sp_in_context(ctx, delta) };
return; }
}
let prev = *PREV_SIGSEGV.lock().unwrap();
match prev {
Some(old) if old.sa_sigaction != libc::SIG_DFL
&& old.sa_sigaction != libc::SIG_IGN => {
type SaFn = unsafe extern "C" fn(libc::c_int, *mut libc::siginfo_t, *mut libc::c_void);
let f: SaFn = unsafe { std::mem::transmute(old.sa_sigaction) };
unsafe { f(sig, info, ctx) };
}
_ => {
unsafe { libc::raise(libc::SIGSEGV) };
}
}
}
#[cfg(not(windows))]
#[allow(unused_variables)]
unsafe fn update_sp_in_context(ctx: *mut libc::c_void, delta: isize) {
#[cfg(all(target_os = "linux", target_arch = "x86_64"))]
unsafe {
let uc = ctx as *mut libc::ucontext_t;
let rsp = (*uc).uc_mcontext.gregs[libc::REG_RSP as usize] as isize;
(*uc).uc_mcontext.gregs[libc::REG_RSP as usize] = (rsp + delta) as libc::greg_t;
}
#[cfg(all(target_os = "linux", target_arch = "aarch64"))]
unsafe {
let uc = ctx as *mut libc::ucontext_t;
let sp = (*uc).uc_mcontext.sp as isize;
(*uc).uc_mcontext.sp = (sp + delta) as u64;
}
#[cfg(all(target_os = "macos", target_arch = "x86_64"))]
unsafe {
let uc = ctx as *mut libc::ucontext_t;
let ss = &mut (*(*uc).uc_mcontext).__ss;
ss.__rsp = (ss.__rsp as isize + delta) as u64;
}
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
unsafe {
let uc = ctx as *mut libc::ucontext_t;
let ss = &mut (*(*uc).uc_mcontext).__ss;
ss.__sp = (ss.__sp as isize + delta) as u64;
}
}
pub(crate) unsafe fn grow_stack_if_needed(gp: *mut G) {
let sp = unsafe { (*gp).sched.sp };
let lo = unsafe { (*gp).stack.lo };
if sp == 0 || sp < lo + STACK_GUARD * 2 {
if sp != 0 {
let old_stack = Stack {
lo: unsafe { (*gp).stack.lo },
hi: unsafe { (*gp).stack.hi },
};
let old_size = old_stack.hi - old_stack.lo;
if old_size < STACK_MAX {
let new_size = (old_size * 2).min(STACK_MAX);
let new_stack = unsafe {
stack_alloc_size(new_size)
.expect("grow_stack_if_needed: allocation failed")
};
let delta = unsafe { copystack(gp, &old_stack, &new_stack) };
unsafe {
(*gp).stack = Stack { lo: new_stack.lo, hi: new_stack.hi };
(*gp).stackguard0 = new_stack.lo + STACK_GUARD;
}
unsafe { stack_free(&old_stack) };
let _ = delta;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn alloc_write_free() {
unsafe {
let stack = stack_alloc().expect("stack_alloc failed");
let ps = page_size();
assert_eq!(stack.hi - stack.lo, GOROUTINE_STACK_BYTES);
assert_eq!(stack.lo % ps, 0);
assert!(stack.hi > stack.lo);
let top = (stack.hi - 8) as *mut u64;
top.write(0xDEAD_BEEF_CAFE_BABE);
assert_eq!(top.read(), 0xDEAD_BEEF_CAFE_BABE);
stack_free(&stack);
}
}
#[test]
fn page_size_sanity() {
let ps = page_size();
assert!(ps.is_power_of_two());
assert!(ps >= 4096);
println!("page_size = {ps}");
}
#[test]
fn page_size_concurrent() {
let handles: Vec<_> = (0..8)
.map(|_| std::thread::spawn(page_size))
.collect();
let sizes: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
assert!(sizes.windows(2).all(|w| w[0] == w[1]));
}
#[test]
fn variable_size_alloc() {
unsafe {
for &size in &[8 * 1024usize, 16 * 1024, 32 * 1024, 64 * 1024] {
let stack = stack_alloc_size(size).unwrap();
assert_eq!(stack.hi - stack.lo, size);
stack_free(&stack);
}
}
}
}