#![cfg_attr(not(test), no_std)]
#![warn(missing_docs)]
use core::future::Future;
use core::marker::PhantomData;
use core::mem;
use core::mem::MaybeUninit;
use core::pin::Pin;
use core::ptr;
use core::task::Context;
use core::task::Poll;
#[repr(C)] pub struct StackFuture<'a, T, const STACK_SIZE: usize> {
data: [MaybeUninit<u8>; STACK_SIZE],
poll_fn: fn(this: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T>,
drop_fn: fn(this: &mut Self),
_phantom: PhantomData<dyn Future<Output = T> + Send + 'a>,
}
impl<'a, T, const STACK_SIZE: usize> StackFuture<'a, T, { STACK_SIZE }> {
pub fn from<F>(future: F) -> Self
where
F: Future<Output = T> + Send + 'a, {
if mem::align_of::<F>() > mem::align_of::<Self>() {
panic!(
"cannot create StackFuture, required alignment is {} but maximum alignment is {}",
mem::align_of::<F>(),
mem::align_of::<Self>()
)
}
if Self::has_space_for_val(&future) {
let mut result = StackFuture {
data: [MaybeUninit::uninit(); STACK_SIZE],
poll_fn: Self::poll_inner::<F>,
drop_fn: Self::drop_inner::<F>,
_phantom: PhantomData,
};
assert_eq!(result.data.as_ptr() as usize, &result as *const _ as usize);
unsafe { result.as_mut_ptr::<F>().write(future) };
result
} else {
panic!(
"cannot create StackFuture, required size is {}, available space is {}",
mem::size_of::<F>(),
STACK_SIZE
);
}
}
fn poll_inner<F: Future>(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> {
self.as_pin_mut_ref::<F>().poll(cx)
}
fn drop_inner<F>(&mut self) {
unsafe { ptr::drop_in_place(self.as_mut_ptr::<F>()) }
}
fn as_mut_ptr<F>(&mut self) -> *mut F {
assert!(Self::has_space_for::<F>());
self.data.as_mut_ptr().cast()
}
fn as_pin_mut_ref<F>(self: Pin<&mut Self>) -> Pin<&mut F> {
unsafe { self.map_unchecked_mut(|this| &mut *this.as_mut_ptr()) }
}
fn required_space<F>() -> usize {
mem::size_of::<F>()
}
fn has_space_for<F>() -> bool {
Self::required_space::<F>() <= STACK_SIZE
}
fn has_space_for_val<F>(_: &F) -> bool {
Self::has_space_for::<F>()
}
}
impl<'a, T, const STACK_SIZE: usize> Future for StackFuture<'a, T, { STACK_SIZE }> {
type Output = T;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
(self.as_mut().poll_fn)(self, cx)
}
}
impl<'a, T, const STACK_SIZE: usize> Drop for StackFuture<'a, T, { STACK_SIZE }> {
fn drop(&mut self) {
(self.drop_fn)(self);
}
}
#[cfg(test)]
mod tests {
use crate::StackFuture;
use core::task::Poll;
use futures::executor::block_on;
use futures::pin_mut;
use futures::Future;
use std::sync::Arc;
use std::task::Context;
use std::task::Wake;
#[test]
fn create_and_run() {
let f = StackFuture::<'_, _, 8>::from(async { 5 });
assert_eq!(block_on(f), 5);
}
enum Never {}
struct SuspendPoint;
impl Future for SuspendPoint {
type Output = Never;
fn poll(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
Poll::Pending
}
}
struct Waker;
impl Wake for Waker {
fn wake(self: std::sync::Arc<Self>) {
unimplemented!()
}
}
#[test]
fn destructor_runs() {
let mut destructed = false;
let _poll_result = {
let f = async {
struct DropMe<'a>(&'a mut bool);
impl Drop for DropMe<'_> {
fn drop(&mut self) {
*self.0 = true;
}
}
let _ = DropMe(&mut destructed);
SuspendPoint.await
};
let f = StackFuture::<'_, _, 32>::from(f);
let waker = Arc::new(Waker).into();
let mut cx = Context::from_waker(&waker);
pin_mut!(f);
f.poll(&mut cx)
};
assert!(destructed);
}
#[test]
fn test_alignment() {
#[repr(align(8))]
struct BigAlignment(u32);
impl Future for BigAlignment {
type Output = Never;
fn poll(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
Poll::Pending
}
}
let mut f = StackFuture::<'_, _, 1016>::from(BigAlignment(42));
assert!(is_aligned(f.as_mut_ptr::<BigAlignment>(), 8));
}
fn is_aligned<T>(ptr: *mut T, alignment: usize) -> bool {
(ptr as usize) & (alignment - 1) == 0
}
}