use std::{
future::Future,
marker::PhantomData,
pin::Pin,
task::{Context, Poll, Waker},
};
use compio_buf::BufResult;
use compio_driver::{Extra, Key, OpCode, PushEntry};
use futures_util::future::FusedFuture;
use crate::{
CancelToken, Runtime,
waker::{get_ext, get_waker},
};
pub(crate) trait ContextExt {
fn get_waker(&self) -> &Waker;
fn get_cancel(&mut self) -> Option<&CancelToken>;
fn as_extra(&mut self, default: impl FnOnce() -> Extra) -> Option<Extra>;
}
impl ContextExt for Context<'_> {
fn get_waker(&self) -> &Waker {
get_waker(self.waker())
}
fn get_cancel(&mut self) -> Option<&CancelToken> {
get_ext(self.waker())?.get_cancel()
}
fn as_extra(&mut self, default: impl FnOnce() -> Extra) -> Option<Extra> {
let ext = get_ext(self.waker())?;
let mut extra = default();
ext.set_extra(&mut extra);
Some(extra)
}
}
pin_project_lite::pin_project! {
pub struct Submit<T: OpCode, E = ()> {
runtime: Runtime,
state: Option<State<T, E>>,
}
impl<T: OpCode, E> PinnedDrop for Submit<T, E> {
fn drop(this: Pin<&mut Self>) {
let this = this.project();
if let Some(State::Submitted { key, .. }) = this.state.take() {
this.runtime.cancel(key);
}
}
}
}
enum State<T: OpCode, E> {
Idle { op: T },
Submitted { key: Key<T>, _p: PhantomData<E> },
}
impl<T: OpCode, E> State<T, E> {
fn submitted(key: Key<T>) -> Self {
State::Submitted {
key,
_p: PhantomData,
}
}
}
impl<T: OpCode> Submit<T, ()> {
pub(crate) fn new(runtime: Runtime, op: T) -> Self {
Submit {
runtime,
state: Some(State::Idle { op }),
}
}
pub fn with_extra(mut self) -> Submit<T, Extra> {
let runtime = self.runtime.clone();
let Some(state) = self.state.take() else {
return Submit {
runtime,
state: None,
};
};
let state = match state {
State::Submitted { key, .. } => State::Submitted {
key,
_p: PhantomData,
},
State::Idle { op } => State::Idle { op },
};
Submit {
runtime,
state: Some(state),
}
}
}
impl<T: OpCode + 'static> Future for Submit<T, ()> {
type Output = BufResult<usize, T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
loop {
match this.state.take().expect("Cannot poll after ready") {
State::Submitted { key, .. } => match this.runtime.poll_task(cx.get_waker(), key) {
PushEntry::Pending(key) => {
*this.state = Some(State::submitted(key));
return Poll::Pending;
}
PushEntry::Ready(res) => return Poll::Ready(res),
},
State::Idle { op } => {
let extra = cx.as_extra(|| this.runtime.default_extra());
match this.runtime.submit_raw(op, extra) {
PushEntry::Pending(key) => {
if let Some(cancel) = cx.get_cancel() {
cancel.register(&key);
};
*this.state = Some(State::submitted(key))
}
PushEntry::Ready(res) => {
return Poll::Ready(res);
}
}
}
}
}
}
}
impl<T: OpCode + 'static> Future for Submit<T, Extra> {
type Output = (BufResult<usize, T>, Extra);
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
loop {
match this.state.take().expect("Cannot poll after ready") {
State::Submitted { key, .. } => {
match this.runtime.poll_task_with_extra(cx.get_waker(), key) {
PushEntry::Pending(key) => {
*this.state = Some(State::submitted(key));
return Poll::Pending;
}
PushEntry::Ready(res) => return Poll::Ready(res),
}
}
State::Idle { op } => {
let extra = cx.as_extra(|| this.runtime.default_extra());
match this.runtime.submit_raw(op, extra) {
PushEntry::Pending(key) => {
if let Some(cancel) = cx.get_cancel() {
cancel.register(&key);
}
*this.state = Some(State::submitted(key))
}
PushEntry::Ready(res) => {
return Poll::Ready((res, this.runtime.default_extra()));
}
}
}
}
}
}
}
impl<T: OpCode, E> FusedFuture for Submit<T, E>
where
Submit<T, E>: Future,
{
fn is_terminated(&self) -> bool {
self.state.is_none()
}
}