1use std::{
2 cell::RefCell,
3 marker::PhantomData,
4 pin::Pin,
5 rc::Rc,
6 task::{Context, Poll},
7};
8
9use compio_buf::{BufResult, SetLen};
10use compio_driver::{
11 BufferPool, BufferRef, Extra, Key, OpCode, Proactor, PushEntry, TakeBuffer,
12 op::{RecvFromMultiResult, RecvMsgMultiResult},
13};
14use futures_util::{Stream, StreamExt, stream::FusedStream};
15
16use crate::{
17 ContextExt,
18 future::{poll_multishot, poll_task_with_extra, submit_raw},
19};
20
21pin_project_lite::pin_project! {
22 pub struct SubmitMulti<T: OpCode> {
27 driver: Rc<RefCell<Proactor>>,
28 state: Option<State<T>>,
29 }
30
31 impl<T: OpCode> PinnedDrop for SubmitMulti<T> {
32 fn drop(this: Pin<&mut Self>) {
33 let this = this.project();
34 if let Some(State::Submitted { key }) = this.state.take() {
35 this.driver.borrow_mut().cancel(key);
36 }
37 }
38 }
39}
40
41enum State<T: OpCode> {
42 Idle { op: T },
43 Submitted { key: Key<T> },
44 Finished { op: T },
45}
46
47impl<T: OpCode> State<T> {
48 fn submitted(key: Key<T>) -> Self {
49 State::Submitted { key }
50 }
51}
52
53impl<T: OpCode> SubmitMulti<T> {
54 pub(crate) fn new(driver: Rc<RefCell<Proactor>>, op: T) -> Self {
55 SubmitMulti {
56 driver,
57 state: Some(State::Idle { op }),
58 }
59 }
60
61 pub fn try_take(mut self) -> Result<T, Self> {
70 match self.state.take() {
71 Some(State::Finished { op }) | Some(State::Idle { op }) => Ok(op),
72 state => {
73 debug_assert!(state.is_some());
74 self.state = state;
75 Err(self)
76 }
77 }
78 }
79}
80
81impl<T: OpCode + 'static> Stream for SubmitMulti<T> {
82 type Item = BufResult<usize, Extra>;
83
84 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
85 let this = self.project();
86
87 loop {
88 match this.state.take().expect("State error, this is a bug") {
89 State::Idle { op } => {
90 let extra = cx.as_extra(|| this.driver.borrow().default_extra());
91 let entry = submit_raw(&mut this.driver.borrow_mut(), op, extra);
92 match entry {
93 PushEntry::Pending(key) => {
94 if let Some(cancel) = cx.get_cancel() {
95 cancel.register(&key);
96 }
97
98 *this.state = Some(State::submitted(key))
99 }
100 PushEntry::Ready(BufResult(res, op)) => {
101 *this.state = Some(State::Finished { op });
102 let extra = this.driver.borrow().default_extra();
103
104 return Poll::Ready(Some(BufResult(res, extra)));
105 }
106 }
107 }
108
109 State::Submitted { key, .. } => {
110 if let Some(res) =
111 poll_multishot(&mut this.driver.borrow_mut(), cx.get_waker(), &key)
112 {
113 *this.state = Some(State::submitted(key));
114
115 return Poll::Ready(Some(res));
116 };
117
118 let entry =
119 poll_task_with_extra(&mut this.driver.borrow_mut(), cx.get_waker(), key);
120 match entry {
121 PushEntry::Pending(key) => {
122 *this.state = Some(State::submitted(key));
123
124 return Poll::Pending;
125 }
126 PushEntry::Ready((BufResult(res, op), extra)) => {
127 *this.state = Some(State::Finished { op });
128
129 return Poll::Ready(Some(BufResult(res, extra)));
130 }
131 }
132 }
133
134 State::Finished { op } => {
135 *this.state = Some(State::Finished { op });
136
137 return Poll::Ready(None);
138 }
139 }
140 }
141 }
142}
143
144impl<T: OpCode + 'static> FusedStream for SubmitMulti<T> {
145 fn is_terminated(&self) -> bool {
146 matches!(self.state, None | Some(State::Finished { .. }))
147 }
148}
149
150impl<T: OpCode + TakeBuffer + 'static> SubmitMulti<T>
151where
152 <T as TakeBuffer>::Buffer: HandleBufferRef<Param = ()>,
153{
154 pub fn into_managed(self, buffer_pool: BufferPool) -> SubmitMultiManaged<T, T::Buffer> {
156 SubmitMultiManaged::new(self, buffer_pool, ())
157 }
158}
159
160impl<T: OpCode + TakeBuffer + 'static> SubmitMulti<T>
161where
162 <T as TakeBuffer>::Buffer: HandleBufferRef,
163{
164 pub fn into_managed_with(
167 self,
168 buffer_pool: BufferPool,
169 param: <<T as TakeBuffer>::Buffer as HandleBufferRef>::Param,
170 ) -> SubmitMultiManaged<T, T::Buffer> {
171 SubmitMultiManaged::new(self, buffer_pool, param)
172 }
173}
174
175pub struct SubmitMultiManaged<T: OpCode, B = BufferRef>
177where
178 B: HandleBufferRef + 'static,
179{
180 inner: Option<SubmitMulti<T>>,
181 buffer_pool: BufferPool,
182 param: <B as HandleBufferRef>::Param,
183 _p: PhantomData<&'static B>,
184}
185
186impl<T: OpCode, B: HandleBufferRef + 'static> SubmitMultiManaged<T, B> {
187 fn new(
188 stream: SubmitMulti<T>,
189 buffer_pool: BufferPool,
190 param: <B as HandleBufferRef>::Param,
191 ) -> Self {
192 Self {
193 inner: Some(stream),
194 buffer_pool,
195 param,
196 _p: PhantomData,
197 }
198 }
199}
200
201impl<T: OpCode + TakeBuffer<Buffer = B> + 'static, B: HandleBufferRef> Stream
202 for SubmitMultiManaged<T, B>
203{
204 type Item = std::io::Result<Option<B>>;
205
206 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
207 if let Some(inner) = self.inner.as_mut() {
208 let buffer = match std::task::ready!(inner.poll_next_unpin(cx)) {
209 Some(BufResult(res, extra)) => {
210 if inner.is_terminated() {
211 let mut b = self
212 .inner
213 .take()
214 .and_then(|s| s.try_take().ok())
215 .and_then(|op| op.take_buffer());
216 let res = res?;
217 if let Some(ref mut b) = b {
218 unsafe { b.advance_to(res) }
219 }
220 b
221 } else {
222 let b = self.buffer_pool.take(extra.buffer_id()?)?;
223 let res = res?;
224 if let Some(mut b) = b {
225 unsafe {
226 SetLen::advance_to(&mut b, res);
227 Some(B::from_buffer_ref(b, self.param))
228 }
229 } else {
230 None
231 }
232 }
233 }
234 None => self
235 .inner
236 .take()
237 .and_then(|s| s.try_take().ok())
238 .and_then(|op| op.take_buffer()),
239 };
240 Poll::Ready(Some(Ok(buffer)))
241 } else {
242 Poll::Ready(None)
243 }
244 }
245}
246
247impl<T: OpCode + TakeBuffer<Buffer = B> + 'static, B: HandleBufferRef> FusedStream
248 for SubmitMultiManaged<T, B>
249{
250 fn is_terminated(&self) -> bool {
251 self.inner.as_ref().is_none_or(|s| s.is_terminated())
252 }
253}
254
255mod private {
256 use super::*;
257
258 pub trait Sealed {}
259
260 impl Sealed for BufferRef {}
261 impl Sealed for RecvFromMultiResult {}
262 impl Sealed for RecvMsgMultiResult {}
263}
264
265#[doc(hidden)]
266pub trait HandleBufferRef: private::Sealed {
267 type Param: Copy + Unpin;
268
269 unsafe fn from_buffer_ref(buffer: BufferRef, param: Self::Param) -> Self;
270
271 unsafe fn advance_to(&mut self, len: usize);
272
273 fn is_empty(&self) -> bool;
274}
275
276impl HandleBufferRef for BufferRef {
277 type Param = ();
278
279 unsafe fn from_buffer_ref(buffer: BufferRef, _: Self::Param) -> Self {
280 buffer
281 }
282
283 unsafe fn advance_to(&mut self, len: usize) {
284 unsafe { SetLen::advance_to(self, len) }
285 }
286
287 fn is_empty(&self) -> bool {
288 <[u8]>::is_empty(self)
291 }
292}
293
294impl HandleBufferRef for RecvFromMultiResult {
295 type Param = ();
296
297 unsafe fn from_buffer_ref(buffer: BufferRef, _: Self::Param) -> Self {
298 unsafe { RecvFromMultiResult::new(buffer) }
299 }
300
301 unsafe fn advance_to(&mut self, _: usize) {}
302
303 fn is_empty(&self) -> bool {
304 false
305 }
306}
307
308impl HandleBufferRef for RecvMsgMultiResult {
309 type Param = usize;
310
311 unsafe fn from_buffer_ref(buffer: BufferRef, clen: usize) -> Self {
312 unsafe { RecvMsgMultiResult::new(buffer, clen) }
313 }
314
315 unsafe fn advance_to(&mut self, _: usize) {}
316
317 fn is_empty(&self) -> bool {
318 false
319 }
320}
321
322pub struct SubmitMultiStream<F, T: OpCode, B = BufferRef>
325where
326 B: HandleBufferRef + 'static,
327{
328 create_op: F,
329 op: Option<SubmitMultiManaged<T, B>>,
330}
331
332impl<F, T: OpCode, B: HandleBufferRef + 'static> SubmitMultiStream<F, T, B> {
333 pub fn new(create_op: F) -> Self {
336 Self {
337 create_op,
338 op: None,
339 }
340 }
341}
342
343impl<
344 F: (Fn() -> std::io::Result<SubmitMultiManaged<T, B>>) + Unpin,
345 T: OpCode + TakeBuffer<Buffer = B> + 'static,
346 B: HandleBufferRef,
347> Stream for SubmitMultiStream<F, T, B>
348{
349 type Item = std::io::Result<B>;
350
351 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
352 loop {
353 match &mut self.op {
354 Some(op) => match std::task::ready!(Pin::new(op).poll_next(cx)) {
355 Some(Ok(Some(buffer))) => {
356 if buffer.is_empty() {
357 break Poll::Ready(None);
358 } else {
359 break Poll::Ready(Some(Ok(buffer)));
360 }
361 }
362 Some(Ok(None)) => break Poll::Ready(None),
363 Some(Err(e)) => break Poll::Ready(Some(Err(e))),
364 None => self.op = None,
365 },
366 None => match (self.create_op)() {
367 Ok(op) => self.op = Some(op),
368 Err(e) => break Poll::Ready(Some(Err(e))),
369 },
370 }
371 }
372 }
373}