1use std::{
2 marker::PhantomData,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use compio_buf::{BufResult, SetLen};
8use compio_driver::{
9 BufferPool, BufferRef, Extra, Key, OpCode, PushEntry, TakeBuffer,
10 op::{RecvFromMultiResult, RecvMsgMultiResult},
11};
12use futures_util::{Stream, StreamExt, stream::FusedStream};
13
14use crate::{ContextExt, Runtime};
15
16pin_project_lite::pin_project! {
17 pub struct SubmitMulti<T: OpCode> {
22 runtime: Runtime,
23 state: Option<State<T>>,
24 }
25
26 impl<T: OpCode> PinnedDrop for SubmitMulti<T> {
27 fn drop(this: Pin<&mut Self>) {
28 let this = this.project();
29 if let Some(State::Submitted { key }) = this.state.take() {
30 this.runtime.cancel(key);
31 }
32 }
33 }
34}
35
36enum State<T: OpCode> {
37 Idle { op: T },
38 Submitted { key: Key<T> },
39 Finished { op: T },
40}
41
42impl<T: OpCode> State<T> {
43 fn submitted(key: Key<T>) -> Self {
44 State::Submitted { key }
45 }
46}
47
48impl<T: OpCode> SubmitMulti<T> {
49 pub(crate) fn new(runtime: Runtime, op: T) -> Self {
50 SubmitMulti {
51 runtime,
52 state: Some(State::Idle { op }),
53 }
54 }
55
56 pub fn try_take(mut self) -> Result<T, Self> {
65 match self.state.take() {
66 Some(State::Finished { op }) | Some(State::Idle { op }) => Ok(op),
67 state => {
68 debug_assert!(state.is_some());
69 self.state = state;
70 Err(self)
71 }
72 }
73 }
74}
75
76impl<T: OpCode + 'static> Stream for SubmitMulti<T> {
77 type Item = BufResult<usize, Extra>;
78
79 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
80 let this = self.project();
81
82 loop {
83 match this.state.take().expect("State error, this is a bug") {
84 State::Idle { op } => {
85 let extra = cx.as_extra(|| this.runtime.default_extra());
86 match this.runtime.submit_raw(op, extra) {
87 PushEntry::Pending(key) => {
88 if let Some(cancel) = cx.get_cancel() {
89 cancel.register(&key);
90 }
91
92 *this.state = Some(State::submitted(key))
93 }
94 PushEntry::Ready(BufResult(res, op)) => {
95 *this.state = Some(State::Finished { op });
96 let extra = this.runtime.default_extra();
97
98 return Poll::Ready(Some(BufResult(res, extra)));
99 }
100 }
101 }
102
103 State::Submitted { key, .. } => {
104 if let Some(res) = this.runtime.poll_multishot(cx.get_waker(), &key) {
105 *this.state = Some(State::submitted(key));
106
107 return Poll::Ready(Some(res));
108 };
109
110 match this.runtime.poll_task_with_extra(cx.get_waker(), key) {
111 PushEntry::Pending(key) => {
112 *this.state = Some(State::submitted(key));
113
114 return Poll::Pending;
115 }
116 PushEntry::Ready((BufResult(res, op), extra)) => {
117 *this.state = Some(State::Finished { op });
118
119 return Poll::Ready(Some(BufResult(res, extra)));
120 }
121 }
122 }
123
124 State::Finished { op } => {
125 *this.state = Some(State::Finished { op });
126
127 return Poll::Ready(None);
128 }
129 }
130 }
131 }
132}
133
134impl<T: OpCode + 'static> FusedStream for SubmitMulti<T> {
135 fn is_terminated(&self) -> bool {
136 matches!(self.state, None | Some(State::Finished { .. }))
137 }
138}
139
140impl<T: OpCode + TakeBuffer + 'static> SubmitMulti<T>
141where
142 <T as TakeBuffer>::Buffer: HandleBufferRef<Param = ()>,
143{
144 pub fn into_managed(self, buffer_pool: BufferPool) -> SubmitMultiManaged<T, T::Buffer> {
146 SubmitMultiManaged::new(self, buffer_pool, ())
147 }
148}
149
150impl<T: OpCode + TakeBuffer + 'static> SubmitMulti<T>
151where
152 <T as TakeBuffer>::Buffer: HandleBufferRef,
153{
154 pub fn into_managed_with(
157 self,
158 buffer_pool: BufferPool,
159 param: <<T as TakeBuffer>::Buffer as HandleBufferRef>::Param,
160 ) -> SubmitMultiManaged<T, T::Buffer> {
161 SubmitMultiManaged::new(self, buffer_pool, param)
162 }
163}
164
165pub struct SubmitMultiManaged<T: OpCode, B = BufferRef>
167where
168 B: HandleBufferRef + 'static,
169{
170 inner: Option<SubmitMulti<T>>,
171 buffer_pool: BufferPool,
172 param: <B as HandleBufferRef>::Param,
173 _p: PhantomData<&'static B>,
174}
175
176impl<T: OpCode, B: HandleBufferRef + 'static> SubmitMultiManaged<T, B> {
177 fn new(
178 stream: SubmitMulti<T>,
179 buffer_pool: BufferPool,
180 param: <B as HandleBufferRef>::Param,
181 ) -> Self {
182 Self {
183 inner: Some(stream),
184 buffer_pool,
185 param,
186 _p: PhantomData,
187 }
188 }
189}
190
191impl<T: OpCode + TakeBuffer<Buffer = B> + 'static, B: HandleBufferRef> Stream
192 for SubmitMultiManaged<T, B>
193{
194 type Item = std::io::Result<B>;
195
196 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
197 if let Some(inner) = self.inner.as_mut() {
198 let buffer = match std::task::ready!(inner.poll_next_unpin(cx)) {
199 Some(BufResult(res, extra)) => {
200 if inner.is_terminated() {
201 let mut b = self
202 .inner
203 .take()
204 .and_then(|s| s.try_take().ok())
205 .and_then(|op| op.take_buffer());
206 let res = res?;
207 if let Some(ref mut b) = b {
208 unsafe { b.advance_to(res) }
209 }
210 b
211 } else {
212 let b = self.buffer_pool.take(extra.buffer_id()?)?;
213 let res = res?;
214 if let Some(mut b) = b {
215 unsafe {
216 SetLen::advance_to(&mut b, res);
217 Some(B::from_buffer_ref(b, self.param))
218 }
219 } else {
220 None
221 }
222 }
223 }
224 None => self
225 .inner
226 .take()
227 .and_then(|s| s.try_take().ok())
228 .and_then(|op| op.take_buffer()),
229 };
230 Poll::Ready(buffer.map(Ok))
231 } else {
232 Poll::Ready(None)
233 }
234 }
235}
236
237impl<T: OpCode + TakeBuffer<Buffer = B> + 'static, B: HandleBufferRef> FusedStream
238 for SubmitMultiManaged<T, B>
239{
240 fn is_terminated(&self) -> bool {
241 self.inner.as_ref().is_none_or(|s| s.is_terminated())
242 }
243}
244
245mod private {
246 use super::*;
247
248 pub trait Sealed {}
249
250 impl Sealed for BufferRef {}
251 impl Sealed for RecvFromMultiResult {}
252 impl Sealed for RecvMsgMultiResult {}
253}
254
255#[doc(hidden)]
256pub trait HandleBufferRef: private::Sealed {
257 type Param: Copy + Unpin;
258
259 unsafe fn from_buffer_ref(buffer: BufferRef, param: Self::Param) -> Self;
260
261 unsafe fn advance_to(&mut self, len: usize);
262}
263
264impl HandleBufferRef for BufferRef {
265 type Param = ();
266
267 unsafe fn from_buffer_ref(buffer: BufferRef, _: Self::Param) -> Self {
268 buffer
269 }
270
271 unsafe fn advance_to(&mut self, len: usize) {
272 unsafe { SetLen::advance_to(self, len) }
273 }
274}
275
276impl HandleBufferRef for RecvFromMultiResult {
277 type Param = ();
278
279 unsafe fn from_buffer_ref(buffer: BufferRef, _: Self::Param) -> Self {
280 unsafe { RecvFromMultiResult::new(buffer) }
281 }
282
283 unsafe fn advance_to(&mut self, _: usize) {}
284}
285
286impl HandleBufferRef for RecvMsgMultiResult {
287 type Param = usize;
288
289 unsafe fn from_buffer_ref(buffer: BufferRef, clen: usize) -> Self {
290 unsafe { RecvMsgMultiResult::new(buffer, clen) }
291 }
292
293 unsafe fn advance_to(&mut self, _: usize) {}
294}