compio-runtime 0.12.0

High-level runtime for compio
Documentation
use std::{
    cell::RefCell,
    marker::PhantomData,
    pin::Pin,
    rc::Rc,
    task::{Context, Poll},
};

use compio_buf::{BufResult, SetLen};
use compio_driver::{
    BufferPool, BufferRef, Extra, Key, OpCode, Proactor, PushEntry, TakeBuffer,
    op::{RecvFromMultiResult, RecvMsgMultiResult},
};
use futures_util::{Stream, StreamExt, stream::FusedStream};

use crate::{
    ContextExt,
    future::{poll_multishot, poll_task_with_extra, submit_raw},
};

pin_project_lite::pin_project! {
    /// Returned [`Stream`] for [`Runtime::submit_multi`].
    ///
    /// When this is dropped and the operation hasn't finished yet, it will try to
    /// cancel the operation.
    pub struct SubmitMulti<T: OpCode> {
        driver: Rc<RefCell<Proactor>>,
        state: Option<State<T>>,
    }

    impl<T: OpCode> PinnedDrop for SubmitMulti<T> {
        fn drop(this: Pin<&mut Self>) {
            let this = this.project();
            if let Some(State::Submitted { key }) = this.state.take() {
                this.driver.borrow_mut().cancel(key);
            }
        }
    }
}

enum State<T: OpCode> {
    Idle { op: T },
    Submitted { key: Key<T> },
    Finished { op: T },
}

impl<T: OpCode> State<T> {
    fn submitted(key: Key<T>) -> Self {
        State::Submitted { key }
    }
}

impl<T: OpCode> SubmitMulti<T> {
    pub(crate) fn new(driver: Rc<RefCell<Proactor>>, op: T) -> Self {
        SubmitMulti {
            driver,
            state: Some(State::Idle { op }),
        }
    }

    /// Try to take the inner op from the stream.
    ///
    /// Returns `Ok(T)` if the stream:
    ///
    /// - has not been polled yet, or
    /// - is finished and the op is returned by the driver
    ///
    /// Returns `Err(Self)` if it's still running.
    pub fn try_take(mut self) -> Result<T, Self> {
        match self.state.take() {
            Some(State::Finished { op }) | Some(State::Idle { op }) => Ok(op),
            state => {
                debug_assert!(state.is_some());
                self.state = state;
                Err(self)
            }
        }
    }
}

impl<T: OpCode + 'static> Stream for SubmitMulti<T> {
    type Item = BufResult<usize, Extra>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let this = self.project();

        loop {
            match this.state.take().expect("State error, this is a bug") {
                State::Idle { op } => {
                    let extra = cx.as_extra(|| this.driver.borrow().default_extra());
                    let entry = submit_raw(&mut this.driver.borrow_mut(), op, extra);
                    match entry {
                        PushEntry::Pending(key) => {
                            if let Some(cancel) = cx.get_cancel() {
                                cancel.register(&key);
                            }

                            *this.state = Some(State::submitted(key))
                        }
                        PushEntry::Ready(BufResult(res, op)) => {
                            *this.state = Some(State::Finished { op });
                            let extra = this.driver.borrow().default_extra();

                            return Poll::Ready(Some(BufResult(res, extra)));
                        }
                    }
                }

                State::Submitted { key, .. } => {
                    if let Some(res) =
                        poll_multishot(&mut this.driver.borrow_mut(), cx.get_waker(), &key)
                    {
                        *this.state = Some(State::submitted(key));

                        return Poll::Ready(Some(res));
                    };

                    let entry =
                        poll_task_with_extra(&mut this.driver.borrow_mut(), cx.get_waker(), key);
                    match entry {
                        PushEntry::Pending(key) => {
                            *this.state = Some(State::submitted(key));

                            return Poll::Pending;
                        }
                        PushEntry::Ready((BufResult(res, op), extra)) => {
                            *this.state = Some(State::Finished { op });

                            return Poll::Ready(Some(BufResult(res, extra)));
                        }
                    }
                }

                State::Finished { op } => {
                    *this.state = Some(State::Finished { op });

                    return Poll::Ready(None);
                }
            }
        }
    }
}

impl<T: OpCode + 'static> FusedStream for SubmitMulti<T> {
    fn is_terminated(&self) -> bool {
        matches!(self.state, None | Some(State::Finished { .. }))
    }
}

impl<T: OpCode + TakeBuffer + 'static> SubmitMulti<T>
where
    <T as TakeBuffer>::Buffer: HandleBufferRef<Param = ()>,
{
    /// Convert this stream into one that iterates the buffers from the results.
    pub fn into_managed(self, buffer_pool: BufferPool) -> SubmitMultiManaged<T, T::Buffer> {
        SubmitMultiManaged::new(self, buffer_pool, ())
    }
}

impl<T: OpCode + TakeBuffer + 'static> SubmitMulti<T>
where
    <T as TakeBuffer>::Buffer: HandleBufferRef,
{
    /// Convert this stream into one that iterates the buffers from the results,
    /// with a param to construct the result item.
    pub fn into_managed_with(
        self,
        buffer_pool: BufferPool,
        param: <<T as TakeBuffer>::Buffer as HandleBufferRef>::Param,
    ) -> SubmitMultiManaged<T, T::Buffer> {
        SubmitMultiManaged::new(self, buffer_pool, param)
    }
}

/// A wrapper around [`SubmitMulti`] that iterates the buffers from the results.
pub struct SubmitMultiManaged<T: OpCode, B = BufferRef>
where
    B: HandleBufferRef + 'static,
{
    inner: Option<SubmitMulti<T>>,
    buffer_pool: BufferPool,
    param: <B as HandleBufferRef>::Param,
    _p: PhantomData<&'static B>,
}

