use crate::rt::Cleanup;
use crate::rt::async_support::ReturnCode;
use crate::rt::async_support::waitable::{WaitableOp, WaitableOperation};
use std::alloc::Layout;
use std::fmt;
use std::future::{Future, IntoFuture};
use std::marker;
use std::mem::{self, ManuallyDrop};
use std::pin::Pin;
use std::ptr;
use std::sync::atomic::{AtomicU32, Ordering::Relaxed};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Wake, Waker};
pub trait FutureOps {
type Payload;
fn new(&mut self) -> u64;
fn elem_layout(&mut self) -> Layout;
unsafe fn lower(&mut self, payload: Self::Payload, dst: *mut u8);
unsafe fn dealloc_lists(&mut self, dst: *mut u8);
unsafe fn lift(&mut self, dst: *mut u8) -> Self::Payload;
unsafe fn start_write(&mut self, future: u32, val: *const u8) -> u32;
unsafe fn start_read(&mut self, future: u32, val: *mut u8) -> u32;
unsafe fn cancel_read(&mut self, future: u32) -> u32;
unsafe fn cancel_write(&mut self, future: u32) -> u32;
unsafe fn drop_readable(&mut self, future: u32);
unsafe fn drop_writable(&mut self, future: u32);
}
#[doc(hidden)]
pub struct FutureVtable<T> {
pub layout: Layout,
pub lower: unsafe fn(value: T, dst: *mut u8),
pub dealloc_lists: unsafe fn(dst: *mut u8),
pub lift: unsafe fn(dst: *mut u8) -> T,
pub start_write: unsafe extern "C" fn(future: u32, val: *const u8) -> u32,
pub start_read: unsafe extern "C" fn(future: u32, val: *mut u8) -> u32,
pub cancel_write: unsafe extern "C" fn(future: u32) -> u32,
pub cancel_read: unsafe extern "C" fn(future: u32) -> u32,
pub drop_writable: unsafe extern "C" fn(future: u32),
pub drop_readable: unsafe extern "C" fn(future: u32),
pub new: unsafe extern "C" fn() -> u64,
}
impl<T> FutureOps for &FutureVtable<T> {
type Payload = T;
fn new(&mut self) -> u64 {
unsafe { (self.new)() }
}
fn elem_layout(&mut self) -> Layout {
self.layout
}
unsafe fn lower(&mut self, payload: Self::Payload, dst: *mut u8) {
unsafe { (self.lower)(payload, dst) }
}
unsafe fn dealloc_lists(&mut self, dst: *mut u8) {
unsafe { (self.dealloc_lists)(dst) }
}
unsafe fn lift(&mut self, dst: *mut u8) -> Self::Payload {
unsafe { (self.lift)(dst) }
}
unsafe fn start_write(&mut self, future: u32, val: *const u8) -> u32 {
unsafe { (self.start_write)(future, val) }
}
unsafe fn start_read(&mut self, future: u32, val: *mut u8) -> u32 {
unsafe { (self.start_read)(future, val) }
}
unsafe fn cancel_read(&mut self, future: u32) -> u32 {
unsafe { (self.cancel_read)(future) }
}
unsafe fn cancel_write(&mut self, future: u32) -> u32 {
unsafe { (self.cancel_write)(future) }
}
unsafe fn drop_readable(&mut self, future: u32) {
unsafe { (self.drop_readable)(future) }
}
unsafe fn drop_writable(&mut self, future: u32) {
unsafe { (self.drop_writable)(future) }
}
}
pub unsafe fn future_new<T>(
default: fn() -> T,
vtable: &'static FutureVtable<T>,
) -> (FutureWriter<T>, FutureReader<T>) {
let (tx, rx) = unsafe { raw_future_new(vtable) };
(unsafe { FutureWriter::new(tx, default) }, rx)
}
pub unsafe fn raw_future_new<O>(mut ops: O) -> (RawFutureWriter<O>, RawFutureReader<O>)
where
O: FutureOps + Clone,
{
unsafe {
let handles = ops.new();
let reader = handles as u32;
let writer = (handles >> 32) as u32;
rtdebug!("future.new() = [{writer}, {reader}]");
(
RawFutureWriter::new(writer, ops.clone()),
RawFutureReader::new(reader, ops),
)
}
}
pub struct FutureWriter<T: 'static> {
raw: ManuallyDrop<RawFutureWriter<&'static FutureVtable<T>>>,
should_write_default_value: bool,
default: fn() -> T,
}
impl<T> FutureWriter<T> {
unsafe fn new(raw: RawFutureWriter<&'static FutureVtable<T>>, default: fn() -> T) -> Self {
Self {
raw: ManuallyDrop::new(raw),
default,
should_write_default_value: true,
}
}
pub fn write(mut self, value: T) -> FutureWrite<T> {
let raw = unsafe { ManuallyDrop::take(&mut self.raw).write(value) };
let default = self.default;
mem::forget(self);
FutureWrite { raw, default }
}
}
impl<T> fmt::Debug for FutureWriter<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FutureWriter")
.field("handle", &self.raw.handle)
.finish()
}
}
impl<T> Drop for FutureWriter<T> {
fn drop(&mut self) {
if self.should_write_default_value {
let raw = unsafe { ManuallyDrop::take(&mut self.raw) };
let value = (self.default)();
raw.write_and_forget(value);
} else {
unsafe { ManuallyDrop::drop(&mut self.raw) }
}
}
}
pub struct FutureWrite<T: 'static> {
raw: RawFutureWrite<&'static FutureVtable<T>>,
default: fn() -> T,
}
#[derive(Debug)]
pub enum FutureWriteCancel<T: 'static> {
AlreadySent,
Dropped(T),
Cancelled(T, FutureWriter<T>),
}
impl<T: 'static> Future for FutureWrite<T> {
type Output = Result<(), FutureWriteError<T>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.pin_project().poll(cx)
}
}
impl<T: 'static> FutureWrite<T> {
fn pin_project(self: Pin<&mut Self>) -> Pin<&mut RawFutureWrite<&'static FutureVtable<T>>> {
unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().raw) }
}
pub fn cancel(self: Pin<&mut Self>) -> FutureWriteCancel<T> {
let default = self.default;
match self.pin_project().cancel() {
RawFutureWriteCancel::AlreadySent => FutureWriteCancel::AlreadySent,
RawFutureWriteCancel::Dropped(val) => FutureWriteCancel::Dropped(val),
RawFutureWriteCancel::Cancelled(val, raw) => FutureWriteCancel::Cancelled(
val,
FutureWriter {
raw: ManuallyDrop::new(raw),
default,
should_write_default_value: true,
},
),
}
}
}
impl<T: 'static> Drop for FutureWrite<T> {
fn drop(&mut self) {
if self.raw.op.is_done() {
return;
}
let pin = unsafe { Pin::new_unchecked(self) };
pin.cancel();
}
}
pub struct RawFutureWriter<O: FutureOps> {
handle: u32,
ops: O,
}
impl<O: FutureOps> RawFutureWriter<O> {
unsafe fn new(handle: u32, ops: O) -> Self {
Self { handle, ops }
}
pub fn write(self, value: O::Payload) -> RawFutureWrite<O> {
RawFutureWrite {
op: WaitableOperation::new(FutureWriteOp(marker::PhantomData), (self, value)),
}
}
pub fn write_and_forget(self, value: O::Payload)
where
O: 'static,
{
return Arc::new(DeferredWrite {
write: Mutex::new(self.write(value)),
})
.wake();
struct DeferredWrite<O: FutureOps> {
write: Mutex<RawFutureWrite<O>>,
}
unsafe impl<O: FutureOps> Send for DeferredWrite<O> {}
unsafe impl<O: FutureOps> Sync for DeferredWrite<O> {}
impl<O: FutureOps + 'static> Wake for DeferredWrite<O> {
fn wake(self: Arc<Self>) {
let poll = {
let waker = Waker::from(self.clone());
let mut cx = Context::from_waker(&waker);
let mut write = self.write.lock().unwrap();
unsafe { Pin::new_unchecked(&mut *write).poll(&mut cx) }
};
if poll.is_ready() {
assert_eq!(Arc::strong_count(&self), 1);
} else {
assert!(Arc::strong_count(&self) > 1);
}
assert_eq!(Arc::weak_count(&self), 0);
}
}
}
}
impl<O: FutureOps> fmt::Debug for RawFutureWriter<O> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RawFutureWriter")
.field("handle", &self.handle)
.finish()
}
}
impl<O: FutureOps> Drop for RawFutureWriter<O> {
fn drop(&mut self) {
unsafe {
rtdebug!("future.drop-writable({})", self.handle);
self.ops.drop_writable(self.handle);
}
}
}
pub struct RawFutureWrite<O: FutureOps> {
op: WaitableOperation<FutureWriteOp<O>>,
}
struct FutureWriteOp<O>(marker::PhantomData<O>);
enum WriteComplete<T> {
Written,
Dropped(T),
Cancelled(T),
}
unsafe impl<O: FutureOps> WaitableOp for FutureWriteOp<O> {
type Start = (RawFutureWriter<O>, O::Payload);
type InProgress = (RawFutureWriter<O>, Option<Cleanup>);
type Result = (WriteComplete<O::Payload>, RawFutureWriter<O>);
type Cancel = RawFutureWriteCancel<O>;
fn start(&mut self, (mut writer, value): Self::Start) -> (u32, Self::InProgress) {
let (ptr, cleanup) = Cleanup::new(writer.ops.elem_layout());
let code = unsafe {
writer.ops.lower(value, ptr);
writer.ops.start_write(writer.handle, ptr)
};
rtdebug!("future.write({}, {ptr:?}) = {code:#x}", writer.handle);
(code, (writer, cleanup))
}
fn start_cancelled(&mut self, (writer, value): Self::Start) -> Self::Cancel {
RawFutureWriteCancel::Cancelled(value, writer)
}
fn in_progress_update(
&mut self,
(mut writer, cleanup): Self::InProgress,
code: u32,
) -> Result<Self::Result, Self::InProgress> {
let ptr = cleanup
.as_ref()
.map(|c| c.ptr.as_ptr())
.unwrap_or(ptr::null_mut());
match code {
super::BLOCKED => Err((writer, cleanup)),
super::DROPPED | super::CANCELLED => {
let value = unsafe { writer.ops.lift(ptr) };
let status = if code == super::DROPPED {
WriteComplete::Dropped(value)
} else {
WriteComplete::Cancelled(value)
};
Ok((status, writer))
}
super::COMPLETED => {
unsafe {
writer.ops.dealloc_lists(ptr);
}
Ok((WriteComplete::Written, writer))
}
other => unreachable!("unexpected code {other:?}"),
}
}
fn in_progress_waitable(&mut self, (writer, _): &Self::InProgress) -> u32 {
writer.handle
}
fn in_progress_cancel(&mut self, (writer, _): &mut Self::InProgress) -> u32 {
let code = unsafe { writer.ops.cancel_write(writer.handle) };
rtdebug!("future.cancel-write({}) = {code:#x}", writer.handle);
code
}
fn result_into_cancel(&mut self, (result, writer): Self::Result) -> Self::Cancel {
match result {
WriteComplete::Written => RawFutureWriteCancel::AlreadySent,
WriteComplete::Dropped(val) => RawFutureWriteCancel::Dropped(val),
WriteComplete::Cancelled(val) => RawFutureWriteCancel::Cancelled(val, writer),
}
}
}
impl<O: FutureOps> Future for RawFutureWrite<O> {
type Output = Result<(), FutureWriteError<O::Payload>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.pin_project()
.poll_complete(cx)
.map(|(result, _writer)| match result {
WriteComplete::Written => Ok(()),
WriteComplete::Dropped(value) | WriteComplete::Cancelled(value) => {
Err(FutureWriteError { value })
}
})
}
}
impl<O: FutureOps> RawFutureWrite<O> {
fn pin_project(self: Pin<&mut Self>) -> Pin<&mut WaitableOperation<FutureWriteOp<O>>> {
unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().op) }
}
pub fn cancel(self: Pin<&mut Self>) -> RawFutureWriteCancel<O> {
self.pin_project().cancel()
}
}
pub struct FutureWriteError<T> {
pub value: T,
}
impl<T> fmt::Debug for FutureWriteError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FutureWriteError").finish_non_exhaustive()
}
}
impl<T> fmt::Display for FutureWriteError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
"read end dropped".fmt(f)
}
}
impl<T> std::error::Error for FutureWriteError<T> {}
#[derive(Debug)]
pub enum RawFutureWriteCancel<O: FutureOps> {
AlreadySent,
Dropped(O::Payload),
Cancelled(O::Payload, RawFutureWriter<O>),
}
pub type FutureReader<T> = RawFutureReader<&'static FutureVtable<T>>;
pub struct RawFutureReader<O: FutureOps> {
handle: AtomicU32,
ops: O,
}
impl<O: FutureOps> fmt::Debug for RawFutureReader<O> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RawFutureReader")
.field("handle", &self.handle)
.finish()
}
}
impl<O: FutureOps> RawFutureReader<O> {
pub unsafe fn new(handle: u32, ops: O) -> Self {
Self {
handle: AtomicU32::new(handle),
ops,
}
}
#[doc(hidden)]
pub fn take_handle(&self) -> u32 {
let ret = self.opt_handle().unwrap();
self.handle.store(u32::MAX, Relaxed);
ret
}
fn handle(&self) -> u32 {
self.opt_handle().unwrap()
}
fn opt_handle(&self) -> Option<u32> {
match self.handle.load(Relaxed) {
u32::MAX => None,
other => Some(other),
}
}
}
impl<O: FutureOps> IntoFuture for RawFutureReader<O> {
type Output = O::Payload;
type IntoFuture = RawFutureRead<O>;
fn into_future(self) -> Self::IntoFuture {
RawFutureRead {
op: WaitableOperation::new(FutureReadOp(marker::PhantomData), self),
}
}
}
impl<O: FutureOps> Drop for RawFutureReader<O> {
fn drop(&mut self) {
let Some(handle) = self.opt_handle() else {
return;
};
unsafe {
rtdebug!("future.drop-readable({handle})");
self.ops.drop_readable(handle);
}
}
}
pub type FutureRead<T> = RawFutureRead<&'static FutureVtable<T>>;
pub struct RawFutureRead<O: FutureOps> {
op: WaitableOperation<FutureReadOp<O>>,
}
struct FutureReadOp<O>(marker::PhantomData<O>);
enum ReadComplete<T> {
Value(T),
Cancelled,
}
unsafe impl<O: FutureOps> WaitableOp for FutureReadOp<O> {
type Start = RawFutureReader<O>;
type InProgress = (RawFutureReader<O>, Option<Cleanup>);
type Result = (ReadComplete<O::Payload>, RawFutureReader<O>);
type Cancel = Result<O::Payload, RawFutureReader<O>>;
fn start(&mut self, mut reader: Self::Start) -> (u32, Self::InProgress) {
let (ptr, cleanup) = Cleanup::new(reader.ops.elem_layout());
let code = unsafe { reader.ops.start_read(reader.handle(), ptr) };
rtdebug!("future.read({}, {ptr:?}) = {code:#x}", reader.handle());
(code, (reader, cleanup))
}
fn start_cancelled(&mut self, state: Self::Start) -> Self::Cancel {
Err(state)
}
fn in_progress_update(
&mut self,
(mut reader, cleanup): Self::InProgress,
code: u32,
) -> Result<Self::Result, Self::InProgress> {
match ReturnCode::decode(code) {
ReturnCode::Blocked => Err((reader, cleanup)),
ReturnCode::Cancelled(0) => Ok((ReadComplete::Cancelled, reader)),
ReturnCode::Completed(0) => {
let ptr = cleanup
.as_ref()
.map(|c| c.ptr.as_ptr())
.unwrap_or(ptr::null_mut());
let value = unsafe { reader.ops.lift(ptr) };
Ok((ReadComplete::Value(value), reader))
}
other => panic!("unexpected code {other:?}"),
}
}
fn in_progress_waitable(&mut self, (reader, _): &Self::InProgress) -> u32 {
reader.handle()
}
fn in_progress_cancel(&mut self, (reader, _): &mut Self::InProgress) -> u32 {
let code = unsafe { reader.ops.cancel_read(reader.handle()) };
rtdebug!("future.cancel-read({}) = {code:#x}", reader.handle());
code
}
fn result_into_cancel(&mut self, (value, reader): Self::Result) -> Self::Cancel {
match value {
ReadComplete::Value(value) => Ok(value),
ReadComplete::Cancelled => Err(reader),
}
}
}
impl<O: FutureOps> Future for RawFutureRead<O> {
type Output = O::Payload;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.pin_project()
.poll_complete(cx)
.map(|(result, _reader)| match result {
ReadComplete::Value(val) => val,
ReadComplete::Cancelled => panic!("cannot poll after cancelling"),
})
}
}
impl<O: FutureOps> RawFutureRead<O> {
fn pin_project(self: Pin<&mut Self>) -> Pin<&mut WaitableOperation<FutureReadOp<O>>> {
unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().op) }
}
pub fn cancel(self: Pin<&mut Self>) -> Result<O::Payload, RawFutureReader<O>> {
self.pin_project().cancel()
}
}