use crate::stack::Stack;
use crate::tagged_stack::StackValue;
use may::coroutine;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::{Condvar, Mutex, Once};
use std::time::{Duration, Instant};
static SCHEDULER_INIT: Once = Once::new();
static SCHEDULER_START_TIME: std::sync::OnceLock<Instant> = std::sync::OnceLock::new();
pub static ACTIVE_STRANDS: AtomicUsize = AtomicUsize::new(0);
pub(crate) static SHUTDOWN_CONDVAR: Condvar = Condvar::new();
pub(crate) static SHUTDOWN_MUTEX: Mutex<()> = Mutex::new(());
pub static TOTAL_SPAWNED: AtomicU64 = AtomicU64::new(0);
pub static TOTAL_COMPLETED: AtomicU64 = AtomicU64::new(0);
pub static PEAK_STRANDS: AtomicUsize = AtomicUsize::new(0);
static NEXT_STRAND_ID: AtomicU64 = AtomicU64::new(1);
#[cfg(feature = "diagnostics")]
const DEFAULT_REGISTRY_SIZE: usize = 1024;
#[cfg(feature = "diagnostics")]
pub struct StrandSlot {
pub strand_id: AtomicU64,
pub spawn_time: AtomicU64,
}
#[cfg(feature = "diagnostics")]
impl StrandSlot {
const fn new() -> Self {
Self {
strand_id: AtomicU64::new(0),
spawn_time: AtomicU64::new(0),
}
}
}
#[cfg(feature = "diagnostics")]
pub struct StrandRegistry {
slots: Box<[StrandSlot]>,
pub overflow_count: AtomicU64,
}
#[cfg(feature = "diagnostics")]
impl StrandRegistry {
fn new(capacity: usize) -> Self {
let mut slots = Vec::with_capacity(capacity);
for _ in 0..capacity {
slots.push(StrandSlot::new());
}
Self {
slots: slots.into_boxed_slice(),
overflow_count: AtomicU64::new(0),
}
}
pub fn register(&self, strand_id: u64) -> Option<usize> {
let spawn_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
for (idx, slot) in self.slots.iter().enumerate() {
slot.spawn_time.store(spawn_time, Ordering::Relaxed);
if slot
.strand_id
.compare_exchange(0, strand_id, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
return Some(idx);
}
}
self.overflow_count.fetch_add(1, Ordering::Relaxed);
None
}
pub fn unregister(&self, strand_id: u64) -> bool {
for slot in self.slots.iter() {
if slot
.strand_id
.compare_exchange(strand_id, 0, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
slot.spawn_time.store(0, Ordering::Release);
return true;
}
}
false
}
pub fn active_strands(&self) -> impl Iterator<Item = (u64, u64)> + '_ {
self.slots.iter().filter_map(|slot| {
let id = slot.strand_id.load(Ordering::Acquire);
if id > 0 {
let time = slot.spawn_time.load(Ordering::Relaxed);
Some((id, time))
} else {
None
}
})
}
pub fn capacity(&self) -> usize {
self.slots.len()
}
}
#[cfg(feature = "diagnostics")]
static STRAND_REGISTRY: std::sync::OnceLock<StrandRegistry> = std::sync::OnceLock::new();
#[cfg(feature = "diagnostics")]
pub fn strand_registry() -> &'static StrandRegistry {
STRAND_REGISTRY.get_or_init(|| {
let size = std::env::var("SEQ_STRAND_REGISTRY_SIZE")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(DEFAULT_REGISTRY_SIZE);
StrandRegistry::new(size)
})
}
pub fn scheduler_elapsed() -> Option<Duration> {
SCHEDULER_START_TIME.get().map(|start| start.elapsed())
}
const DEFAULT_STACK_SIZE: usize = 0x20000;
fn parse_stack_size(env_value: Option<String>) -> usize {
match env_value {
Some(val) => match val.parse::<usize>() {
Ok(0) => {
eprintln!(
"Warning: SEQ_STACK_SIZE=0 is invalid, using default {}",
DEFAULT_STACK_SIZE
);
DEFAULT_STACK_SIZE
}
Ok(size) => size,
Err(_) => {
eprintln!(
"Warning: SEQ_STACK_SIZE='{}' is not a valid number, using default {}",
val, DEFAULT_STACK_SIZE
);
DEFAULT_STACK_SIZE
}
},
None => DEFAULT_STACK_SIZE,
}
}
const DEFAULT_POOL_CAPACITY: usize = 10000;
#[unsafe(no_mangle)]
pub unsafe extern "C" fn patch_seq_scheduler_init() {
SCHEDULER_INIT.call_once(|| {
let stack_size = parse_stack_size(std::env::var("SEQ_STACK_SIZE").ok());
let pool_capacity = std::env::var("SEQ_POOL_CAPACITY")
.ok()
.and_then(|s| s.parse().ok())
.filter(|&v| v > 0)
.unwrap_or(DEFAULT_POOL_CAPACITY);
may::config()
.set_stack_size(stack_size)
.set_pool_capacity(pool_capacity);
SCHEDULER_START_TIME.get_or_init(Instant::now);
#[cfg(unix)]
{
use std::sync::atomic::{AtomicBool, Ordering};
static SIGINT_RECEIVED: AtomicBool = AtomicBool::new(false);
extern "C" fn sigint_handler(_: libc::c_int) {
if SIGINT_RECEIVED.swap(true, Ordering::SeqCst) {
unsafe { libc::_exit(130) }; }
std::process::exit(130);
}
unsafe {
libc::signal(
libc::SIGINT,
sigint_handler as *const () as libc::sighandler_t,
);
}
}
#[cfg(feature = "diagnostics")]
crate::diagnostics::install_signal_handler();
#[cfg(feature = "diagnostics")]
crate::watchdog::install_watchdog();
});
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn patch_seq_scheduler_run() -> Stack {
let mut guard = SHUTDOWN_MUTEX.lock().expect(
"scheduler_run: shutdown mutex poisoned - strand panicked during shutdown synchronization",
);
while ACTIVE_STRANDS.load(Ordering::Acquire) > 0 {
guard = SHUTDOWN_CONDVAR
.wait(guard)
.expect("scheduler_run: condvar wait failed - strand panicked during shutdown wait");
}
std::ptr::null_mut()
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn patch_seq_scheduler_shutdown() {
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn patch_seq_strand_spawn(
entry: extern "C" fn(Stack) -> Stack,
initial_stack: Stack,
) -> i64 {
unsafe { patch_seq_strand_spawn_with_base(entry, initial_stack, std::ptr::null_mut()) }
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn patch_seq_strand_spawn_with_base(
entry: extern "C" fn(Stack) -> Stack,
initial_stack: Stack,
stack_base: Stack,
) -> i64 {
let strand_id = NEXT_STRAND_ID.fetch_add(1, Ordering::Relaxed);
let new_count = ACTIVE_STRANDS.fetch_add(1, Ordering::Release) + 1;
TOTAL_SPAWNED.fetch_add(1, Ordering::Relaxed);
let mut peak = PEAK_STRANDS.load(Ordering::Acquire);
while new_count > peak {
match PEAK_STRANDS.compare_exchange_weak(
peak,
new_count,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(current) => peak = current,
}
}
#[cfg(feature = "diagnostics")]
let _ = strand_registry().register(strand_id);
let entry_fn = entry;
let stack_addr = initial_stack as usize;
let base_addr = stack_base as usize;
unsafe {
coroutine::spawn(move || {
let stack_ptr = stack_addr as *mut StackValue;
let base_ptr = base_addr as *mut StackValue;
debug_assert!(
stack_ptr.is_null()
|| stack_addr.is_multiple_of(std::mem::align_of::<StackValue>()),
"Stack pointer must be null or properly aligned"
);
debug_assert!(
stack_ptr.is_null() || stack_addr > 0x1000,
"Stack pointer appears to be in invalid memory region (< 0x1000)"
);
if !base_ptr.is_null() {
crate::stack::patch_seq_set_stack_base(base_ptr);
}
let final_stack = entry_fn(stack_ptr);
free_stack(final_stack);
#[cfg(feature = "diagnostics")]
strand_registry().unregister(strand_id);
let prev_count = ACTIVE_STRANDS.fetch_sub(1, Ordering::AcqRel);
TOTAL_COMPLETED.fetch_add(1, Ordering::Release);
if prev_count == 1 {
let _guard = SHUTDOWN_MUTEX.lock()
.expect("strand_spawn: shutdown mutex poisoned - strand panicked during shutdown notification");
SHUTDOWN_CONDVAR.notify_all();
}
});
}
strand_id as i64
}
fn free_stack(_stack: Stack) {
crate::arena::arena_reset();
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn patch_seq_spawn_strand(entry: extern "C" fn(Stack) -> Stack) {
unsafe {
patch_seq_strand_spawn(entry, std::ptr::null_mut());
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn patch_seq_yield_strand(stack: Stack) -> Stack {
coroutine::yield_now();
stack
}
use std::cell::Cell;
use std::sync::OnceLock;
static YIELD_THRESHOLD: OnceLock<u64> = OnceLock::new();
thread_local! {
static TAIL_CALL_COUNTER: Cell<u64> = const { Cell::new(0) };
}
fn get_yield_threshold() -> u64 {
*YIELD_THRESHOLD.get_or_init(|| {
match std::env::var("SEQ_YIELD_INTERVAL") {
Ok(s) if s.is_empty() => 0,
Ok(s) => match s.parse::<u64>() {
Ok(n) => n,
Err(_) => {
eprintln!(
"Warning: SEQ_YIELD_INTERVAL='{}' is not a valid positive integer, yield safety valve disabled",
s
);
0
}
},
Err(_) => 0,
}
})
}
#[unsafe(no_mangle)]
pub extern "C" fn patch_seq_maybe_yield() {
let threshold = get_yield_threshold();
if threshold == 0 {
return;
}
TAIL_CALL_COUNTER.with(|counter| {
let count = counter.get().wrapping_add(1);
counter.set(count);
if count >= threshold {
counter.set(0);
coroutine::yield_now();
}
});
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn patch_seq_wait_all_strands() {
let mut guard = SHUTDOWN_MUTEX.lock()
.expect("wait_all_strands: shutdown mutex poisoned - strand panicked during shutdown synchronization");
while ACTIVE_STRANDS.load(Ordering::Acquire) > 0 {
guard = SHUTDOWN_CONDVAR
.wait(guard)
.expect("wait_all_strands: condvar wait failed - strand panicked during shutdown wait");
}
}
pub use patch_seq_maybe_yield as maybe_yield;
pub use patch_seq_scheduler_init as scheduler_init;
pub use patch_seq_scheduler_run as scheduler_run;
pub use patch_seq_scheduler_shutdown as scheduler_shutdown;
pub use patch_seq_spawn_strand as spawn_strand;
pub use patch_seq_strand_spawn as strand_spawn;
pub use patch_seq_wait_all_strands as wait_all_strands;
pub use patch_seq_yield_strand as yield_strand;
#[cfg(test)]
mod tests {
use super::*;
use crate::stack::push;
use crate::value::Value;
use std::sync::atomic::{AtomicU32, Ordering};
#[test]
fn test_spawn_strand() {
unsafe {
static COUNTER: AtomicU32 = AtomicU32::new(0);
extern "C" fn test_entry(_stack: Stack) -> Stack {
COUNTER.fetch_add(1, Ordering::SeqCst);
std::ptr::null_mut()
}
for _ in 0..100 {
spawn_strand(test_entry);
}
std::thread::sleep(std::time::Duration::from_millis(200));
assert_eq!(COUNTER.load(Ordering::SeqCst), 100);
}
}
#[test]
fn test_scheduler_init_idempotent() {
unsafe {
scheduler_init();
scheduler_init();
scheduler_init();
}
}
#[test]
fn test_free_stack_null() {
free_stack(std::ptr::null_mut());
}
#[test]
fn test_free_stack_valid() {
unsafe {
let stack = push(crate::stack::alloc_test_stack(), Value::Int(42));
free_stack(stack);
}
}
#[test]
fn test_strand_spawn_with_stack() {
unsafe {
static COUNTER: AtomicU32 = AtomicU32::new(0);
extern "C" fn test_entry(stack: Stack) -> Stack {
COUNTER.fetch_add(1, Ordering::SeqCst);
stack
}
let initial_stack = push(crate::stack::alloc_test_stack(), Value::Int(99));
strand_spawn(test_entry, initial_stack);
std::thread::sleep(std::time::Duration::from_millis(200));
assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
}
}
#[test]
fn test_scheduler_shutdown() {
unsafe {
scheduler_init();
scheduler_shutdown();
}
}
#[test]
fn test_many_strands_stress() {
unsafe {
static COUNTER: AtomicU32 = AtomicU32::new(0);
extern "C" fn increment(_stack: Stack) -> Stack {
COUNTER.fetch_add(1, Ordering::SeqCst);
std::ptr::null_mut()
}
COUNTER.store(0, Ordering::SeqCst);
for _ in 0..1000 {
strand_spawn(increment, std::ptr::null_mut());
}
wait_all_strands();
assert_eq!(COUNTER.load(Ordering::SeqCst), 1000);
}
}
#[test]
fn test_strand_ids_are_unique() {
unsafe {
use std::collections::HashSet;
extern "C" fn noop(_stack: Stack) -> Stack {
std::ptr::null_mut()
}
let mut ids = Vec::new();
for _ in 0..100 {
let id = strand_spawn(noop, std::ptr::null_mut());
ids.push(id);
}
wait_all_strands();
let unique_ids: HashSet<_> = ids.iter().collect();
assert_eq!(unique_ids.len(), 100, "All strand IDs should be unique");
assert!(
ids.iter().all(|&id| id > 0),
"All strand IDs should be positive"
);
}
}
#[test]
fn test_arena_reset_with_strands() {
unsafe {
use crate::arena;
use crate::seqstring::arena_string;
extern "C" fn create_temp_strings(stack: Stack) -> Stack {
for i in 0..100 {
let temp = arena_string(&format!("temporary string {}", i));
assert!(!temp.as_str().is_empty());
}
let stats = arena::arena_stats();
assert!(stats.allocated_bytes > 0, "Arena should have allocations");
stack }
arena::arena_reset();
strand_spawn(create_temp_strings, std::ptr::null_mut());
wait_all_strands();
let stats_after = arena::arena_stats();
assert_eq!(
stats_after.allocated_bytes, 0,
"Arena should be reset after strand exits"
);
}
}
#[test]
fn test_arena_with_channel_send() {
unsafe {
use crate::channel::{close_channel, make_channel, receive, send};
use crate::stack::{pop, push};
use crate::value::Value;
use std::sync::Arc;
use std::sync::atomic::{AtomicI64, AtomicU32, Ordering};
static RECEIVED_COUNT: AtomicU32 = AtomicU32::new(0);
static CHANNEL_PTR: AtomicI64 = AtomicI64::new(0);
let stack = crate::stack::alloc_test_stack();
let stack = make_channel(stack);
let (stack, chan_val) = pop(stack);
let channel = match chan_val {
Value::Channel(ch) => ch,
_ => panic!("Expected Channel"),
};
let ch_ptr = Arc::as_ptr(&channel) as i64;
CHANNEL_PTR.store(ch_ptr, Ordering::Release);
std::mem::forget(channel.clone());
std::mem::forget(channel.clone());
extern "C" fn sender(_stack: Stack) -> Stack {
use crate::seqstring::arena_string;
use crate::value::ChannelData;
use std::sync::Arc;
unsafe {
let ch_ptr = CHANNEL_PTR.load(Ordering::Acquire) as *const ChannelData;
let channel = Arc::from_raw(ch_ptr);
let channel_clone = Arc::clone(&channel);
std::mem::forget(channel);
let msg = arena_string("Hello from sender!");
let stack = push(crate::stack::alloc_test_stack(), Value::String(msg));
let stack = push(stack, Value::Channel(channel_clone));
send(stack)
}
}
extern "C" fn receiver(_stack: Stack) -> Stack {
use crate::value::ChannelData;
use std::sync::Arc;
use std::sync::atomic::Ordering;
unsafe {
let ch_ptr = CHANNEL_PTR.load(Ordering::Acquire) as *const ChannelData;
let channel = Arc::from_raw(ch_ptr);
let channel_clone = Arc::clone(&channel);
std::mem::forget(channel);
let stack = push(
crate::stack::alloc_test_stack(),
Value::Channel(channel_clone),
);
let stack = receive(stack);
let (stack, _success) = pop(stack);
let (_stack, msg_val) = pop(stack);
match msg_val {
Value::String(s) => {
assert_eq!(s.as_str(), "Hello from sender!");
RECEIVED_COUNT.fetch_add(1, Ordering::SeqCst);
}
_ => panic!("Expected String"),
}
std::ptr::null_mut()
}
}
spawn_strand(sender);
spawn_strand(receiver);
wait_all_strands();
assert_eq!(
RECEIVED_COUNT.load(Ordering::SeqCst),
1,
"Receiver should have received message"
);
let stack = push(stack, Value::Channel(channel));
close_channel(stack);
}
}
#[test]
fn test_no_memory_leak_over_many_iterations() {
unsafe {
use crate::arena;
use crate::seqstring::arena_string;
extern "C" fn allocate_strings_and_exit(stack: Stack) -> Stack {
for i in 0..50 {
let temp = arena_string(&format!("request header {}", i));
assert!(!temp.as_str().is_empty());
}
stack
}
let iterations = 10_000;
for i in 0..iterations {
arena::arena_reset();
strand_spawn(allocate_strings_and_exit, std::ptr::null_mut());
wait_all_strands();
if i % 1000 == 0 {
let stats = arena::arena_stats();
assert_eq!(
stats.allocated_bytes, 0,
"Arena not reset after iteration {} (leaked {} bytes)",
i, stats.allocated_bytes
);
}
}
let final_stats = arena::arena_stats();
assert_eq!(
final_stats.allocated_bytes, 0,
"Arena leaked memory after {} iterations ({} bytes)",
iterations, final_stats.allocated_bytes
);
println!(
"✓ Memory leak test passed: {} iterations with no growth",
iterations
);
}
}
#[test]
fn test_parse_stack_size_valid() {
assert_eq!(parse_stack_size(Some("2097152".to_string())), 2097152);
assert_eq!(parse_stack_size(Some("1".to_string())), 1);
assert_eq!(parse_stack_size(Some("999999999".to_string())), 999999999);
}
#[test]
fn test_parse_stack_size_none() {
assert_eq!(parse_stack_size(None), DEFAULT_STACK_SIZE);
}
#[test]
fn test_parse_stack_size_zero() {
assert_eq!(parse_stack_size(Some("0".to_string())), DEFAULT_STACK_SIZE);
}
#[test]
fn test_parse_stack_size_invalid() {
assert_eq!(
parse_stack_size(Some("invalid".to_string())),
DEFAULT_STACK_SIZE
);
assert_eq!(
parse_stack_size(Some("-100".to_string())),
DEFAULT_STACK_SIZE
);
assert_eq!(parse_stack_size(Some("".to_string())), DEFAULT_STACK_SIZE);
assert_eq!(
parse_stack_size(Some("1.5".to_string())),
DEFAULT_STACK_SIZE
);
}
#[test]
#[cfg(feature = "diagnostics")]
fn test_strand_registry_basic() {
let registry = StrandRegistry::new(10);
assert_eq!(registry.register(1), Some(0)); assert_eq!(registry.register(2), Some(1)); assert_eq!(registry.register(3), Some(2));
let active: Vec<_> = registry.active_strands().collect();
assert_eq!(active.len(), 3);
assert!(registry.unregister(2));
let active: Vec<_> = registry.active_strands().collect();
assert_eq!(active.len(), 2);
assert!(!registry.unregister(999));
}
#[test]
#[cfg(feature = "diagnostics")]
fn test_strand_registry_overflow() {
let registry = StrandRegistry::new(3);
assert!(registry.register(1).is_some());
assert!(registry.register(2).is_some());
assert!(registry.register(3).is_some());
assert!(registry.register(4).is_none());
assert_eq!(registry.overflow_count.load(Ordering::Relaxed), 1);
assert!(registry.register(5).is_none());
assert_eq!(registry.overflow_count.load(Ordering::Relaxed), 2);
}
#[test]
#[cfg(feature = "diagnostics")]
fn test_strand_registry_slot_reuse() {
let registry = StrandRegistry::new(3);
registry.register(1);
registry.register(2);
registry.register(3);
registry.unregister(2);
assert!(registry.register(4).is_some());
assert_eq!(registry.active_strands().count(), 3);
}
#[test]
#[cfg(feature = "diagnostics")]
fn test_strand_registry_concurrent_stress() {
use std::sync::Arc;
use std::thread;
let registry = Arc::new(StrandRegistry::new(50));
let handles: Vec<_> = (0..100)
.map(|i| {
let reg = Arc::clone(®istry);
thread::spawn(move || {
let id = (i + 1) as u64;
let _ = reg.register(id);
thread::yield_now();
reg.unregister(id);
})
})
.collect();
for h in handles {
h.join().unwrap();
}
assert_eq!(registry.active_strands().count(), 0);
}
#[test]
fn test_strand_lifecycle_counters() {
unsafe {
let initial_spawned = TOTAL_SPAWNED.load(Ordering::Relaxed);
let initial_completed = TOTAL_COMPLETED.load(Ordering::Relaxed);
static COUNTER: AtomicU32 = AtomicU32::new(0);
extern "C" fn simple_work(_stack: Stack) -> Stack {
COUNTER.fetch_add(1, Ordering::SeqCst);
std::ptr::null_mut()
}
COUNTER.store(0, Ordering::SeqCst);
for _ in 0..10 {
strand_spawn(simple_work, std::ptr::null_mut());
}
wait_all_strands();
let final_spawned = TOTAL_SPAWNED.load(Ordering::Relaxed);
let final_completed = TOTAL_COMPLETED.load(Ordering::Relaxed);
assert!(
final_spawned >= initial_spawned + 10,
"TOTAL_SPAWNED should have increased by at least 10"
);
assert!(
final_completed >= initial_completed + 10,
"TOTAL_COMPLETED should have increased by at least 10"
);
assert_eq!(COUNTER.load(Ordering::SeqCst), 10);
}
}
#[test]
fn test_maybe_yield_disabled_by_default() {
for _ in 0..1000 {
patch_seq_maybe_yield();
}
}
#[test]
fn test_tail_call_counter_increments() {
TAIL_CALL_COUNTER.with(|counter| {
let initial = counter.get();
patch_seq_maybe_yield();
patch_seq_maybe_yield();
patch_seq_maybe_yield();
let _ = counter.get();
counter.set(initial);
});
}
#[test]
fn test_counter_overflow_safety() {
TAIL_CALL_COUNTER.with(|counter| {
let initial = counter.get();
counter.set(u64::MAX - 1);
patch_seq_maybe_yield();
patch_seq_maybe_yield();
patch_seq_maybe_yield();
counter.set(initial);
});
}
}