impl<T: OpCode, B: HandleBufferRef + 'static> SubmitMultiManaged<T, B> {
    fn new(
        stream: SubmitMulti<T>,
        buffer_pool: BufferPool,
        param: <B as HandleBufferRef>::Param,
    ) -> Self {
        Self {
            inner: Some(stream),
            buffer_pool,
            param,
            _p: PhantomData,
        }
    }
}

impl<T: OpCode + TakeBuffer<Buffer = B> + 'static, B: HandleBufferRef> Stream
    for SubmitMultiManaged<T, B>
{
    type Item = std::io::Result<Option<B>>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        if let Some(inner) = self.inner.as_mut() {
            let buffer = match std::task::ready!(inner.poll_next_unpin(cx)) {
                Some(BufResult(res, extra)) => {
                    if inner.is_terminated() {
                        let mut b = self
                            .inner
                            .take()
                            .and_then(|s| s.try_take().ok())
                            .and_then(|op| op.take_buffer());
                        let res = res?;
                        if let Some(ref mut b) = b {
                            unsafe { b.advance_to(res) }
                        }
                        b
                    } else {
                        let b = self.buffer_pool.take(extra.buffer_id()?)?;
                        let res = res?;
                        if let Some(mut b) = b {
                            unsafe {
                                SetLen::advance_to(&mut b, res);
                                Some(B::from_buffer_ref(b, self.param))
                            }
                        } else {
                            None
                        }
                    }
                }
                None => self
                    .inner
                    .take()
                    .and_then(|s| s.try_take().ok())
                    .and_then(|op| op.take_buffer()),
            };
            Poll::Ready(Some(Ok(buffer)))
        } else {
            Poll::Ready(None)
        }
    }
}

impl<T: OpCode + TakeBuffer<Buffer = B> + 'static, B: HandleBufferRef> FusedStream
    for SubmitMultiManaged<T, B>
{
    fn is_terminated(&self) -> bool {
        self.inner.as_ref().is_none_or(|s| s.is_terminated())
    }
}

mod private {
    use super::*;

    pub trait Sealed {}

    impl Sealed for BufferRef {}
    impl Sealed for RecvFromMultiResult {}
    impl Sealed for RecvMsgMultiResult {}
}

#[doc(hidden)]
pub trait HandleBufferRef: private::Sealed {
    type Param: Copy + Unpin;

    unsafe fn from_buffer_ref(buffer: BufferRef, param: Self::Param) -> Self;

    unsafe fn advance_to(&mut self, len: usize);

    fn is_empty(&self) -> bool;
}

impl HandleBufferRef for BufferRef {
    type Param = ();

    unsafe fn from_buffer_ref(buffer: BufferRef, _: Self::Param) -> Self {
        buffer
    }

    unsafe fn advance_to(&mut self, len: usize) {
        unsafe { SetLen::advance_to(self, len) }
    }

    fn is_empty(&self) -> bool {
        // A fallback buffer pool takes the buffer before the operation, so it
        // can return an empty buffer when EOF is reached.
        <[u8]>::is_empty(self)
    }
}

impl HandleBufferRef for RecvFromMultiResult {
    type Param = ();

    unsafe fn from_buffer_ref(buffer: BufferRef, _: Self::Param) -> Self {
        unsafe { RecvFromMultiResult::new(buffer) }
    }

    unsafe fn advance_to(&mut self, _: usize) {}

    fn is_empty(&self) -> bool {
        false
    }
}

impl HandleBufferRef for RecvMsgMultiResult {
    type Param = usize;

    unsafe fn from_buffer_ref(buffer: BufferRef, clen: usize) -> Self {
        unsafe { RecvMsgMultiResult::new(buffer, clen) }
    }

    unsafe fn advance_to(&mut self, _: usize) {}

    fn is_empty(&self) -> bool {
        false
    }
}

/// A wrapper around [`SubmitMultiManaged`] that submits the operation
/// automatically till the stream is finished.
pub struct SubmitMultiStream<F, T: OpCode, B = BufferRef>
where
    B: HandleBufferRef + 'static,
{
    create_op: F,
    op: Option<SubmitMultiManaged<T, B>>,
}

impl<F, T: OpCode, B: HandleBufferRef + 'static> SubmitMultiStream<F, T, B> {
    /// Create a new [`SubmitMultiStream`] with a closure that creates the
    /// operation.
    pub fn new(create_op: F) -> Self {
        Self {
            create_op,
            op: None,
        }
    }
}

impl<
    F: (Fn() -> std::io::Result<SubmitMultiManaged<T, B>>) + Unpin,
    T: OpCode + TakeBuffer<Buffer = B> + 'static,
    B: HandleBufferRef,
> Stream for SubmitMultiStream<F, T, B>
{
    type Item = std::io::Result<B>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        loop {
            match &mut self.op {
                Some(op) => match std::task::ready!(Pin::new(op).poll_next(cx)) {
                    Some(Ok(Some(buffer))) => {
                        if buffer.is_empty() {
                            break Poll::Ready(None);
                        } else {
                            break Poll::Ready(Some(Ok(buffer)));
                        }
                    }
                    Some(Ok(None)) => break Poll::Ready(None),
                    Some(Err(e)) => break Poll::Ready(Some(Err(e))),
                    None => self.op = None,
                },
                None => match (self.create_op)() {
                    Ok(op) => self.op = Some(op),
                    Err(e) => break Poll::Ready(Some(Err(e))),
                },
            }
        }
    }
}