Skip to main content

compio_runtime/future/
stream.rs

1use std::{
2    cell::RefCell,
3    marker::PhantomData,
4    pin::Pin,
5    rc::Rc,
6    task::{Context, Poll},
7};
8
9use compio_buf::{BufResult, SetLen};
10use compio_driver::{
11    BufferPool, BufferRef, Extra, Key, OpCode, Proactor, PushEntry, TakeBuffer,
12    op::{RecvFromMultiResult, RecvMsgMultiResult},
13};
14use futures_util::{Stream, StreamExt, stream::FusedStream};
15
16use crate::{
17    ContextExt,
18    future::{poll_multishot, poll_task_with_extra, submit_raw},
19};
20
21pin_project_lite::pin_project! {
22    /// Returned [`Stream`] for [`Runtime::submit_multi`].
23    ///
24    /// When this is dropped and the operation hasn't finished yet, it will try to
25    /// cancel the operation.
26    pub struct SubmitMulti<T: OpCode> {
27        driver: Rc<RefCell<Proactor>>,
28        state: Option<State<T>>,
29    }
30
31    impl<T: OpCode> PinnedDrop for SubmitMulti<T> {
32        fn drop(this: Pin<&mut Self>) {
33            let this = this.project();
34            if let Some(State::Submitted { key }) = this.state.take() {
35                this.driver.borrow_mut().cancel(key);
36            }
37        }
38    }
39}
40
41enum State<T: OpCode> {
42    Idle { op: T },
43    Submitted { key: Key<T> },
44    Finished { op: T },
45}
46
47impl<T: OpCode> State<T> {
48    fn submitted(key: Key<T>) -> Self {
49        State::Submitted { key }
50    }
51}
52
53impl<T: OpCode> SubmitMulti<T> {
54    pub(crate) fn new(driver: Rc<RefCell<Proactor>>, op: T) -> Self {
55        SubmitMulti {
56            driver,
57            state: Some(State::Idle { op }),
58        }
59    }
60
61    /// Try to take the inner op from the stream.
62    ///
63    /// Returns `Ok(T)` if the stream:
64    ///
65    /// - has not been polled yet, or
66    /// - is finished and the op is returned by the driver
67    ///
68    /// Returns `Err(Self)` if it's still running.
69    pub fn try_take(mut self) -> Result<T, Self> {
70        match self.state.take() {
71            Some(State::Finished { op }) | Some(State::Idle { op }) => Ok(op),
72            state => {
73                debug_assert!(state.is_some());
74                self.state = state;
75                Err(self)
76            }
77        }
78    }
79}
80
81impl<T: OpCode + 'static> Stream for SubmitMulti<T> {
82    type Item = BufResult<usize, Extra>;
83
84    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
85        let this = self.project();
86
87        loop {
88            match this.state.take().expect("State error, this is a bug") {
89                State::Idle { op } => {
90                    let extra = cx.as_extra(|| this.driver.borrow().default_extra());
91                    let entry = submit_raw(&mut this.driver.borrow_mut(), op, extra);
92                    match entry {
93                        PushEntry::Pending(key) => {
94                            if let Some(cancel) = cx.get_cancel() {
95                                cancel.register(&key);
96                            }
97
98                            *this.state = Some(State::submitted(key))
99                        }
100                        PushEntry::Ready(BufResult(res, op)) => {
101                            *this.state = Some(State::Finished { op });
102                            let extra = this.driver.borrow().default_extra();
103
104                            return Poll::Ready(Some(BufResult(res, extra)));
105                        }
106                    }
107                }
108
109                State::Submitted { key, .. } => {
110                    if let Some(res) =
111                        poll_multishot(&mut this.driver.borrow_mut(), cx.get_waker(), &key)
112                    {
113                        *this.state = Some(State::submitted(key));
114
115                        return Poll::Ready(Some(res));
116                    };
117
118                    let entry =
119                        poll_task_with_extra(&mut this.driver.borrow_mut(), cx.get_waker(), key);
120                    match entry {
121                        PushEntry::Pending(key) => {
122                            *this.state = Some(State::submitted(key));
123
124                            return Poll::Pending;
125                        }
126                        PushEntry::Ready((BufResult(res, op), extra)) => {
127                            *this.state = Some(State::Finished { op });
128
129                            return Poll::Ready(Some(BufResult(res, extra)));
130                        }
131                    }
132                }
133
134                State::Finished { op } => {
135                    *this.state = Some(State::Finished { op });
136
137                    return Poll::Ready(None);
138                }
139            }
140        }
141    }
142}
143
144impl<T: OpCode + 'static> FusedStream for SubmitMulti<T> {
145    fn is_terminated(&self) -> bool {
146        matches!(self.state, None | Some(State::Finished { .. }))
147    }
148}
149
150impl<T: OpCode + TakeBuffer + 'static> SubmitMulti<T>
151where
152    <T as TakeBuffer>::Buffer: HandleBufferRef<Param = ()>,
153{
154    /// Convert this stream into one that iterates the buffers from the results.
155    pub fn into_managed(self, buffer_pool: BufferPool) -> SubmitMultiManaged<T, T::Buffer> {
156        SubmitMultiManaged::new(self, buffer_pool, ())
157    }
158}
159
160impl<T: OpCode + TakeBuffer + 'static> SubmitMulti<T>
161where
162    <T as TakeBuffer>::Buffer: HandleBufferRef,
163{
164    /// Convert this stream into one that iterates the buffers from the results,
165    /// with a param to construct the result item.
166    pub fn into_managed_with(
167        self,
168        buffer_pool: BufferPool,
169        param: <<T as TakeBuffer>::Buffer as HandleBufferRef>::Param,
170    ) -> SubmitMultiManaged<T, T::Buffer> {
171        SubmitMultiManaged::new(self, buffer_pool, param)
172    }
173}
174
175/// A wrapper around [`SubmitMulti`] that iterates the buffers from the results.
176pub struct SubmitMultiManaged<T: OpCode, B = BufferRef>
177where
178    B: HandleBufferRef + 'static,
179{
180    inner: Option<SubmitMulti<T>>,
181    buffer_pool: BufferPool,
182    param: <B as HandleBufferRef>::Param,
183    _p: PhantomData<&'static B>,
184}
185
186impl<T: OpCode, B: HandleBufferRef + 'static> SubmitMultiManaged<T, B> {
187    fn new(
188        stream: SubmitMulti<T>,
189        buffer_pool: BufferPool,
190        param: <B as HandleBufferRef>::Param,
191    ) -> Self {
192        Self {
193            inner: Some(stream),
194            buffer_pool,
195            param,
196            _p: PhantomData,
197        }
198    }
199}
200
201impl<T: OpCode + TakeBuffer<Buffer = B> + 'static, B: HandleBufferRef> Stream
202    for SubmitMultiManaged<T, B>
203{
204    type Item = std::io::Result<Option<B>>;
205
206    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
207        if let Some(inner) = self.inner.as_mut() {
208            let buffer = match std::task::ready!(inner.poll_next_unpin(cx)) {
209                Some(BufResult(res, extra)) => {
210                    if inner.is_terminated() {
211                        let mut b = self
212                            .inner
213                            .take()
214                            .and_then(|s| s.try_take().ok())
215                            .and_then(|op| op.take_buffer());
216                        let res = res?;
217                        if let Some(ref mut b) = b {
218                            unsafe { b.advance_to(res) }
219                        }
220                        b
221                    } else {
222                        let b = self.buffer_pool.take(extra.buffer_id()?)?;
223                        let res = res?;
224                        if let Some(mut b) = b {
225                            unsafe {
226                                SetLen::advance_to(&mut b, res);
227                                Some(B::from_buffer_ref(b, self.param))
228                            }
229                        } else {
230                            None
231                        }
232                    }
233                }
234                None => self
235                    .inner
236                    .take()
237                    .and_then(|s| s.try_take().ok())
238                    .and_then(|op| op.take_buffer()),
239            };
240            Poll::Ready(Some(Ok(buffer)))
241        } else {
242            Poll::Ready(None)
243        }
244    }
245}
246
247impl<T: OpCode + TakeBuffer<Buffer = B> + 'static, B: HandleBufferRef> FusedStream
248    for SubmitMultiManaged<T, B>
249{
250    fn is_terminated(&self) -> bool {
251        self.inner.as_ref().is_none_or(|s| s.is_terminated())
252    }
253}
254
255mod private {
256    use super::*;
257
258    pub trait Sealed {}
259
260    impl Sealed for BufferRef {}
261    impl Sealed for RecvFromMultiResult {}
262    impl Sealed for RecvMsgMultiResult {}
263}
264
265#[doc(hidden)]
266pub trait HandleBufferRef: private::Sealed {
267    type Param: Copy + Unpin;
268
269    unsafe fn from_buffer_ref(buffer: BufferRef, param: Self::Param) -> Self;
270
271    unsafe fn advance_to(&mut self, len: usize);
272
273    fn is_empty(&self) -> bool;
274}
275
276impl HandleBufferRef for BufferRef {
277    type Param = ();
278
279    unsafe fn from_buffer_ref(buffer: BufferRef, _: Self::Param) -> Self {
280        buffer
281    }
282
283    unsafe fn advance_to(&mut self, len: usize) {
284        unsafe { SetLen::advance_to(self, len) }
285    }
286
287    fn is_empty(&self) -> bool {
288        // A fallback buffer pool takes the buffer before the operation, so it
289        // can return an empty buffer when EOF is reached.
290        <[u8]>::is_empty(self)
291    }
292}
293
294impl HandleBufferRef for RecvFromMultiResult {
295    type Param = ();
296
297    unsafe fn from_buffer_ref(buffer: BufferRef, _: Self::Param) -> Self {
298        unsafe { RecvFromMultiResult::new(buffer) }
299    }
300
301    unsafe fn advance_to(&mut self, _: usize) {}
302
303    fn is_empty(&self) -> bool {
304        false
305    }
306}
307
308impl HandleBufferRef for RecvMsgMultiResult {
309    type Param = usize;
310
311    unsafe fn from_buffer_ref(buffer: BufferRef, clen: usize) -> Self {
312        unsafe { RecvMsgMultiResult::new(buffer, clen) }
313    }
314
315    unsafe fn advance_to(&mut self, _: usize) {}
316
317    fn is_empty(&self) -> bool {
318        false
319    }
320}
321
322/// A wrapper around [`SubmitMultiManaged`] that submits the operation
323/// automatically till the stream is finished.
324pub struct SubmitMultiStream<F, T: OpCode, B = BufferRef>
325where
326    B: HandleBufferRef + 'static,
327{
328    create_op: F,
329    op: Option<SubmitMultiManaged<T, B>>,
330}
331
332impl<F, T: OpCode, B: HandleBufferRef + 'static> SubmitMultiStream<F, T, B> {
333    /// Create a new [`SubmitMultiStream`] with a closure that creates the
334    /// operation.
335    pub fn new(create_op: F) -> Self {
336        Self {
337            create_op,
338            op: None,
339        }
340    }
341}
342
343impl<
344    F: (Fn() -> std::io::Result<SubmitMultiManaged<T, B>>) + Unpin,
345    T: OpCode + TakeBuffer<Buffer = B> + 'static,
346    B: HandleBufferRef,
347> Stream for SubmitMultiStream<F, T, B>
348{
349    type Item = std::io::Result<B>;
350
351    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
352        loop {
353            match &mut self.op {
354                Some(op) => match std::task::ready!(Pin::new(op).poll_next(cx)) {
355                    Some(Ok(Some(buffer))) => {
356                        if buffer.is_empty() {
357                            break Poll::Ready(None);
358                        } else {
359                            break Poll::Ready(Some(Ok(buffer)));
360                        }
361                    }
362                    Some(Ok(None)) => break Poll::Ready(None),
363                    Some(Err(e)) => break Poll::Ready(Some(Err(e))),
364                    None => self.op = None,
365                },
366                None => match (self.create_op)() {
367                    Ok(op) => self.op = Some(op),
368                    Err(e) => break Poll::Ready(Some(Err(e))),
369                },
370            }
371        }
372    }
373}