Skip to main content

compio_runtime/future/
stream.rs

1use std::{
2    marker::PhantomData,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use compio_buf::{BufResult, SetLen};
8use compio_driver::{
9    BufferPool, BufferRef, Extra, Key, OpCode, PushEntry, TakeBuffer,
10    op::{RecvFromMultiResult, RecvMsgMultiResult},
11};
12use futures_util::{Stream, StreamExt, stream::FusedStream};
13
14use crate::{ContextExt, Runtime};
15
16pin_project_lite::pin_project! {
17    /// Returned [`Stream`] for [`Runtime::submit_multi`].
18    ///
19    /// When this is dropped and the operation hasn't finished yet, it will try to
20    /// cancel the operation.
21    pub struct SubmitMulti<T: OpCode> {
22        runtime: Runtime,
23        state: Option<State<T>>,
24    }
25
26    impl<T: OpCode> PinnedDrop for SubmitMulti<T> {
27        fn drop(this: Pin<&mut Self>) {
28            let this = this.project();
29            if let Some(State::Submitted { key }) = this.state.take() {
30                this.runtime.cancel(key);
31            }
32        }
33    }
34}
35
36enum State<T: OpCode> {
37    Idle { op: T },
38    Submitted { key: Key<T> },
39    Finished { op: T },
40}
41
42impl<T: OpCode> State<T> {
43    fn submitted(key: Key<T>) -> Self {
44        State::Submitted { key }
45    }
46}
47
48impl<T: OpCode> SubmitMulti<T> {
49    pub(crate) fn new(runtime: Runtime, op: T) -> Self {
50        SubmitMulti {
51            runtime,
52            state: Some(State::Idle { op }),
53        }
54    }
55
56    /// Try to take the inner op from the stream.
57    ///
58    /// Returns `Ok(T)` if the stream:
59    ///
60    /// - has not been polled yet, or
61    /// - is finished and the op is returned by the driver
62    ///
63    /// Returns `Err(Self)` if it's still running.
64    pub fn try_take(mut self) -> Result<T, Self> {
65        match self.state.take() {
66            Some(State::Finished { op }) | Some(State::Idle { op }) => Ok(op),
67            state => {
68                debug_assert!(state.is_some());
69                self.state = state;
70                Err(self)
71            }
72        }
73    }
74}
75
76impl<T: OpCode + 'static> Stream for SubmitMulti<T> {
77    type Item = BufResult<usize, Extra>;
78
79    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
80        let this = self.project();
81
82        loop {
83            match this.state.take().expect("State error, this is a bug") {
84                State::Idle { op } => {
85                    let extra = cx.as_extra(|| this.runtime.default_extra());
86                    match this.runtime.submit_raw(op, extra) {
87                        PushEntry::Pending(key) => {
88                            if let Some(cancel) = cx.get_cancel() {
89                                cancel.register(&key);
90                            }
91
92                            *this.state = Some(State::submitted(key))
93                        }
94                        PushEntry::Ready(BufResult(res, op)) => {
95                            *this.state = Some(State::Finished { op });
96                            let extra = this.runtime.default_extra();
97
98                            return Poll::Ready(Some(BufResult(res, extra)));
99                        }
100                    }
101                }
102
103                State::Submitted { key, .. } => {
104                    if let Some(res) = this.runtime.poll_multishot(cx.get_waker(), &key) {
105                        *this.state = Some(State::submitted(key));
106
107                        return Poll::Ready(Some(res));
108                    };
109
110                    match this.runtime.poll_task_with_extra(cx.get_waker(), key) {
111                        PushEntry::Pending(key) => {
112                            *this.state = Some(State::submitted(key));
113
114                            return Poll::Pending;
115                        }
116                        PushEntry::Ready((BufResult(res, op), extra)) => {
117                            *this.state = Some(State::Finished { op });
118
119                            return Poll::Ready(Some(BufResult(res, extra)));
120                        }
121                    }
122                }
123
124                State::Finished { op } => {
125                    *this.state = Some(State::Finished { op });
126
127                    return Poll::Ready(None);
128                }
129            }
130        }
131    }
132}
133
134impl<T: OpCode + 'static> FusedStream for SubmitMulti<T> {
135    fn is_terminated(&self) -> bool {
136        matches!(self.state, None | Some(State::Finished { .. }))
137    }
138}
139
140impl<T: OpCode + TakeBuffer + 'static> SubmitMulti<T>
141where
142    <T as TakeBuffer>::Buffer: HandleBufferRef<Param = ()>,
143{
144    /// Convert this stream into one that iterates the buffers from the results.
145    pub fn into_managed(self, buffer_pool: BufferPool) -> SubmitMultiManaged<T, T::Buffer> {
146        SubmitMultiManaged::new(self, buffer_pool, ())
147    }
148}
149
150impl<T: OpCode + TakeBuffer + 'static> SubmitMulti<T>
151where
152    <T as TakeBuffer>::Buffer: HandleBufferRef,
153{
154    /// Convert this stream into one that iterates the buffers from the results,
155    /// with a param to construct the result item.
156    pub fn into_managed_with(
157        self,
158        buffer_pool: BufferPool,
159        param: <<T as TakeBuffer>::Buffer as HandleBufferRef>::Param,
160    ) -> SubmitMultiManaged<T, T::Buffer> {
161        SubmitMultiManaged::new(self, buffer_pool, param)
162    }
163}
164
165/// A wrapper around [`SubmitMulti`] that iterates the buffers from the results.
166pub struct SubmitMultiManaged<T: OpCode, B = BufferRef>
167where
168    B: HandleBufferRef + 'static,
169{
170    inner: Option<SubmitMulti<T>>,
171    buffer_pool: BufferPool,
172    param: <B as HandleBufferRef>::Param,
173    _p: PhantomData<&'static B>,
174}
175
176impl<T: OpCode, B: HandleBufferRef + 'static> SubmitMultiManaged<T, B> {
177    fn new(
178        stream: SubmitMulti<T>,
179        buffer_pool: BufferPool,
180        param: <B as HandleBufferRef>::Param,
181    ) -> Self {
182        Self {
183            inner: Some(stream),
184            buffer_pool,
185            param,
186            _p: PhantomData,
187        }
188    }
189}
190
191impl<T: OpCode + TakeBuffer<Buffer = B> + 'static, B: HandleBufferRef> Stream
192    for SubmitMultiManaged<T, B>
193{
194    type Item = std::io::Result<B>;
195
196    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
197        if let Some(inner) = self.inner.as_mut() {
198            let buffer = match std::task::ready!(inner.poll_next_unpin(cx)) {
199                Some(BufResult(res, extra)) => {
200                    if inner.is_terminated() {
201                        let mut b = self
202                            .inner
203                            .take()
204                            .and_then(|s| s.try_take().ok())
205                            .and_then(|op| op.take_buffer());
206                        let res = res?;
207                        if let Some(ref mut b) = b {
208                            unsafe { b.advance_to(res) }
209                        }
210                        b
211                    } else {
212                        let b = self.buffer_pool.take(extra.buffer_id()?)?;
213                        let res = res?;
214                        if let Some(mut b) = b {
215                            unsafe {
216                                SetLen::advance_to(&mut b, res);
217                                Some(B::from_buffer_ref(b, self.param))
218                            }
219                        } else {
220                            None
221                        }
222                    }
223                }
224                None => self
225                    .inner
226                    .take()
227                    .and_then(|s| s.try_take().ok())
228                    .and_then(|op| op.take_buffer()),
229            };
230            Poll::Ready(buffer.map(Ok))
231        } else {
232            Poll::Ready(None)
233        }
234    }
235}
236
237impl<T: OpCode + TakeBuffer<Buffer = B> + 'static, B: HandleBufferRef> FusedStream
238    for SubmitMultiManaged<T, B>
239{
240    fn is_terminated(&self) -> bool {
241        self.inner.as_ref().is_none_or(|s| s.is_terminated())
242    }
243}
244
245mod private {
246    use super::*;
247
248    pub trait Sealed {}
249
250    impl Sealed for BufferRef {}
251    impl Sealed for RecvFromMultiResult {}
252    impl Sealed for RecvMsgMultiResult {}
253}
254
255#[doc(hidden)]
256pub trait HandleBufferRef: private::Sealed {
257    type Param: Copy + Unpin;
258
259    unsafe fn from_buffer_ref(buffer: BufferRef, param: Self::Param) -> Self;
260
261    unsafe fn advance_to(&mut self, len: usize);
262}
263
264impl HandleBufferRef for BufferRef {
265    type Param = ();
266
267    unsafe fn from_buffer_ref(buffer: BufferRef, _: Self::Param) -> Self {
268        buffer
269    }
270
271    unsafe fn advance_to(&mut self, len: usize) {
272        unsafe { SetLen::advance_to(self, len) }
273    }
274}
275
276impl HandleBufferRef for RecvFromMultiResult {
277    type Param = ();
278
279    unsafe fn from_buffer_ref(buffer: BufferRef, _: Self::Param) -> Self {
280        unsafe { RecvFromMultiResult::new(buffer) }
281    }
282
283    unsafe fn advance_to(&mut self, _: usize) {}
284}
285
286impl HandleBufferRef for RecvMsgMultiResult {
287    type Param = usize;
288
289    unsafe fn from_buffer_ref(buffer: BufferRef, clen: usize) -> Self {
290        unsafe { RecvMsgMultiResult::new(buffer, clen) }
291    }
292
293    unsafe fn advance_to(&mut self, _: usize) {}
294}