#![feature(ptr_alignment_type)]
#![feature(coroutine_trait)]
#![feature(ptr_as_uninit)]
#![feature(ptr_mask)]
#![feature(ptr_sub_ptr)]
#![feature(strict_provenance)]
#![feature(nonzero_ops)]
#![feature(pointer_is_aligned_to)]
#![feature(naked_functions)]
#![feature(asm_const)]
mod arch;
pub(crate) mod os;
pub mod stack;
use crate::arch::STACK_FRAME_ALIGN;
use crate::stack::{align_alloc, Stack, StackOrientation, StackPool, COMMON_POOL};
use core::marker::PhantomData;
use core::mem::ManuallyDrop;
use core::num::NonZeroUsize;
use core::ops::{Coroutine, CoroutineState};
use core::pin::Pin;
use core::ptr::{Alignment, NonNull};
struct Control {
instr_ptr: *const u8,
stack_ptr: NonNull<u8>,
}
impl Control {
pub const SIZE: usize = size_of::<Self>();
}
pub struct Coro<Return> {
#[cfg(feature = "std")]
stack: ManuallyDrop<Stack>,
#[cfg(not(feature = "std"))]
stack: Stack,
phantom: PhantomData<fn() -> Return>,
}
pub const STACK_SIZE: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(1 << 21) };
pub const STACK_ALIGN: Alignment = unsafe { Alignment::new_unchecked(STACK_SIZE.get()) };
impl<Return> Coro<Return> {
pub fn new<F>(f: F) -> Self
where
F: FnOnce() -> Return + 'static,
{
COMMON_POOL.with_borrow_mut(|pool| Self::with_stack_from(pool, f))
}
#[allow(rustdoc::broken_intra_doc_links)]
pub fn with_stack_from<F>(pool: &mut StackPool, f: F) -> Self
where
F: FnOnce() -> Return + 'static,
{
extern "system" fn trampoline<F, Return>() -> !
where
F: FnOnce() -> Return + 'static,
{
let stack_start = unsafe { stack_start() };
let stack_bottom = match StackOrientation::current() {
StackOrientation::Upwards => unsafe { stack_start.byte_add(Control::SIZE) },
StackOrientation::Downwards => unsafe {
stack_start.byte_add(STACK_SIZE.get() - Control::SIZE)
},
};
let f_ptr = align_alloc::<F>(stack_bottom, Alignment::of::<F>());
let f = unsafe { f_ptr.read() };
let ret_value = ManuallyDrop::new(f());
let control = unsafe { current_control() };
unsafe { arch::return_control(control, NonNull::from(&ret_value)) }
}
let stack = pool.take(STACK_SIZE, STACK_ALIGN);
let (f_ptr, stack_ptr) = stack.align_alloc(stack.bottom());
unsafe { f_ptr.write(f) };
let stack_ptr = align_alloc::<()>(stack_ptr, STACK_FRAME_ALIGN);
let stack_ptr = unsafe { NonNull::new_unchecked(stack_ptr) }.cast();
let control = unsafe { stack.control().as_uninit_mut() };
control.write(Control {
instr_ptr: trampoline::<F, Return> as *const _,
stack_ptr,
});
#[cfg(feature = "std")]
let stack = ManuallyDrop::new(stack);
Self {
stack,
phantom: PhantomData,
}
}
fn control(&self) -> &Control {
unsafe { self.stack.control().as_ref() }
}
fn control_mut(&mut self) -> &mut Control {
unsafe { self.stack.control().as_mut() }
}
pub fn is_finished(&self) -> bool {
self.control().instr_ptr.is_null()
}
}
impl<Return> Drop for Coro<Return> {
fn drop(&mut self) {
COMMON_POOL.with_borrow_mut(|pool| {
let stack = unsafe { ManuallyDrop::take(&mut self.stack) };
pool.give(stack);
});
}
}
impl<Return> Coroutine for Coro<Return> {
type Yield = ();
type Return = Return;
fn resume(self: Pin<&mut Self>, _arg: ()) -> CoroutineState<Self::Yield, Self::Return> {
let coro = self.get_mut();
{
let control = coro.control_mut();
unsafe { arch::transfer_control(control) };
}
if coro.is_finished() {
let ret_val_addr = coro.control().stack_ptr.cast();
let ret_val = unsafe { ret_val_addr.read() };
CoroutineState::Complete(ret_val)
} else {
CoroutineState::Yielded(())
}
}
}
pub fn yield_() {
let control = unsafe { current_control() };
unsafe { arch::transfer_control(control) };
}
unsafe fn stack_start() -> NonNull<u8> {
arch::stack_ptr()
.map_addr(|a| NonZeroUsize::new_unchecked(a.get() & STACK_ALIGN.mask()))
}
unsafe fn current_control<'a>() -> &'a mut Control {
let stack_start = stack_start().cast();
let mut control_ptr = match StackOrientation::current() {
StackOrientation::Upwards => stack_start,
StackOrientation::Downwards => stack_start.byte_add(STACK_SIZE.get() - Control::SIZE),
};
control_ptr.as_mut()
}
#[cfg(test)]
mod tests {
use crate::{yield_, Coro};
use std::ops::{Coroutine, CoroutineState};
use std::pin::Pin;
#[test]
fn simple() {
let mut coro = Coro::new(|| {
eprintln!("[c] activated for the first time");
yield_();
eprintln!("[c] resumed fine, exiting");
});
match Pin::new(&mut coro).resume(()) {
CoroutineState::Yielded(_) => {
eprintln!("[m] yielded once")
}
CoroutineState::Complete(_) => panic!("completed before expected"),
}
match Pin::new(&mut coro).resume(()) {
CoroutineState::Yielded(_) => panic!("yielded when not expected"),
CoroutineState::Complete(_) => {
eprintln!("[m] completed!")
}
}
}
#[test]
fn return_value() {
let mut coro = Coro::new(|| 12345);
assert_eq!(
Pin::new(&mut coro).resume(()),
CoroutineState::Complete(12345)
);
}
#[test]
fn lots_of_yields() {
const ITER_COUNT: usize = 10_000_000;
let mut coro = Coro::new(|| {
let mut res = 0;
for i in 0..ITER_COUNT {
res |= core::hint::black_box(i);
yield_();
}
res
});
let mut res = 0;
for i in 0..ITER_COUNT {
res |= i;
assert_eq!(Pin::new(&mut coro).resume(()), CoroutineState::Yielded(()));
}
assert_eq!(
Pin::new(&mut coro).resume(()),
CoroutineState::Complete(res)
);
}
}