use std::collections::{BTreeSet, VecDeque};
use std::num::NonZeroUsize;
use std::{io, marker, mem, panic, thread};
mod context_switch;
mod stack;
mod uring;
use std::cell::UnsafeCell;
thread_local! {
static RUNTIME: UnsafeCell<Option<RuntimeState>> = UnsafeCell::new(None);
}
fn ensure_runtime_exists() {
RUNTIME.with(|tls| {
let runtime = unsafe { &*tls.get() };
assert!(runtime.is_some());
})
}
unsafe fn runtime() -> &'static mut RuntimeState {
RUNTIME.with(|tls| {
let borrow = &mut *tls.get();
borrow.as_mut().unwrap() })
}
pub unsafe fn concurrency_pair() -> (Waker, Waiter) {
let running = runtime().running();
(Waker(running), Waiter(running))
}
#[derive(Debug)]
pub struct Waker(FiberIndex);
impl Waker {
pub unsafe fn schedule(self) {
runtime().ready_fibers.push_back(self.0);
}
pub unsafe fn schedule_immediately(self) {
runtime().ready_fibers.push_front(self.0);
}
}
#[derive(Debug)]
pub struct Waiter(FiberIndex);
impl Waiter {
pub unsafe fn park(self) {
let to = runtime().process_io_and_wait();
let to = runtime().fibers.get(to).continuation;
let continuation = &mut runtime().fibers.get_mut(self.0).continuation;
context_switch::jump(to, continuation); }
}
struct RuntimeState {
uring: uring::Uring,
fibers: Fibers,
ready_fibers: VecDeque<FiberIndex>,
running_fiber: Option<FiberIndex>,
stack_pool: Vec<*const u8>,
bootstrap: mem::MaybeUninit<context_switch::Continuation>,
}
impl RuntimeState {
fn new() -> Self {
RuntimeState {
uring: uring::Uring::new(),
fibers: Fibers::new(),
ready_fibers: VecDeque::new(),
running_fiber: None,
stack_pool: Vec::new(),
bootstrap: mem::MaybeUninit::uninit(),
}
}
fn allocate_stack(&mut self) -> *const u8 {
if let Some(stack_bottom) = self.stack_pool.pop() {
return stack_bottom;
}
let stack = stack::Stack::new(NonZeroUsize::MIN, NonZeroUsize::new(32).unwrap()).unwrap();
let stack_base = stack.base();
mem::forget(stack); stack_base
}
fn running(&self) -> FiberIndex {
self.running_fiber.unwrap() }
fn process_io(&mut self) {
for (user_data, result) in self.uring.process_cq() {
let fiber = FiberIndex(user_data.0 as u32);
self.fibers.get_mut(fiber).syscall_result = Some(result);
self.ready_fibers.push_back(fiber);
}
}
fn process_io_and_wait(&mut self) -> FiberIndex {
if let Some(fiber) = self.ready_fibers.pop_front() {
self.running_fiber = Some(fiber);
return fiber;
}
loop {
self.process_io();
if let Some(fiber) = self.ready_fibers.pop_front() {
self.running_fiber = Some(fiber);
break fiber;
}
self.uring.wait_for_completed_syscall();
}
}
fn cancel(&mut self, fiber: FiberIndex) {
let state = self.fibers.get_mut(fiber);
state.is_cancelled = true;
if state.issuing_syscall {
self.uring.cancel_syscall(uring::UserData(fiber.0 as u64));
}
let children = state.children.clone();
for child in children {
self.cancel(child);
}
}
}
impl Drop for RuntimeState {
fn drop(&mut self) {
let guard_pages = 1;
let usable_pages = 32;
let page_size = unsafe { libc::sysconf(libc::_SC_PAGESIZE) as usize };
assert_eq!(page_size, 4096);
let length = (guard_pages + usable_pages) * page_size;
for stack_bottom in self.stack_pool.drain(..) {
let pointer = unsafe { stack_bottom.sub(length) } as *mut u8;
drop(stack::Stack { pointer, length })
}
}
}
struct Fibers(slab::Slab<FiberState>);
impl Fibers {
fn new() -> Self {
Fibers(slab::Slab::new())
}
fn get(&self, fiber: FiberIndex) -> &FiberState {
&self.0[fiber.0 as usize]
}
fn get_mut(&mut self, fiber: FiberIndex) -> &mut FiberState {
&mut self.0[fiber.0 as usize]
}
fn add(
&mut self,
parent: Option<FiberIndex>,
stack_base: *const u8,
continuation: context_switch::Continuation,
is_contained: bool,
) -> FiberIndex {
let index = self.0.insert(FiberState {
stack_base,
continuation,
is_completed: false,
join_handle: JoinHandleState::Unused,
syscall_result: None,
parent,
children: BTreeSet::new(),
is_cancelled: false,
is_contained,
issuing_syscall: false,
});
FiberIndex(index as u32)
}
fn remove(&mut self, fiber: FiberIndex) {
self.0.remove(fiber.0 as usize);
}
fn nearest_contained_ancestor(&self, fiber: FiberIndex) -> FiberIndex {
let mut nearest_contained_ancestor = fiber;
while !self.get(nearest_contained_ancestor).is_contained {
nearest_contained_ancestor = self.get(nearest_contained_ancestor).parent.unwrap();
}
nearest_contained_ancestor
}
}
#[repr(transparent)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct FiberIndex(u32);
#[derive(Debug)]
struct FiberState {
stack_base: *const u8, continuation: context_switch::Continuation,
is_completed: bool,
join_handle: JoinHandleState,
syscall_result: Option<io::Result<u32>>,
parent: Option<FiberIndex>,
children: BTreeSet<FiberIndex>, is_cancelled: bool,
is_contained: bool,
issuing_syscall: bool,
}
#[derive(Debug, Clone)]
enum JoinHandleState {
Unused,
Waiting(FiberIndex), Dropped,
}
pub fn start<F: FnOnce() -> T, T>(f: F) -> thread::Result<T> {
unsafe {
exclusive_runtime(|| {
let stack_base = runtime().allocate_stack();
let closure_pointer = (stack_base as *mut F).sub(1);
closure_pointer.write(f);
let continuation = context_switch::prepare_stack(
stack_base.sub(closure_union_size::<F, T>()) as *mut u8,
start_trampoline::<F, T> as *const (),
);
let root_fiber = runtime().fibers.add(None, stack_base, continuation, true);
runtime().running_fiber = Some(root_fiber);
let bootstrap = runtime().bootstrap.as_mut_ptr();
context_switch::jump(continuation, bootstrap);
let output_pointer = (stack_base as *const thread::Result<T>).sub(1);
output_pointer.read()
})
}
}
unsafe fn exclusive_runtime<T>(f: impl FnOnce() -> T) -> T {
RUNTIME.with(|tls| {
let runtime = &mut *tls.get();
assert!(runtime.is_none());
*runtime = Some(RuntimeState::new());
});
let output = f();
RUNTIME.with(|tls| {
let runtime = &mut *tls.get();
*runtime = None;
});
output
}
unsafe extern "C" fn start_trampoline<F: FnOnce() -> T, T>() -> ! {
let running = runtime().running();
let stack_base = runtime().fibers.get(running).stack_base;
let closure_pointer = (stack_base as *const F).sub(1);
let output_pointer = (stack_base as *mut thread::Result<T>).sub(1);
let closure = closure_pointer.read();
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| (closure)()));
output_pointer.write(result);
runtime().fibers.get_mut(running).is_completed = true;
if !runtime().fibers.get(running).children.is_empty() {
let to = runtime().process_io_and_wait();
let to = runtime().fibers.get(to).continuation;
let continuation = &mut runtime().fibers.get_mut(running).continuation;
context_switch::jump(to, continuation); }
runtime().stack_pool.push(stack_base);
runtime().fibers.remove(running);
let to = runtime().bootstrap.assume_init();
let mut dummy = mem::MaybeUninit::uninit();
unsafe { context_switch::jump(to, dummy.as_mut_ptr()) };
unreachable!();
}
pub fn yield_now() {
ensure_runtime_exists();
unsafe {
runtime().process_io();
if runtime().ready_fibers.is_empty() {
return;
}
let (waker, waiter) = concurrency_pair();
waker.schedule();
waiter.park();
}
}
pub fn spawn<F: FnOnce() -> T + 'static, T: 'static>(f: F) -> JoinHandle<T> {
ensure_runtime_exists();
unsafe { spawn_inner(f, false) }
}
pub fn contain<F: FnOnce() -> T + 'static, T: 'static>(f: F) -> thread::Result<T> {
ensure_runtime_exists();
unsafe { spawn_inner(f, true) }.join()
}
unsafe fn spawn_inner<F: FnOnce() -> T + 'static, T: 'static>(
f: F,
is_contained: bool,
) -> JoinHandle<T> {
let stack_base = runtime().allocate_stack();
let closure_pointer = (stack_base as *mut F).sub(1);
closure_pointer.write(f);
let continuation = context_switch::prepare_stack(
stack_base.sub(closure_union_size::<F, T>()) as *mut u8,
spawn_trampoline::<F, T> as *const (),
);
let parent = runtime().running();
let child = runtime()
.fibers
.add(Some(parent), stack_base, continuation, is_contained);
runtime().fibers.get_mut(parent).children.insert(child);
runtime().ready_fibers.push_back(child);
JoinHandle::new(child)
}
unsafe extern "C" fn spawn_trampoline<F: FnOnce() -> T, T>() -> ! {
let running = runtime().running();
let stack_base = runtime().fibers.get(running).stack_base;
let closure_pointer = (stack_base as *const F).sub(1);
let output_pointer = (stack_base as *mut thread::Result<T>).sub(1);
let closure = closure_pointer.read();
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| (closure)()));
let is_err = result.is_err();
output_pointer.write(result);
runtime().fibers.get_mut(running).is_completed = true;
if is_err {
let nearest_contained_ancestor = runtime().fibers.nearest_contained_ancestor(running);
runtime().cancel(nearest_contained_ancestor);
}
if !runtime().fibers.get(running).children.is_empty() {
let to = runtime().process_io_and_wait();
let to = runtime().fibers.get(to).continuation;
let continuation = &mut runtime().fibers.get_mut(running).continuation;
context_switch::jump(to, continuation); }
if let JoinHandleState::Waiting(fiber) = runtime().fibers.get(running).join_handle {
runtime().ready_fibers.push_back(fiber);
}
let parent = runtime().fibers.get(running).parent.unwrap();
runtime().fibers.get_mut(parent).children.remove(&running);
if runtime().fibers.get(parent).is_completed && runtime().fibers.get(parent).children.is_empty()
{
runtime().ready_fibers.push_back(parent);
}
if let JoinHandleState::Dropped = runtime().fibers.get(running).join_handle {
let stack_base = runtime().fibers.get(running).stack_base;
runtime().stack_pool.push(stack_base);
}
let to = runtime().process_io_and_wait();
let to = runtime().fibers.get(to).continuation;
let mut dummy = mem::MaybeUninit::uninit();
unsafe { context_switch::jump(to, dummy.as_mut_ptr()) };
unreachable!();
}
#[derive(Debug)]
pub struct JoinHandle<T> {
fiber: FiberIndex,
output: marker::PhantomData<T>,
}
impl<T> JoinHandle<T> {
fn new(fiber: FiberIndex) -> Self {
JoinHandle {
fiber,
output: marker::PhantomData,
}
}
pub fn join(self) -> thread::Result<T> {
ensure_runtime_exists();
unsafe {
let stack_base = runtime().fibers.get(self.fiber).stack_base;
let output_pointer = (stack_base as *const thread::Result<T>).sub(1);
if runtime().fibers.get(self.fiber).is_completed {
return output_pointer.read();
}
let running = runtime().running();
runtime().fibers.get_mut(self.fiber).join_handle = JoinHandleState::Waiting(running);
let to = runtime().process_io_and_wait();
let to = runtime().fibers.get(to).continuation;
let continuation = &mut runtime().fibers.get_mut(running).continuation;
context_switch::jump(to, continuation);
assert!(runtime().fibers.get(self.fiber).is_completed);
output_pointer.read()
}
}
pub fn cancel(&self) {
ensure_runtime_exists();
unsafe {
runtime().cancel(self.fiber);
}
}
}
impl<T> Drop for JoinHandle<T> {
fn drop(&mut self) {
ensure_runtime_exists();
unsafe {
runtime().fibers.get_mut(self.fiber).join_handle = JoinHandleState::Dropped;
if runtime().fibers.get(self.fiber).is_completed {
let stack_base = runtime().fibers.get(self.fiber).stack_base;
runtime().stack_pool.push(stack_base);
runtime().fibers.remove(self.fiber);
}
}
}
}
pub(crate) fn syscall(sqe: io_uring::squeue::Entry) -> io::Result<u32> {
ensure_runtime_exists();
unsafe {
let running = runtime().running();
assert!(!runtime().fibers.get(running).issuing_syscall);
runtime().fibers.get_mut(running).issuing_syscall = true;
assert!(runtime().fibers.get(running).syscall_result.is_none());
runtime()
.uring
.issue_syscall(uring::UserData(running.0 as u64), sqe);
let to = runtime().process_io_and_wait();
if running != to {
let to = runtime().fibers.get(to).continuation;
let continuation = &mut runtime().fibers.get_mut(running).continuation;
context_switch::jump(to, continuation); }
assert!(runtime().fibers.get(running).issuing_syscall);
runtime().fibers.get_mut(running).issuing_syscall = false;
runtime()
.fibers
.get_mut(running)
.syscall_result
.take()
.unwrap()
}
}
pub fn nop() -> io::Result<()> {
let result = syscall(io_uring::opcode::Nop::new().build())?;
assert_eq!(result, 0);
Ok(())
}
const fn closure_union_size<F: FnOnce() -> T, T>() -> usize {
let closure_size = mem::size_of::<F>();
let output_size = mem::size_of::<T>();
if closure_size > output_size {
closure_size
} else {
output_size
}
}
pub fn cancel() {
ensure_runtime_exists();
unsafe {
let running = runtime().running();
runtime().cancel(running);
}
}
pub fn is_cancelled() -> bool {
ensure_runtime_exists();
unsafe {
let running = runtime().running();
runtime().fibers.get(running).is_cancelled
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::time;
use std::thread;
use std::time::{Duration, Instant};
mod start {
use super::*;
use std::time::Duration;
#[test]
fn returns_output() {
let output = start(|| 123);
assert_eq!(output.unwrap(), 123);
}
#[test]
fn catches_panic() {
let result = start(|| panic!("oops"));
assert!(result.is_err());
}
#[test]
#[should_panic]
fn cant_nest() {
start(|| {
start(|| {}).unwrap();
})
.unwrap();
}
#[test]
fn waits_for_children() {
static mut VALUE: usize = 0;
start(|| {
let handle = spawn(|| unsafe { VALUE += 1 });
drop(handle);
let handle = spawn(|| unsafe { VALUE += 1 });
mem::forget(handle);
})
.unwrap();
assert_eq!(unsafe { VALUE }, 2);
}
#[test]
fn works_consecutively() {
start(|| {}).unwrap();
start(|| {}).unwrap();
}
#[test]
fn works_in_parallel() {
let handle = thread::spawn(|| {
start(|| {
thread::sleep(Duration::from_millis(2));
})
.unwrap();
});
thread::sleep(Duration::from_millis(1));
start(|| {}).unwrap();
assert!(handle.join().is_ok());
}
}
mod contain {
use super::*;
#[test]
fn returns_output() {
start(|| {
let output = contain(|| 123);
assert_eq!(output.unwrap(), 123);
})
.unwrap();
}
#[test]
fn catches_panic() {
start(|| {
let result = contain(|| panic!("oops"));
assert!(result.is_err());
})
.unwrap();
}
#[test]
fn cant_nest_start() {
start(|| {
let result = contain(|| start(|| {}).unwrap());
assert!(result.is_err());
})
.unwrap();
}
#[test]
fn waits_for_children() {
start(|| {
static mut VALUE: usize = 0;
contain(|| {
let handle = spawn(|| unsafe { VALUE += 1 });
drop(handle);
let handle = spawn(|| unsafe { VALUE += 1 });
mem::forget(handle);
})
.unwrap();
assert_eq!(unsafe { VALUE }, 2);
})
.unwrap();
}
}
mod spawn {
use super::*;
#[test]
fn returns_child_output() {
start(|| {
let handle = spawn(|| 123);
let output = handle.join();
assert_eq!(output.unwrap(), 123);
})
.unwrap();
}
#[test]
fn returns_non_child_output() {
start(|| {
let other = spawn(|| 123);
let handle = spawn(|| other.join().unwrap());
let output = handle.join();
assert_eq!(output.unwrap(), 123);
})
.unwrap();
}
#[test]
fn returns_already_completed_output() {
start(|| {
let handle = spawn(|| 123);
yield_now();
let output = handle.join();
assert_eq!(output.unwrap(), 123);
})
.unwrap();
}
#[test]
fn catches_panic() {
start(|| {
let result = spawn(|| panic!("oops")).join();
assert!(result.is_err());
})
.unwrap();
}
#[test]
fn cant_nest_start() {
start(|| {
let result = spawn(|| start(|| {})).join();
assert!(result.is_err());
})
.unwrap();
}
#[test]
fn waits_for_children() {
start(|| {
static mut VALUE: usize = 0;
spawn(|| {
let handle = spawn(|| unsafe { VALUE += 1 });
drop(handle);
let handle = spawn(|| unsafe { VALUE += 1 });
mem::forget(handle);
})
.join()
.unwrap();
assert_eq!(unsafe { VALUE }, 2);
})
.unwrap();
}
}
mod yield_now {
use super::*;
#[test]
fn to_same_fiber() {
start(|| {
yield_now();
})
.unwrap();
}
#[test]
fn to_other_fiber() {
start(|| {
static mut VALUE: usize = 0;
spawn(|| unsafe { VALUE += 1 });
assert_eq!(unsafe { VALUE }, 0);
yield_now();
assert_eq!(unsafe { VALUE }, 1);
})
.unwrap();
}
}
mod cancellation {
use super::*;
#[test]
fn initially_not_cancelled() {
start(|| {
assert!(!is_cancelled());
})
.unwrap();
}
#[test]
fn function_cancels_fiber_hierarchy() {
start(|| {
contain(|| {
let handle = spawn(|| {
assert!(is_cancelled());
});
cancel();
handle.join().unwrap();
assert!(is_cancelled());
})
.unwrap();
assert!(!is_cancelled());
})
.unwrap();
}
#[test]
fn panic_cancels_fiber_hierarchy() {
static mut GRAND_CHILD_CANCELLED: bool = false;
start(|| {
let _ = contain(|| {
let handle = spawn(|| {
spawn(|| unsafe {
dbg!(is_cancelled());
GRAND_CHILD_CANCELLED = is_cancelled();
});
panic!("oops");
});
handle.join().unwrap();
assert!(is_cancelled());
});
assert!(unsafe { GRAND_CHILD_CANCELLED });
assert!(!is_cancelled());
})
.unwrap();
}
#[test]
fn method_cancels_fiber_hierarchy() {
start(|| {
let handle = spawn(|| {
spawn(|| {
assert!(is_cancelled());
});
yield_now();
assert!(is_cancelled());
});
yield_now();
handle.cancel();
handle.join().unwrap();
assert!(!is_cancelled());
})
.unwrap();
}
}
}