#![doc = include_str!("../docs.md")]
#![warn(
clippy::complexity,
clippy::correctness,
clippy::style,
future_incompatible,
missing_debug_implementations,
missing_docs,
rust_2018_idioms,
rustdoc::all,
clippy::undocumented_unsafe_blocks
)]
use core::{
fmt,
future::Future,
marker::PhantomData,
pin::Pin,
task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
};
pub type CoroFn<S, R, F> = fn(Handle<S, R>) -> F;
pub trait AsCoro {
type Snd: Unpin + 'static;
type Rcv: Unpin;
type Out;
fn as_coro_fn(handle: Handle<Self::Snd, Self::Rcv>) -> impl Future<Output = Self::Out>;
fn as_coro() -> ReadyCoro<Self::Snd, Self::Rcv, Self::Out, impl Future<Output = Self::Out>> {
Coro {
_lifecycle: Ready,
state: SharedState::default(),
fut: Box::pin(Self::as_coro_fn(Handle {
_snd: PhantomData,
_rcv: PhantomData,
})),
}
}
}
pub trait IntoCoro: Sized {
type Snd: Unpin + 'static;
type Rcv: Unpin;
type Out;
fn into_coro_fn(self, handle: Handle<Self::Snd, Self::Rcv>) -> impl Future<Output = Self::Out>;
fn into_coro(
self,
) -> ReadyCoro<Self::Snd, Self::Rcv, Self::Out, impl Future<Output = Self::Out>> {
Coro {
_lifecycle: Ready,
state: SharedState::default(),
fut: Box::pin(self.into_coro_fn(Handle {
_snd: PhantomData,
_rcv: PhantomData,
})),
}
}
fn into_dyn_coro(
self,
) -> ReadyCoro<Self::Snd, Self::Rcv, Self::Out, dyn Future<Output = Self::Out>>
where
Self: 'static,
Self::Out: 'static,
{
Coro {
_lifecycle: Ready,
state: SharedState::default(),
fut: Box::pin(self.into_coro_fn(Handle {
_snd: PhantomData,
_rcv: PhantomData,
})),
}
}
}
impl<F, S, R, Fut, O> From<F> for ReadyCoro<S, R, O, Fut>
where
F: FnOnce(Handle<S, R>) -> Fut,
S: Unpin + 'static,
R: Unpin,
Fut: Future<Output = O>,
{
fn from(f: F) -> Self {
Coro {
_lifecycle: Ready,
state: SharedState::default(),
fut: Box::pin((f)(Handle {
_snd: PhantomData,
_rcv: PhantomData,
})),
}
}
}
impl<S, R, O> ReadyCoro<S, R, O, dyn Future<Output = O>>
where
S: Unpin + 'static,
R: Unpin,
{
pub fn from_dyn<F, Fut>(f: F) -> Self
where
F: FnOnce(Handle<S, R>) -> Fut,
Fut: Future<Output = O> + 'static,
{
Coro {
_lifecycle: Ready,
state: SharedState::default(),
fut: Box::pin((f)(Handle {
_snd: PhantomData,
_rcv: PhantomData,
})),
}
}
}
const DENY_FUT: &str = "a Coro is not permitted to await a future that uses an arbitrary Waker";
static WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
|_| panic!("{DENY_FUT}"),
|_| panic!("{DENY_FUT}"),
|_| panic!("{DENY_FUT}"),
|_| {},
);
#[derive(Debug, Clone)]
struct SharedState<S, R>
where
S: 'static,
{
s: Option<S>,
r: Option<R>,
}
impl<S, R> Default for SharedState<S, R> {
fn default() -> Self {
Self { s: None, r: None }
}
}
pub trait Lifecycle: fmt::Debug {}
#[derive(Debug)]
pub struct Ready;
impl Lifecycle for Ready {}
#[derive(Debug)]
pub struct Pending;
impl Lifecycle for Pending {}
pub type ReadyCoro<S, R, O, F> = Coro<S, R, O, F, Ready>;
pub type PendingCoro<S, R, O, F> = Coro<S, R, O, F, Pending>;
pub struct Coro<S, R, O, F: ?Sized, L>
where
S: Unpin + 'static,
R: Unpin,
F: Future<Output = O>,
L: Lifecycle,
{
_lifecycle: L,
state: SharedState<S, R>,
fut: Pin<Box<F>>,
}
impl<S, R, O, F: ?Sized, L> fmt::Debug for Coro<S, R, O, F, L>
where
S: Unpin,
R: Unpin,
F: Future<Output = O>,
L: Lifecycle,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Coro")
.field("lifecycle", &self._lifecycle)
.finish()
}
}
impl<S, R, O, Fut: ?Sized> Coro<S, R, O, Fut, Ready>
where
S: Unpin,
R: Unpin,
Fut: Future<Output = O>,
{
pub fn run_sync<F>(self, mut step_fn: F) -> O
where
F: FnMut(S) -> R,
{
let mut coro = self;
loop {
coro = match coro.resume() {
CoroState::Complete(res) => return res,
CoroState::Pending(c, s) => c.send((step_fn)(s)),
};
}
}
pub fn resume(mut self) -> CoroState<S, R, O, Fut> {
let waker = unsafe {
Waker::from_raw(RawWaker::new(
&self.state as *const SharedState<S, R> as *const (),
&WAKER_VTABLE,
))
};
let mut ctx = Context::from_waker(&waker);
match self.fut.as_mut().poll(&mut ctx) {
Poll::Ready(val) => CoroState::Complete(val),
Poll::Pending => {
let s = self.state.s.take().unwrap_or_else(|| panic!("{DENY_FUT}"));
let sm = Coro {
_lifecycle: Pending,
state: self.state,
fut: self.fut,
};
CoroState::Pending(sm, s)
}
}
}
}
impl<S, R, O, F: ?Sized> Coro<S, R, O, F, Pending>
where
S: Unpin,
R: Unpin,
F: Future<Output = O>,
{
pub fn send(mut self, r: R) -> ReadyCoro<S, R, O, F> {
self.state.r = Some(r);
Coro {
_lifecycle: Ready,
state: self.state,
fut: self.fut,
}
}
}
pub type Generator<T, F> = Coro<T, (), (), F, Ready>;
impl<T, F: ?Sized> Iterator for Generator<T, F>
where
T: Unpin,
F: Future<Output = ()>,
{
type Item = T;
fn next(&mut self) -> Option<T> {
let waker = unsafe {
Waker::from_raw(RawWaker::new(
&self.state as *const SharedState<T, ()> as *const (),
&WAKER_VTABLE,
))
};
let mut ctx = Context::from_waker(&waker);
match self.fut.as_mut().poll(&mut ctx) {
Poll::Pending => {
let val = self.state.s.take().unwrap_or_else(|| panic!("{DENY_FUT}"));
self.state.r = Some(());
Some(val)
}
Poll::Ready(()) => None,
}
}
}
#[derive(Debug)]
#[must_use]
pub enum CoroState<S, R, T, F: ?Sized>
where
S: Unpin + 'static,
R: Unpin,
F: Future<Output = T>,
{
Pending(PendingCoro<S, R, T, F>, S),
Complete(T),
}
impl<S, R, T, F: ?Sized> CoroState<S, R, T, F>
where
S: Unpin + 'static,
R: Unpin,
F: Future<Output = T>,
{
pub fn is_pending(&self) -> bool {
matches!(self, &Self::Pending(_, _))
}
pub fn is_complete(&self) -> bool {
matches!(self, &Self::Complete(_))
}
pub fn unwrap_pending(self, f: impl Fn(S) -> R) -> ReadyCoro<S, R, T, F> {
match self {
Self::Pending(coro, s) => coro.send((f)(s)),
Self::Complete(_) => {
panic!("called `CoroState::unwrap_pending` on a `Complete` value")
}
}
}
pub fn unwrap(self) -> T {
match self {
Self::Pending(_, _) => {
panic!("called `CoroState::unwrap` on a `Pending` value")
}
Self::Complete(t) => t,
}
}
}
#[derive(Debug)]
pub struct Handle<S, R = ()>
where
S: Unpin,
R: Unpin,
{
_snd: PhantomData<S>,
_rcv: PhantomData<R>,
}
#[doc(hidden)]
pub type HandOwl = Handle<(), ()>;
impl<S, R> Clone for Handle<S, R>
where
S: Unpin + 'static,
R: Unpin,
{
fn clone(&self) -> Self {
*self
}
}
impl<S, R> Copy for Handle<S, R>
where
S: Unpin + 'static,
R: Unpin,
{
}
impl<S, R> Handle<S, R>
where
S: Unpin + 'static,
R: Unpin,
{
pub async fn yield_value(&self, snd: S) -> R {
Yield {
polled: false,
s: Some(snd),
_r: PhantomData,
}
.await
}
pub async fn yield_from<T, C, F>(&self, coro: C) -> T
where
C: Into<ReadyCoro<S, R, T, F>>,
F: Future<Output = T>,
{
coro.into().fut.await
}
pub async fn yield_from_type<C, T>(&self) -> T
where
C: AsCoro<Snd = S, Rcv = R, Out = T>,
{
C::as_coro_fn(*self).await
}
}
struct Yield<S, R>
where
S: Unpin + 'static,
R: Unpin,
{
polled: bool,
s: Option<S>,
_r: PhantomData<R>,
}
impl<S, R> Future for Yield<S, R>
where
S: Unpin + 'static,
R: Unpin,
{
type Output = R;
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<R> {
if *ctx.waker().vtable() != WAKER_VTABLE {
panic!("this future must be awaited inside of a Coro");
}
if self.polled {
let data = unsafe {
(ctx.waker().data() as *mut () as *mut SharedState<S, R>)
.as_mut()
.unwrap_unchecked()
.r
.take()
.unwrap_unchecked()
};
Poll::Ready(data)
} else {
self.polled = true;
unsafe {
(ctx.waker().data() as *mut () as *mut SharedState<S, R>)
.as_mut()
.unwrap_unchecked()
.s = Some(self.s.take().unwrap_unchecked());
};
Poll::Pending
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
io::{self, Cursor, ErrorKind},
pin::pin,
sync::mpsc::channel,
thread::spawn,
};
fn yield_recv_return()
-> ReadyCoro<usize, bool, &'static str, impl Future<Output = &'static str>> {
Coro::from(async |handle: Handle<usize, bool>| {
assert!(handle.yield_value(42).await);
"hello, world!"
})
}
#[test]
fn yield_recv_return_works() {
let mut coro = yield_recv_return();
coro = coro.resume().unwrap_pending(|n| {
assert_eq!(n, 42);
true
});
let s = coro.resume().unwrap();
assert_eq!(s, "hello, world!");
}
#[test]
fn coro_state_chaining_works() {
let s = yield_recv_return()
.resume()
.unwrap_pending(|n| {
assert_eq!(n, 42);
true
})
.resume()
.unwrap();
assert_eq!(s, "hello, world!");
}
#[test]
#[should_panic = "called `CoroState::unwrap` on a `Pending` value"]
fn unwrap_on_pending_panics() {
yield_recv_return().resume().unwrap();
}
#[test]
#[should_panic = "called `CoroState::unwrap_pending` on a `Complete` value"]
fn unwrap_pending_on_complete_panics() {
let mut coro = yield_recv_return();
coro = coro.resume().unwrap_pending(|n| {
assert_eq!(n, 42);
true
});
coro.resume().unwrap_pending(|n| {
assert_eq!(n, 42);
true
});
}
#[test]
#[should_panic = "this future must be awaited inside of a Coro"]
fn manually_polling_yield_panics() {
let coro = Coro::from(async |handle: Handle<()>| {
let fut = handle.yield_value(());
let waker = Waker::noop();
let mut cx = Context::from_waker(waker);
let _ = pin!(fut).as_mut().poll(&mut cx);
});
let _ = coro.resume();
}
#[test]
fn dyn_future_type_works() {
let coro = Coro::from_dyn(|handle: Handle<usize, bool>| async move {
assert!(handle.yield_value(42).await);
"hello, dyn world!"
});
let mut coro = coro;
coro = coro.resume().unwrap_pending(|n| {
assert_eq!(n, 42);
true
});
let s = coro.resume().unwrap();
assert_eq!(s, "hello, dyn world!");
}
fn double_nums(
nums: &[usize],
) -> ReadyCoro<usize, usize, &'static str, impl Future<Output = &'static str>> {
Coro::from(async |handle: Handle<usize, usize>| {
for &n in nums.iter() {
let doubled = handle.yield_value(n).await;
assert_eq!(doubled, n * 2);
}
"done"
})
}
#[test]
fn closures_capturing_references_work() {
let mut coro = double_nums(&[1, 2, 3]);
loop {
coro = match coro.resume() {
CoroState::Pending(c, n) => c.send(n * 2),
CoroState::Complete(res) => {
assert_eq!(res, "done");
return;
}
};
}
}
#[test]
fn run_sync_works() {
let res = double_nums(&[1, 2, 3]).run_sync(|n| n * 2);
assert_eq!(res, "done");
}
struct Counter<const N: usize>;
impl<const N: usize> AsCoro for Counter<N> {
type Snd = usize;
type Rcv = ();
type Out = ();
async fn as_coro_fn(handle: Handle<usize>) {
for n in 0..N {
handle.yield_value(n).await
}
}
}
#[test]
fn generator_iter_works() {
let nums: Vec<usize> = Counter::<6>::as_coro().collect();
assert_eq!(nums, vec![0, 1, 2, 3, 4, 5]);
}
#[test]
fn yield_from_type_works() {
let coro = Coro::from(async |handle: Handle<usize>| {
handle.yield_from_type::<Counter<6>, _>().await
});
let total: usize = coro.sum();
assert_eq!(total, 1 + 2 + 3 + 4 + 5);
}
#[test]
fn capturing_nested_closure_works() {
let g = |k: usize| {
Generator::from(move |handle: Handle<usize>| async move {
for n in 0..k {
handle.yield_value(n).await;
}
})
};
let nums: Vec<usize> = g(6).collect();
assert_eq!(nums, vec![0, 1, 2, 3, 4, 5]);
}
fn tokio_boom() -> ReadyCoro<(), (), &'static str, impl Future<Output = &'static str>> {
Coro::from(async |_: Handle<()>| {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
"boom!"
})
}
#[tokio::test]
#[cfg_attr(miri, ignore)]
#[should_panic = "a Coro is not permitted to await a future that uses an arbitrary Waker"]
async fn awaiting_a_future_that_needs_a_waker_panics() {
let _ = tokio_boom().resume();
}
#[test]
#[should_panic = "there is no reactor running, must be called from the context of a Tokio 1.x runtime"]
fn awaiting_a_tokio_future_panics_outside_of_tokio() {
let _ = tokio_boom().resume();
}
struct Echo<T> {
initial: T,
}
impl<T: Unpin + 'static> IntoCoro for Echo<T> {
type Snd = T;
type Rcv = T;
type Out = ();
async fn into_coro_fn(self, handle: Handle<T, T>) {
let mut val = self.initial;
loop {
val = handle.yield_value(val).await;
}
}
}
#[test]
fn moving_between_threads_works() {
let mut ping_pong = Echo { initial: "ping" }.into_coro();
let (tx1, rx1) = channel();
let (tx2, rx2) = channel();
let jh1 = spawn(move || {
for _ in 0..3 {
ping_pong = match ping_pong.resume() {
CoroState::Pending(c, s) => {
assert_eq!(s, "ping");
let coro = c.send("pong");
tx1.send(coro).unwrap();
rx2.recv().unwrap()
}
CoroState::Complete(_) => break,
};
}
});
let jh2 = spawn(move || {
for _ in 0..3 {
let ping_pong = rx1.recv().unwrap();
match ping_pong.resume() {
CoroState::Pending(c, s) => {
assert_eq!(s, "pong");
let coro = c.send("ping");
tx2.send(coro).unwrap();
}
CoroState::Complete(_) => break,
}
}
});
jh1.join().unwrap();
jh2.join().unwrap();
}
const HELLO_WORLD: [u8; 17] = [
0x02, 0x00, 0x05, 0x00, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x06, 0x00, 0xe4, 0xb8, 0x96, 0xe7,
0x95, 0x8c,
];
async fn read_9p_u16(handle: Handle<usize, Vec<u8>>) -> io::Result<u16> {
let n = size_of::<u16>();
let buf = handle.yield_value(n).await;
let data = buf[0..n].try_into().unwrap();
Ok(u16::from_le_bytes(data))
}
async fn read_9p_string(handle: Handle<usize, Vec<u8>>) -> io::Result<String> {
let len = handle.yield_from(read_9p_u16).await? as usize;
let buf = handle.yield_value(len).await;
String::from_utf8(buf).map_err(|e| io::Error::new(ErrorKind::InvalidData, e.to_string()))
}
async fn read_9p_string_vec(handle: Handle<usize, Vec<u8>>) -> io::Result<Vec<String>> {
let len = handle.yield_from(read_9p_u16).await? as usize;
let mut buf = Vec::with_capacity(len);
for _ in 0..len {
buf.push(handle.yield_from(read_9p_string).await?);
}
Ok(buf)
}
#[test]
fn nested_yield_from_works() {
use std::io::Read;
let mut coro = Coro::from(read_9p_string_vec);
let mut r = Cursor::new(HELLO_WORLD.to_vec());
loop {
coro = match coro.resume() {
CoroState::Complete(parsed) => {
assert_eq!(parsed.unwrap(), &["Hello", "世界"]);
return;
}
CoroState::Pending(c, n) => c.send({
let mut buf = vec![0; n];
r.read_exact(&mut buf).unwrap();
buf
}),
};
}
}
#[tokio::test]
#[cfg_attr(miri, ignore)]
async fn nested_yield_from_works_with_async() {
use tokio::io::AsyncReadExt;
let mut coro = Coro::from(read_9p_string_vec);
let mut r = Cursor::new(HELLO_WORLD.to_vec());
loop {
coro = match coro.resume() {
CoroState::Complete(parsed) => {
assert_eq!(parsed.unwrap(), &["Hello", "世界"]);
return;
}
CoroState::Pending(c, n) => c.send({
let mut buf = vec![0; n];
r.read_exact(&mut buf).await.unwrap();
buf
}),
};
}
}
#[test]
fn run_sync_works_with_nested_yields() {
use std::io::Read;
let mut r = Cursor::new(HELLO_WORLD.to_vec());
let parsed = Coro::from(read_9p_string_vec)
.run_sync(|n| {
let mut buf = vec![0; n];
r.read_exact(&mut buf).unwrap();
buf
})
.expect("parsing to be successful");
assert_eq!(parsed, &["Hello", "世界"]);
}
}