use super::cabi;
use std::ffi::c_void;
use std::future::Future;
use std::marker;
use std::mem;
use std::pin::Pin;
use std::ptr;
use std::task::{Context, Poll, Waker};
pub struct WaitableOperation<S: WaitableOp> {
op: S,
state: WaitableOperationState<S>,
completion_status: CompletionStatus,
}
struct CompletionStatus {
code: Option<u32>,
waker: Option<Waker>,
_pinned: marker::PhantomPinned,
}
pub unsafe trait WaitableOp {
type Start;
type InProgress;
type Result;
type Cancel;
fn start(&mut self, state: Self::Start) -> (u32, Self::InProgress);
fn in_progress_update(
&mut self,
state: Self::InProgress,
code: u32,
) -> Result<Self::Result, Self::InProgress>;
fn start_cancelled(&mut self, state: Self::Start) -> Self::Cancel;
fn in_progress_waitable(&mut self, state: &Self::InProgress) -> u32;
fn in_progress_cancel(&mut self, state: &mut Self::InProgress) -> u32;
fn result_into_cancel(&mut self, result: Self::Result) -> Self::Cancel;
}
enum WaitableOperationState<S: WaitableOp> {
Start(S::Start),
InProgress(S::InProgress),
Done,
}
impl<S> WaitableOperation<S>
where
S: WaitableOp,
{
pub fn new(op: S, state: S::Start) -> WaitableOperation<S> {
WaitableOperation {
op,
state: WaitableOperationState::Start(state),
completion_status: CompletionStatus {
code: None,
waker: None,
_pinned: marker::PhantomPinned,
},
}
}
fn pin_project(
self: Pin<&mut Self>,
) -> (
&mut S,
&mut WaitableOperationState<S>,
Pin<&mut CompletionStatus>,
) {
unsafe {
let me = self.get_unchecked_mut();
(
&mut me.op,
&mut me.state,
Pin::new_unchecked(&mut me.completion_status),
)
}
}
pub fn register_waker(self: Pin<&mut Self>, waitable: u32, cx: &mut Context) {
let (_, _, mut completion_status) = self.pin_project();
debug_assert!(completion_status.as_mut().code_mut().is_none());
*completion_status.as_mut().waker_mut() = Some(cx.waker().clone());
unsafe {
let task = cabi::wasip3_task_set(ptr::null_mut());
assert!(!task.is_null());
assert!((*task).version >= cabi::WASIP3_TASK_V1);
let ptr: *mut CompletionStatus = completion_status.get_unchecked_mut();
let prev = ((*task).waitable_register)((*task).ptr, waitable, cabi_wake, ptr.cast());
if !prev.is_null() {
assert_eq!(ptr, prev.cast());
}
cabi::wasip3_task_set(task);
}
unsafe extern "C" fn cabi_wake(ptr: *mut c_void, code: u32) {
let ptr: &mut CompletionStatus = unsafe { &mut *ptr.cast::<CompletionStatus>() };
ptr.code = Some(code);
ptr.waker.take().unwrap().wake()
}
}
pub fn unregister_waker(self: Pin<&mut Self>, waitable: u32) {
unsafe {
let task = cabi::wasip3_task_set(ptr::null_mut());
assert!(!task.is_null());
assert!((*task).version >= cabi::WASIP3_TASK_V1);
let prev = ((*task).waitable_unregister)((*task).ptr, waitable);
if !prev.is_null() {
let ptr: *mut CompletionStatus = self.pin_project().2.get_unchecked_mut();
assert_eq!(ptr, prev.cast());
}
cabi::wasip3_task_set(task);
}
}
pub fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<S::Result> {
use WaitableOperationState::*;
let (op, state, completion_status) = self.as_mut().pin_project();
let optional_code = match state {
Start(_) => {
let Start(s) = mem::replace(state, Done) else {
unreachable!()
};
let (code, s) = op.start(s);
*state = InProgress(s);
Some(code)
}
InProgress(_) => completion_status.code_mut().take(),
Done => panic!("cannot re-poll after operation completes"),
};
self.poll_complete_with_code(Some(cx), optional_code)
}
fn poll_complete_with_code(
mut self: Pin<&mut Self>,
cx: Option<&mut Context>,
optional_code: Option<u32>,
) -> Poll<S::Result> {
use WaitableOperationState::*;
let (op, state, _completion_status) = self.as_mut().pin_project();
if let Some(code) = optional_code {
let InProgress(in_progress) = mem::replace(state, Done) else {
unreachable!()
};
match op.in_progress_update(in_progress, code) {
Ok(result) => return Poll::Ready(result),
Err(in_progress) => *state = InProgress(in_progress),
}
}
let in_progress = match state {
InProgress(s) => s,
_ => unreachable!(),
};
if let Some(cx) = cx {
let handle = op.in_progress_waitable(in_progress);
self.register_waker(handle, cx);
}
Poll::Pending
}
pub fn cancel(mut self: Pin<&mut Self>) -> S::Cancel {
use WaitableOperationState::*;
let (op, state, mut completion_status) = self.as_mut().pin_project();
let in_progress = match state {
Start(_) => {
let Start(s) = mem::replace(state, Done) else {
unreachable!()
};
return op.start_cancelled(s);
}
InProgress(s) => s,
Done => panic!("cannot cancel operation after completing it"),
};
match completion_status.as_mut().code_mut().take() {
Some(code) => {
match self.as_mut().poll_complete_with_code(None, Some(code)) {
Poll::Ready(result) => {
return self.as_mut().pin_project().0.result_into_cancel(result);
}
Poll::Pending => {}
}
}
None => {
let waitable = op.in_progress_waitable(in_progress);
self.as_mut().unregister_waker(waitable);
}
}
let (op, InProgress(in_progress), _) = self.as_mut().pin_project() else {
unreachable!()
};
let code = op.in_progress_cancel(in_progress);
match self.as_mut().poll_complete_with_code(None, Some(code)) {
Poll::Ready(result) => self.as_mut().pin_project().0.result_into_cancel(result),
Poll::Pending => unreachable!(),
}
}
pub fn is_done(&self) -> bool {
matches!(self.state, WaitableOperationState::Done)
}
}
impl<S: WaitableOp> Future for WaitableOperation<S> {
type Output = S::Result;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<S::Result> {
self.poll_complete(cx)
}
}
impl<S: WaitableOp> Drop for WaitableOperation<S> {
fn drop(&mut self) {
if self.is_done() {
return;
}
let pin = unsafe { Pin::new_unchecked(self) };
pin.cancel();
}
}
impl CompletionStatus {
fn code_mut(self: Pin<&mut Self>) -> &mut Option<u32> {
unsafe { &mut self.get_unchecked_mut().code }
}
fn waker_mut(self: Pin<&mut Self>) -> &mut Option<Waker> {
unsafe { &mut self.get_unchecked_mut().waker }
}
}