msquic_async/
stream.rs

1use crate::buffer::{StreamRecvBuffer, WriteBuffer};
2use crate::connection::ConnectionError;
3
4use std::collections::VecDeque;
5use std::fmt;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::{Arc, Mutex, RwLock};
9use std::task::{ready, Context, Poll, Waker};
10
11use bytes::Bytes;
12use libc::c_void;
13use rangemap::RangeSet;
14use thiserror::Error;
15use tracing::trace;
16
17#[derive(Debug, Clone, Copy, PartialEq)]
18pub enum StreamType {
19    Bidirectional,
20    Unidirectional,
21}
22
23/// A stream represents a bidirectional or unidirectional stream.
24#[derive(Debug)]
25pub struct Stream(Arc<StreamInstance>);
26
27impl Stream {
28    pub(crate) fn open(
29        msquic_conn: &msquic::Connection,
30        stream_type: StreamType,
31    ) -> Result<Self, StartError> {
32        let flags = if stream_type == StreamType::Unidirectional {
33            msquic::StreamOpenFlags::UNIDIRECTIONAL
34        } else {
35            msquic::StreamOpenFlags::NONE
36        };
37        let inner = Arc::new(StreamInner::new(
38            stream_type,
39            StreamSendState::Closed,
40            StreamRecvState::Closed,
41            true,
42        ));
43        let inner_in_ev = inner.clone();
44        let msquic_stream = msquic::Stream::open(msquic_conn, flags, move |stream_ref, ev| {
45            inner_in_ev.callback_handler_impl(stream_ref, ev)
46        })
47        .map_err(StartError::OtherError)?;
48        let stream_handle = unsafe { msquic_stream.as_raw() };
49        let instance = Arc::new(StreamInstance {
50            inner,
51            msquic_stream,
52        });
53        trace!(
54            "StreamInstance({:p}, Inner: {:p}, HQUIC: {:p}) Open by local",
55            instance,
56            instance.inner,
57            stream_handle
58        );
59
60        Ok(Self(instance))
61    }
62
63    pub(crate) fn from_raw(handle: msquic::ffi::HQUIC, stream_type: StreamType) -> Self {
64        let msquic_stream = unsafe { msquic::Stream::from_raw(handle) };
65        let send_state = if stream_type == StreamType::Bidirectional {
66            StreamSendState::StartComplete
67        } else {
68            StreamSendState::Closed
69        };
70        let inner = Arc::new(StreamInner::new(
71            stream_type,
72            send_state,
73            StreamRecvState::StartComplete,
74            false,
75        ));
76        let inner_in_ev = inner.clone();
77        msquic_stream.set_callback_handler(move |stream_ref, ev| {
78            inner_in_ev.callback_handler_impl(stream_ref, ev)
79        });
80        let stream_handle = unsafe { msquic_stream.as_raw() };
81        let stream = Self(Arc::new(StreamInstance {
82            inner,
83            msquic_stream,
84        }));
85        trace!(
86            "StreamInstance({:p}, Inner: {:p}, HQUIC: {:p}, id: {:?}) Start by peer",
87            stream.0,
88            stream.0.inner,
89            stream_handle,
90            stream.id()
91        );
92        stream
93    }
94
95    pub(crate) fn poll_start(
96        &mut self,
97        cx: &mut Context,
98        failed_on_block: bool,
99    ) -> Poll<Result<(), StartError>> {
100        let mut exclusive = self.0.inner.exclusive.lock().unwrap();
101        trace!(
102            "Stream(Inner: {:p}) poll_start state={:?}",
103            self.0.inner,
104            exclusive.state
105        );
106        match exclusive.state {
107            StreamState::Open => {
108                let res = self
109                    .0
110                    .msquic_stream
111                    .start(
112                        msquic::StreamStartFlags::SHUTDOWN_ON_FAIL
113                            | msquic::StreamStartFlags::INDICATE_PEER_ACCEPT
114                            | if failed_on_block {
115                                msquic::StreamStartFlags::FAIL_BLOCKED
116                            } else {
117                                msquic::StreamStartFlags::NONE
118                            },
119                    )
120                    .map_err(StartError::OtherError);
121                trace!(
122                    "Stream(Inner: {:p}) poll_start start={:?}",
123                    self.0.inner,
124                    res
125                );
126                res?;
127                exclusive.state = StreamState::Start;
128                if self.0.inner.shared.stream_type == StreamType::Bidirectional {
129                    exclusive.recv_state = StreamRecvState::Start;
130                }
131                exclusive.send_state = StreamSendState::Start;
132            }
133            StreamState::Start => {}
134            _ => {
135                if let Some(start_status) = &exclusive.start_status {
136                    if start_status.is_ok() {
137                        return Poll::Ready(Ok(()));
138                    }
139                    return Poll::Ready(Err(match start_status.try_as_status_code().unwrap() {
140                        msquic::StatusCode::QUIC_STATUS_STREAM_LIMIT_REACHED => {
141                            StartError::LimitReached
142                        }
143                        msquic::StatusCode::QUIC_STATUS_ABORTED
144                        | msquic::StatusCode::QUIC_STATUS_INVALID_STATE => {
145                            StartError::ConnectionLost(
146                                exclusive.conn_error.as_ref().expect("conn_error").clone(),
147                            )
148                        }
149                        _ => StartError::OtherError(start_status.clone()),
150                    }));
151                } else {
152                    return Poll::Ready(Ok(()));
153                }
154            }
155        }
156        exclusive.start_waiters.push(cx.waker().clone());
157        Poll::Pending
158    }
159
160    /// Returns the stream ID.
161    pub fn id(&self) -> Option<u64> {
162        self.0.id()
163    }
164
165    /// Splits the stream into a read stream and a write stream.
166    pub fn split(self) -> (Option<ReadStream>, Option<WriteStream>) {
167        match (
168            self.0.inner.shared.stream_type,
169            self.0.inner.shared.local_open,
170        ) {
171            (StreamType::Unidirectional, true) => (None, Some(WriteStream(self.0))),
172            (StreamType::Unidirectional, false) => (Some(ReadStream(self.0)), None),
173            (StreamType::Bidirectional, _) => {
174                (Some(ReadStream(self.0.clone())), Some(WriteStream(self.0)))
175            }
176        }
177    }
178
179    /// Poll to read from the stream into buf.
180    pub fn poll_read(
181        &mut self,
182        cx: &mut Context<'_>,
183        buf: &mut [u8],
184    ) -> Poll<Result<usize, ReadError>> {
185        self.0.poll_read(cx, buf)
186    }
187
188    /// Poll to read the next segment of data.
189    pub fn poll_read_chunk(
190        &self,
191        cx: &mut Context<'_>,
192    ) -> Poll<Result<Option<StreamRecvBuffer>, ReadError>> {
193        self.0.poll_read_chunk(cx)
194    }
195
196    /// Read the next segment of data.
197    pub fn read_chunk(&self) -> ReadChunk<'_> {
198        self.0.read_chunk()
199    }
200
201    /// Poll to write to the stream from buf.
202    pub fn poll_write(
203        &mut self,
204        cx: &mut Context<'_>,
205        buf: &[u8],
206        fin: bool,
207    ) -> Poll<Result<usize, WriteError>> {
208        self.0.poll_write(cx, buf, fin)
209    }
210
211    /// Poll to write a bytes to the stream directly.
212    pub fn poll_write_chunk(
213        &mut self,
214        cx: &mut Context<'_>,
215        chunk: &Bytes,
216        fin: bool,
217    ) -> Poll<Result<usize, WriteError>> {
218        self.0.poll_write_chunk(cx, chunk, fin)
219    }
220
221    /// Write a bytes to the stream directly.
222    pub fn write_chunk<'a>(&'a mut self, chunk: &'a Bytes, fin: bool) -> WriteChunk<'a> {
223        self.0.write_chunk(chunk, fin)
224    }
225
226    /// Poll to write the list of bytes to the stream directly.
227    pub fn poll_write_chunks(
228        &mut self,
229        cx: &mut Context<'_>,
230        chunks: &[Bytes],
231        fin: bool,
232    ) -> Poll<Result<usize, WriteError>> {
233        self.0.poll_write_chunks(cx, chunks, fin)
234    }
235
236    /// Write the list of bytes to the stream directly.
237    pub fn write_chunks<'a>(&'a mut self, chunks: &'a [Bytes], fin: bool) -> WriteChunks<'a> {
238        self.0.write_chunks(chunks, fin)
239    }
240
241    /// Poll to finish writing to the stream.
242    pub fn poll_finish_write(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WriteError>> {
243        self.0.poll_finish_write(cx)
244    }
245
246    /// Poll to abort writing to the stream.
247    pub fn poll_abort_write(
248        &mut self,
249        cx: &mut Context<'_>,
250        error_code: u64,
251    ) -> Poll<Result<(), WriteError>> {
252        self.0.poll_abort_write(cx, error_code)
253    }
254
255    /// Abort writing to the stream.
256    pub fn abort_write(&mut self, error_code: u64) -> Result<(), WriteError> {
257        self.0.abort_write(error_code)
258    }
259
260    /// Poll to abort reading from the stream.
261    pub fn poll_abort_read(
262        &mut self,
263        cx: &mut Context<'_>,
264        error_code: u64,
265    ) -> Poll<Result<(), ReadError>> {
266        self.0.poll_abort_read(cx, error_code)
267    }
268
269    /// Abort reading from the stream.
270    pub fn abort_read(&mut self, error_code: u64) -> Result<(), ReadError> {
271        self.0.abort_read(error_code)
272    }
273}
274
275/// A stream that can only be read from.
276#[derive(Debug)]
277pub struct ReadStream(Arc<StreamInstance>);
278
279impl ReadStream {
280    /// Returns the stream ID.
281    pub fn id(&self) -> Option<u64> {
282        self.0.id()
283    }
284
285    /// Poll to read from the stream into buf.
286    pub fn poll_read(
287        &mut self,
288        cx: &mut Context<'_>,
289        buf: &mut [u8],
290    ) -> Poll<Result<usize, ReadError>> {
291        self.0.poll_read(cx, buf)
292    }
293
294    /// Poll to read the next segment of data.
295    pub fn poll_read_chunk(
296        &self,
297        cx: &mut Context<'_>,
298    ) -> Poll<Result<Option<StreamRecvBuffer>, ReadError>> {
299        self.0.poll_read_chunk(cx)
300    }
301
302    /// Read the next segment of data.
303    pub fn read_chunk(&self) -> ReadChunk<'_> {
304        self.0.read_chunk()
305    }
306
307    /// Poll to abort reading from the stream.
308    pub fn poll_abort_read(
309        &mut self,
310        cx: &mut Context<'_>,
311        error_code: u64,
312    ) -> Poll<Result<(), ReadError>> {
313        self.0.poll_abort_read(cx, error_code)
314    }
315
316    /// Abort reading from the stream.
317    pub fn abort_read(&mut self, error_code: u64) -> Result<(), ReadError> {
318        self.0.abort_read(error_code)
319    }
320}
321
322/// A stream that can only be written to.
323#[derive(Debug)]
324pub struct WriteStream(Arc<StreamInstance>);
325
326impl WriteStream {
327    /// Returns the stream ID.
328    pub fn id(&self) -> Option<u64> {
329        self.0.id()
330    }
331
332    /// Poll to write to the stream from buf.
333    pub fn poll_write(
334        &mut self,
335        cx: &mut Context<'_>,
336        buf: &[u8],
337        fin: bool,
338    ) -> Poll<Result<usize, WriteError>> {
339        self.0.poll_write(cx, buf, fin)
340    }
341
342    /// Poll to write a bytes to the stream directly.
343    pub fn poll_write_chunk(
344        &mut self,
345        cx: &mut Context<'_>,
346        chunk: &Bytes,
347        fin: bool,
348    ) -> Poll<Result<usize, WriteError>> {
349        self.0.poll_write_chunk(cx, chunk, fin)
350    }
351
352    /// Write a bytes to the stream directly.
353    pub fn write_chunk<'a>(&'a mut self, chunk: &'a Bytes, fin: bool) -> WriteChunk<'a> {
354        self.0.write_chunk(chunk, fin)
355    }
356
357    /// Poll to write the list of bytes to the stream directly.
358    pub fn poll_write_chunks(
359        &mut self,
360        cx: &mut Context<'_>,
361        chunks: &[Bytes],
362        fin: bool,
363    ) -> Poll<Result<usize, WriteError>> {
364        self.0.poll_write_chunks(cx, chunks, fin)
365    }
366
367    /// Write the list of bytes to the stream directly.
368    pub fn write_chunks<'a>(&'a mut self, chunks: &'a [Bytes], fin: bool) -> WriteChunks<'a> {
369        self.0.write_chunks(chunks, fin)
370    }
371
372    /// Poll to finish writing to the stream.
373    pub fn poll_finish_write(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WriteError>> {
374        self.0.poll_finish_write(cx)
375    }
376
377    /// Poll to abort writing to the stream.
378    pub fn poll_abort_write(
379        &mut self,
380        cx: &mut Context<'_>,
381        error_code: u64,
382    ) -> Poll<Result<(), WriteError>> {
383        self.0.poll_abort_write(cx, error_code)
384    }
385
386    /// Abort writing to the stream.
387    pub fn abort_write(&mut self, error_code: u64) -> Result<(), WriteError> {
388        self.0.abort_write(error_code)
389    }
390}
391
392#[derive(Debug)]
393pub(crate) struct StreamInstance {
394    inner: Arc<StreamInner>,
395    msquic_stream: msquic::Stream,
396}
397
398impl StreamInstance {
399    pub(crate) fn id(&self) -> Option<u64> {
400        let id = { *self.inner.shared.id.read().unwrap() };
401        if id.is_some() {
402            id
403        } else {
404            let res = unsafe {
405                msquic::Api::get_param_auto::<u64>(
406                    self.msquic_stream.as_raw(),
407                    msquic::PARAM_STREAM_ID,
408                )
409            };
410            if let Ok(id) = res {
411                self.inner.shared.id.write().unwrap().replace(id);
412                Some(id)
413            } else {
414                None
415            }
416        }
417    }
418
419    fn poll_read(
420        self: &Arc<Self>,
421        cx: &mut Context<'_>,
422        buf: &mut [u8],
423    ) -> Poll<Result<usize, ReadError>> {
424        self.poll_read_generic(cx, |recv_buffers, read_complete_buffers| {
425            let mut read = 0;
426            let mut fin = false;
427            loop {
428                if read == buf.len() {
429                    return ReadStatus::Readable(read);
430                }
431
432                match recv_buffers
433                    .front_mut()
434                    .and_then(|x| x.get_bytes_upto_size(buf.len() - read))
435                {
436                    Some(slice) => {
437                        let len = slice.len();
438                        buf[read..read + len].copy_from_slice(slice);
439                        read += len;
440                    }
441                    None => {
442                        if let Some(mut recv_buffer) = recv_buffers.pop_front() {
443                            recv_buffer.set_stream(self.clone());
444                            fin = recv_buffer.fin();
445                            read_complete_buffers.push(recv_buffer);
446                            continue;
447                        } else {
448                            return (if read > 0 { Some(read) } else { None }, fin).into();
449                        }
450                    }
451                }
452            }
453        })
454        .map(|res| res.map(|x| x.unwrap_or(0)))
455    }
456
457    fn poll_read_chunk(
458        self: &Arc<Self>,
459        cx: &mut Context<'_>,
460    ) -> Poll<Result<Option<StreamRecvBuffer>, ReadError>> {
461        self.poll_read_generic(cx, |recv_buffers, _| {
462            recv_buffers
463                .pop_front()
464                .map(|mut recv_buffer| {
465                    let fin = recv_buffer.fin();
466                    recv_buffer.set_stream(self.clone());
467                    (Some(recv_buffer), fin)
468                })
469                .unwrap_or((None, false))
470                .into()
471        })
472    }
473
474    fn read_chunk(self: &Arc<Self>) -> ReadChunk<'_> {
475        ReadChunk { stream: self }
476    }
477
478    fn poll_read_generic<T, U>(
479        &self,
480        cx: &mut Context<'_>,
481        mut read_fn: T,
482    ) -> Poll<Result<Option<U>, ReadError>>
483    where
484        T: FnMut(&mut VecDeque<StreamRecvBuffer>, &mut Vec<StreamRecvBuffer>) -> ReadStatus<U>,
485    {
486        let res;
487        let mut read_complete_buffers = Vec::new();
488        {
489            let mut exclusive = self.inner.exclusive.lock().unwrap();
490            match exclusive.recv_state {
491                StreamRecvState::Closed => {
492                    return Poll::Ready(Err(ReadError::Closed));
493                }
494                StreamRecvState::Start => {
495                    exclusive.start_waiters.push(cx.waker().clone());
496                    return Poll::Pending;
497                }
498                StreamRecvState::StartComplete => {}
499                StreamRecvState::Shutdown => {
500                    return Poll::Ready(Ok(None));
501                }
502                StreamRecvState::ShutdownComplete => {
503                    if let Some(conn_error) = &exclusive.conn_error {
504                        return Poll::Ready(Err(ReadError::ConnectionLost(conn_error.clone())));
505                    } else if let Some(error_code) = &exclusive.recv_error_code {
506                        return Poll::Ready(Err(ReadError::Reset(*error_code)));
507                    } else {
508                        return Poll::Ready(Ok(None));
509                    }
510                }
511            }
512
513            let status = read_fn(&mut exclusive.recv_buffers, &mut read_complete_buffers);
514
515            res = match status {
516                ReadStatus::Readable(read) | ReadStatus::Blocked(Some(read)) => {
517                    Poll::Ready(Ok(Some(read)))
518                }
519                ReadStatus::Finished(read) => {
520                    exclusive.recv_state = StreamRecvState::Shutdown;
521                    Poll::Ready(Ok(read))
522                }
523                ReadStatus::Blocked(None) => {
524                    exclusive.read_waiters.push(cx.waker().clone());
525                    Poll::Pending
526                }
527            };
528        }
529        res
530    }
531
532    fn poll_write(
533        &self,
534        cx: &mut Context<'_>,
535        buf: &[u8],
536        fin: bool,
537    ) -> Poll<Result<usize, WriteError>> {
538        self.poll_write_generic(cx, |write_buf| {
539            let written = write_buf.put_slice(buf);
540            if written == buf.len() && !fin {
541                WriteStatus::Writable(written)
542            } else {
543                (Some(written), fin).into()
544            }
545        })
546        .map(|res| res.map(|x| x.unwrap_or(0)))
547    }
548
549    fn poll_write_chunk(
550        &self,
551        cx: &mut Context<'_>,
552        chunk: &Bytes,
553        fin: bool,
554    ) -> Poll<Result<usize, WriteError>> {
555        self.poll_write_generic(cx, |write_buf| {
556            let written = write_buf.put_zerocopy(chunk);
557            if written == chunk.len() && !fin {
558                WriteStatus::Writable(written)
559            } else {
560                (Some(written), fin).into()
561            }
562        })
563        .map(|res| res.map(|x| x.unwrap_or(0)))
564    }
565
566    fn write_chunk<'a>(&'a self, chunk: &'a Bytes, fin: bool) -> WriteChunk<'a> {
567        WriteChunk {
568            stream: self,
569            chunk,
570            fin,
571        }
572    }
573
574    fn poll_write_chunks(
575        &self,
576        cx: &mut Context<'_>,
577        chunks: &[Bytes],
578        fin: bool,
579    ) -> Poll<Result<usize, WriteError>> {
580        self.poll_write_generic(cx, |write_buf| {
581            let (mut total_len, mut total_written) = (0, 0);
582            for buf in chunks {
583                total_len += buf.len();
584                total_written += write_buf.put_zerocopy(buf);
585            }
586            if total_written == total_len && !fin {
587                WriteStatus::Writable(total_written)
588            } else {
589                (Some(total_written), fin).into()
590            }
591        })
592        .map(|res| res.map(|x| x.unwrap_or(0)))
593    }
594
595    fn write_chunks<'a>(&'a self, chunks: &'a [Bytes], fin: bool) -> WriteChunks<'a> {
596        WriteChunks {
597            stream: self,
598            chunks,
599            fin,
600        }
601    }
602
603    fn poll_write_generic<T, U>(
604        &self,
605        _cx: &mut Context<'_>,
606        mut write_fn: T,
607    ) -> Poll<Result<Option<U>, WriteError>>
608    where
609        T: FnMut(&mut WriteBuffer) -> WriteStatus<U>,
610    {
611        let mut exclusive = self.inner.exclusive.lock().unwrap();
612        match exclusive.send_state {
613            StreamSendState::Closed => {
614                return Poll::Ready(Err(WriteError::Closed));
615            }
616            StreamSendState::Start => {
617                exclusive.start_waiters.push(_cx.waker().clone());
618                return Poll::Pending;
619            }
620            StreamSendState::StartComplete => {}
621            StreamSendState::Shutdown => {
622                return Poll::Ready(Err(WriteError::Finished));
623            }
624            StreamSendState::ShutdownComplete => {
625                if let Some(conn_error) = &exclusive.conn_error {
626                    return Poll::Ready(Err(WriteError::ConnectionLost(conn_error.clone())));
627                } else if let Some(error_code) = &exclusive.send_error_code {
628                    return Poll::Ready(Err(WriteError::Stopped(*error_code)));
629                } else {
630                    return Poll::Ready(Err(WriteError::Finished));
631                }
632            }
633        }
634        let mut write_buf = exclusive.write_pool.pop().unwrap_or(WriteBuffer::new());
635        let status = write_fn(&mut write_buf);
636        let buffers = unsafe {
637            let (data, len) = write_buf.get_buffers();
638            std::slice::from_raw_parts(data, len)
639        };
640        match status {
641            WriteStatus::Writable(val) | WriteStatus::Blocked(Some(val)) => {
642                match unsafe {
643                    self.msquic_stream.send(
644                        buffers,
645                        msquic::SendFlags::NONE,
646                        write_buf.into_raw() as *const _,
647                    )
648                }
649                .map_err(WriteError::OtherError)
650                {
651                    Ok(()) => Poll::Ready(Ok(Some(val))),
652                    Err(e) => Poll::Ready(Err(e)),
653                }
654            }
655            WriteStatus::Blocked(None) => unreachable!(),
656            WriteStatus::Finished(val) => {
657                match unsafe {
658                    self.msquic_stream.send(
659                        buffers,
660                        msquic::SendFlags::FIN,
661                        write_buf.into_raw() as *const _,
662                    )
663                }
664                .map_err(WriteError::OtherError)
665                {
666                    Ok(()) => {
667                        exclusive.send_state = StreamSendState::Shutdown;
668                        Poll::Ready(Ok(val))
669                    }
670                    Err(e) => Poll::Ready(Err(e)),
671                }
672            }
673        }
674    }
675
676    fn poll_finish_write(&self, cx: &mut Context<'_>) -> Poll<Result<(), WriteError>> {
677        let mut exclusive = self.inner.exclusive.lock().unwrap();
678        match exclusive.send_state {
679            StreamSendState::Start => {
680                exclusive.start_waiters.push(cx.waker().clone());
681                return Poll::Pending;
682            }
683            StreamSendState::StartComplete => {
684                match self
685                    .msquic_stream
686                    .shutdown(msquic::StreamShutdownFlags::GRACEFUL, 0)
687                    .map_err(WriteError::OtherError)
688                {
689                    Ok(()) => {
690                        exclusive.send_state = StreamSendState::Shutdown;
691                    }
692                    Err(e) => return Poll::Ready(Err(e)),
693                }
694            }
695            StreamSendState::Shutdown => {}
696            StreamSendState::ShutdownComplete => {
697                if let Some(conn_error) = &exclusive.conn_error {
698                    return Poll::Ready(Err(WriteError::ConnectionLost(conn_error.clone())));
699                } else if let Some(error_code) = &exclusive.send_error_code {
700                    return Poll::Ready(Err(WriteError::Stopped(*error_code)));
701                } else {
702                    return Poll::Ready(Ok(()));
703                }
704            }
705            _ => {
706                return Poll::Ready(Err(WriteError::Closed));
707            }
708        }
709        exclusive.write_shutdown_waiters.push(cx.waker().clone());
710        Poll::Pending
711    }
712
713    fn poll_abort_write(
714        &self,
715        cx: &mut Context<'_>,
716        error_code: u64,
717    ) -> Poll<Result<(), WriteError>> {
718        let mut exclusive = self.inner.exclusive.lock().unwrap();
719        match exclusive.send_state {
720            StreamSendState::Start => {
721                exclusive.start_waiters.push(cx.waker().clone());
722                return Poll::Pending;
723            }
724            StreamSendState::StartComplete => {
725                match self
726                    .msquic_stream
727                    .shutdown(msquic::StreamShutdownFlags::ABORT_SEND, error_code)
728                    .map_err(WriteError::OtherError)
729                {
730                    Ok(()) => {
731                        exclusive.send_state = StreamSendState::Shutdown;
732                    }
733                    Err(e) => return Poll::Ready(Err(e)),
734                }
735            }
736            StreamSendState::Shutdown => {}
737            StreamSendState::ShutdownComplete => {
738                if let Some(conn_error) = &exclusive.conn_error {
739                    return Poll::Ready(Err(WriteError::ConnectionLost(conn_error.clone())));
740                } else if let Some(error_code) = &exclusive.send_error_code {
741                    return Poll::Ready(Err(WriteError::Stopped(*error_code)));
742                } else {
743                    return Poll::Ready(Ok(()));
744                }
745            }
746            _ => {
747                return Poll::Ready(Err(WriteError::Closed));
748            }
749        }
750        exclusive.write_shutdown_waiters.push(cx.waker().clone());
751        Poll::Pending
752    }
753
754    fn abort_write(&self, error_code: u64) -> Result<(), WriteError> {
755        let mut exclusive = self.inner.exclusive.lock().unwrap();
756        match exclusive.send_state {
757            StreamSendState::StartComplete => {
758                self.msquic_stream
759                    .shutdown(msquic::StreamShutdownFlags::ABORT_SEND, error_code)
760                    .map_err(WriteError::OtherError)?;
761                exclusive.send_state = StreamSendState::Shutdown;
762                Ok(())
763            }
764            _ => Err(WriteError::Closed),
765        }
766    }
767
768    fn poll_abort_read(
769        &self,
770        cx: &mut Context<'_>,
771        error_code: u64,
772    ) -> Poll<Result<(), ReadError>> {
773        let mut exclusive = self.inner.exclusive.lock().unwrap();
774        match exclusive.recv_state {
775            StreamRecvState::Start => {
776                exclusive.start_waiters.push(cx.waker().clone());
777                Poll::Pending
778            }
779            StreamRecvState::StartComplete => {
780                match self
781                    .msquic_stream
782                    .shutdown(msquic::StreamShutdownFlags::ABORT_RECEIVE, error_code)
783                    .map_err(ReadError::OtherError)
784                {
785                    Ok(()) => {
786                        exclusive.recv_state = StreamRecvState::ShutdownComplete;
787                        exclusive
788                            .read_waiters
789                            .drain(..)
790                            .for_each(|waker| waker.wake());
791                        Poll::Ready(Ok(()))
792                    }
793                    Err(e) => Poll::Ready(Err(e)),
794                }
795            }
796            StreamRecvState::ShutdownComplete => {
797                if let Some(conn_error) = &exclusive.conn_error {
798                    Poll::Ready(Err(ReadError::ConnectionLost(conn_error.clone())))
799                } else if let Some(error_code) = &exclusive.recv_error_code {
800                    Poll::Ready(Err(ReadError::Reset(*error_code)))
801                } else {
802                    Poll::Ready(Ok(()))
803                }
804            }
805            _ => Poll::Ready(Err(ReadError::Closed)),
806        }
807    }
808
809    fn abort_read(&self, error_code: u64) -> Result<(), ReadError> {
810        let mut exclusive = self.inner.exclusive.lock().unwrap();
811        match exclusive.recv_state {
812            StreamRecvState::StartComplete => {
813                self.msquic_stream
814                    .shutdown(msquic::StreamShutdownFlags::ABORT_RECEIVE, error_code)
815                    .map_err(ReadError::OtherError)?;
816                exclusive.recv_state = StreamRecvState::ShutdownComplete;
817            }
818            _ => {
819                return Err(ReadError::Closed);
820            }
821        }
822        Ok(())
823    }
824
825    pub(crate) fn read_complete(&self, buffer: &StreamRecvBuffer) {
826        let buffer_range = buffer.range();
827        trace!(
828            "StreamInstance({:p}) read complete offset={} len={}",
829            self,
830            buffer_range.start,
831            buffer_range.end - buffer_range.start
832        );
833
834        let mut exclusive = self.inner.exclusive.lock().unwrap();
835        if !buffer_range.is_empty() {
836            exclusive.read_complete_map.insert(buffer_range);
837        }
838        let complete_len = if let Some(complete_range) = exclusive.read_complete_map.first() {
839            trace!(
840                "StreamInstance({:p}) complete read offset={} len={}",
841                self,
842                complete_range.start,
843                complete_range.end - complete_range.start
844            );
845
846            if complete_range.start == 0 && exclusive.read_complete_cursor < complete_range.end {
847                let complete_len = complete_range.end - exclusive.read_complete_cursor;
848                exclusive.read_complete_cursor = complete_range.end;
849                Some(complete_len)
850            } else if complete_range.start == 0
851                && exclusive.read_complete_cursor == complete_range.end
852                && buffer.offset() == complete_range.end
853                && buffer.is_empty()
854                && buffer.fin()
855            {
856                Some(0)
857            } else {
858                None
859            }
860        } else if buffer.is_empty() && buffer.fin() {
861            Some(0)
862        } else {
863            None
864        };
865        if let Some(complete_len) = complete_len {
866            trace!(
867                "StreamInstance({:p}) call receive_complete len={}",
868                self,
869                complete_len
870            );
871            self.msquic_stream.receive_complete(complete_len as u64);
872        }
873    }
874}
875
876impl Drop for StreamInstance {
877    fn drop(&mut self) {
878        trace!("StreamInstance({:p}) dropping", self);
879        let exclusive = self.inner.exclusive.lock().unwrap();
880        match exclusive.state {
881            StreamState::Start | StreamState::StartComplete => {
882                trace!(
883                    "StreamInstance(Inner: {:p}) shutdown while dropping",
884                    self.inner
885                );
886                // let _ = self.msquic_stream.shutdown(
887                //     msquic::StreamShutdownFlags::ABORT_SEND
888                //         | msquic::StreamShutdownFlags::ABORT_RECEIVE
889                //         | msquic::StreamShutdownFlags::IMMEDIATE,
890                //     0,
891                // );
892            }
893            _ => {}
894        }
895    }
896}
897
898#[derive(Debug)]
899struct StreamInner {
900    exclusive: Mutex<StreamInnerExclusive>,
901    pub(crate) shared: StreamInnerShared,
902}
903
904struct StreamInnerExclusive {
905    state: StreamState,
906    start_status: Option<msquic::Status>,
907    recv_state: StreamRecvState,
908    recv_buffers: VecDeque<StreamRecvBuffer>,
909    recv_len: usize,
910    read_complete_map: RangeSet<usize>,
911    read_complete_cursor: usize,
912    send_state: StreamSendState,
913    write_pool: Vec<WriteBuffer>,
914    recv_error_code: Option<u64>,
915    send_error_code: Option<u64>,
916    conn_error: Option<ConnectionError>,
917    start_waiters: Vec<Waker>,
918    read_waiters: Vec<Waker>,
919    write_shutdown_waiters: Vec<Waker>,
920}
921
922struct StreamInnerShared {
923    stream_type: StreamType,
924    local_open: bool,
925    id: RwLock<Option<u64>>,
926}
927
928#[derive(Debug, PartialEq)]
929enum StreamState {
930    Open,
931    Start,
932    StartComplete,
933    ShutdownComplete,
934}
935
936#[derive(Debug, PartialEq)]
937enum StreamRecvState {
938    Closed,
939    Start,
940    StartComplete,
941    Shutdown,
942    ShutdownComplete,
943}
944
945#[derive(Debug, PartialEq)]
946enum StreamSendState {
947    Closed,
948    Start,
949    StartComplete,
950    Shutdown,
951    ShutdownComplete,
952}
953
954impl StreamInner {
955    fn new(
956        stream_type: StreamType,
957        send_state: StreamSendState,
958        recv_state: StreamRecvState,
959        local_open: bool,
960    ) -> Self {
961        Self {
962            exclusive: Mutex::new(StreamInnerExclusive {
963                state: StreamState::Open,
964                start_status: None,
965                recv_state,
966                recv_buffers: VecDeque::new(),
967                recv_len: 0,
968                read_complete_map: RangeSet::new(),
969                read_complete_cursor: 0,
970                send_state,
971                write_pool: Vec::new(),
972                recv_error_code: None,
973                send_error_code: None,
974                conn_error: None,
975                start_waiters: Vec::new(),
976                read_waiters: Vec::new(),
977                write_shutdown_waiters: Vec::new(),
978            }),
979            shared: StreamInnerShared {
980                local_open,
981                id: RwLock::new(None),
982                stream_type,
983            },
984        }
985    }
986
987    fn handle_event_start_complete(
988        &self,
989        status: msquic::Status,
990        id: u64,
991        peer_accepted: bool,
992    ) -> Result<(), msquic::Status> {
993        if status.is_ok() {
994            self.shared.id.write().unwrap().replace(id);
995        }
996        trace!(
997            "StreamInner({:p}, id={:?}) start complete status={:?}, peer_accepted={}, id={}",
998            self,
999            self.shared.id.read(),
1000            status,
1001            peer_accepted,
1002            id,
1003        );
1004        let mut exclusive = self.exclusive.lock().unwrap();
1005        exclusive.start_status = Some(status.clone());
1006        if status.is_ok() && peer_accepted {
1007            exclusive.state = StreamState::StartComplete;
1008            if self.shared.stream_type == StreamType::Bidirectional {
1009                exclusive.recv_state = StreamRecvState::StartComplete;
1010            }
1011            exclusive.send_state = StreamSendState::StartComplete;
1012        }
1013
1014        if status.0 == msquic::StatusCode::QUIC_STATUS_STREAM_LIMIT_REACHED.into() || peer_accepted
1015        {
1016            exclusive
1017                .start_waiters
1018                .drain(..)
1019                .for_each(|waker| waker.wake());
1020        }
1021        Ok(())
1022    }
1023
1024    fn handle_event_receive(
1025        &self,
1026        absolute_offset: u64,
1027        total_buffer_length: &mut u64,
1028        buffers: &[msquic::BufferRef],
1029        flags: msquic::ReceiveFlags,
1030    ) -> Result<(), msquic::Status> {
1031        trace!(
1032            "StreamInner({:p}, id={:?}) Receive {} offsets {} bytes, fin {}",
1033            self,
1034            self.shared.id.read(),
1035            absolute_offset,
1036            total_buffer_length,
1037            (flags & msquic::ReceiveFlags::FIN) == msquic::ReceiveFlags::FIN
1038        );
1039
1040        let recv_buffer = StreamRecvBuffer::new(
1041            absolute_offset as usize,
1042            buffers,
1043            (flags & msquic::ReceiveFlags::FIN) == msquic::ReceiveFlags::FIN,
1044        );
1045
1046        let mut exclusive = self.exclusive.lock().unwrap();
1047        exclusive.recv_len += *total_buffer_length as usize;
1048        exclusive.recv_buffers.push_back(recv_buffer);
1049        exclusive
1050            .read_waiters
1051            .drain(..)
1052            .for_each(|waker| waker.wake());
1053        Err(msquic::StatusCode::QUIC_STATUS_PENDING.into())
1054    }
1055
1056    fn handle_event_send_complete(
1057        &self,
1058        _canceled: bool,
1059        client_context: *const c_void,
1060    ) -> Result<(), msquic::Status> {
1061        trace!(
1062            "StreamInner({:p}, id={:?}) Send complete",
1063            self,
1064            self.shared.id.read()
1065        );
1066
1067        let mut write_buf = unsafe { WriteBuffer::from_raw(client_context) };
1068        let mut exclusive = self.exclusive.lock().unwrap();
1069        write_buf.reset();
1070        exclusive.write_pool.push(write_buf);
1071        Ok(())
1072    }
1073
1074    fn handle_event_peer_send_shutdown(&self) -> Result<(), msquic::Status> {
1075        trace!(
1076            "StreamInner({:p}, id={:?}) Peer send shutdown",
1077            self,
1078            self.shared.id.read()
1079        );
1080        let mut exclusive = self.exclusive.lock().unwrap();
1081        exclusive.recv_state = StreamRecvState::ShutdownComplete;
1082        exclusive
1083            .read_waiters
1084            .drain(..)
1085            .for_each(|waker| waker.wake());
1086        Ok(())
1087    }
1088
1089    fn handle_event_peer_send_aborted(&self, error_code: u64) -> Result<(), msquic::Status> {
1090        trace!(
1091            "StreamInner({:p}, id={:?}) Peer send aborted",
1092            self,
1093            self.shared.id.read()
1094        );
1095        let mut exclusive = self.exclusive.lock().unwrap();
1096        exclusive.recv_state = StreamRecvState::ShutdownComplete;
1097        exclusive.recv_error_code = Some(error_code);
1098        exclusive
1099            .read_waiters
1100            .drain(..)
1101            .for_each(|waker| waker.wake());
1102        Ok(())
1103    }
1104
1105    fn handle_event_peer_receive_aborted(&self, error_code: u64) -> Result<(), msquic::Status> {
1106        trace!(
1107            "StreamInner({:p}, id={:?}) Peer receive aborted",
1108            self,
1109            self.shared.id.read()
1110        );
1111        let mut exclusive = self.exclusive.lock().unwrap();
1112        exclusive.send_state = StreamSendState::ShutdownComplete;
1113        exclusive.send_error_code = Some(error_code);
1114        exclusive
1115            .write_shutdown_waiters
1116            .drain(..)
1117            .for_each(|waker| waker.wake());
1118        Ok(())
1119    }
1120
1121    fn handle_event_send_shutdown_complete(&self, _graceful: bool) -> Result<(), msquic::Status> {
1122        trace!(
1123            "StreamInner({:p}, id={:?}) Send shutdown complete",
1124            self,
1125            self.shared.id.read()
1126        );
1127        let mut exclusive = self.exclusive.lock().unwrap();
1128        exclusive.send_state = StreamSendState::ShutdownComplete;
1129        exclusive
1130            .write_shutdown_waiters
1131            .drain(..)
1132            .for_each(|waker| waker.wake());
1133        Ok(())
1134    }
1135
1136    #[allow(clippy::too_many_arguments)]
1137    fn handle_event_shutdown_complete(
1138        &self,
1139        msquic_stream: msquic::StreamRef,
1140        connection_shutdown: bool,
1141        app_close_in_progress: bool,
1142        connection_shutdown_by_app: bool,
1143        connection_closed_remotely: bool,
1144        connection_error_code: u64,
1145        connection_close_status: msquic::Status,
1146    ) -> Result<(), msquic::Status> {
1147        trace!(
1148            "StreamInner({:p}, id={:?}) Shutdown complete",
1149            self,
1150            self.shared.id.read()
1151        );
1152        {
1153            let mut exclusive = self.exclusive.lock().unwrap();
1154
1155            if !exclusive.recv_buffers.is_empty() {
1156                trace!(
1157                    "StreamInner({:p}) read complete {}",
1158                    self,
1159                    exclusive.recv_len - exclusive.read_complete_cursor
1160                );
1161                exclusive.recv_buffers.clear();
1162                if !app_close_in_progress {
1163                    msquic_stream.receive_complete(
1164                        (exclusive.recv_len - exclusive.read_complete_cursor) as u64,
1165                    );
1166                }
1167            }
1168
1169            exclusive.state = StreamState::ShutdownComplete;
1170            exclusive.recv_state = StreamRecvState::ShutdownComplete;
1171            exclusive.send_state = StreamSendState::ShutdownComplete;
1172            if connection_shutdown {
1173                match (connection_shutdown_by_app, connection_closed_remotely) {
1174                    (true, true) => {
1175                        exclusive.conn_error =
1176                            Some(ConnectionError::ShutdownByPeer(connection_error_code));
1177                    }
1178                    (true, false) => {
1179                        exclusive.conn_error = Some(ConnectionError::ShutdownByLocal);
1180                    }
1181                    (false, true) | (false, false) => {
1182                        exclusive.conn_error = Some(ConnectionError::ShutdownByTransport(
1183                            connection_close_status,
1184                            connection_error_code,
1185                        ));
1186                    }
1187                }
1188            }
1189            exclusive
1190                .start_waiters
1191                .drain(..)
1192                .for_each(|waker| waker.wake());
1193            exclusive
1194                .read_waiters
1195                .drain(..)
1196                .for_each(|waker| waker.wake());
1197        }
1198        // unsafe {
1199        //     Arc::from_raw(self as *const _);
1200        // }
1201        Ok(())
1202    }
1203
1204    fn handle_event_ideal_send_buffer_size(&self, _byte_count: u64) -> Result<(), msquic::Status> {
1205        trace!(
1206            "StreamInner({:p}, id={:?}) Ideal send buffer size",
1207            self,
1208            self.shared.id.read()
1209        );
1210        Ok(())
1211    }
1212
1213    fn handle_event_peer_accepted(&self) -> Result<(), msquic::Status> {
1214        trace!(
1215            "StreamInner({:p}, id={:?}) Peer accepted",
1216            self,
1217            self.shared.id.read()
1218        );
1219        let mut exclusive = self.exclusive.lock().unwrap();
1220        exclusive.state = StreamState::StartComplete;
1221        if self.shared.stream_type == StreamType::Bidirectional {
1222            exclusive.recv_state = StreamRecvState::StartComplete;
1223        }
1224        exclusive.send_state = StreamSendState::StartComplete;
1225        exclusive
1226            .start_waiters
1227            .drain(..)
1228            .for_each(|waker| waker.wake());
1229        Ok(())
1230    }
1231
1232    fn callback_handler_impl(
1233        &self,
1234        msquic_stream: msquic::StreamRef,
1235        ev: msquic::StreamEvent,
1236    ) -> Result<(), msquic::Status> {
1237        match ev {
1238            msquic::StreamEvent::StartComplete {
1239                status,
1240                id,
1241                peer_accepted,
1242            } => self.handle_event_start_complete(status, id, peer_accepted),
1243            msquic::StreamEvent::Receive {
1244                absolute_offset,
1245                total_buffer_length,
1246                buffers,
1247                flags,
1248            } => self.handle_event_receive(absolute_offset, total_buffer_length, buffers, flags),
1249            msquic::StreamEvent::SendComplete {
1250                cancelled,
1251                client_context,
1252            } => self.handle_event_send_complete(cancelled, client_context),
1253            msquic::StreamEvent::PeerSendShutdown => self.handle_event_peer_send_shutdown(),
1254            msquic::StreamEvent::PeerSendAborted { error_code } => {
1255                self.handle_event_peer_send_aborted(error_code)
1256            }
1257            msquic::StreamEvent::PeerReceiveAborted { error_code } => {
1258                self.handle_event_peer_receive_aborted(error_code)
1259            }
1260            msquic::StreamEvent::SendShutdownComplete { graceful } => {
1261                self.handle_event_send_shutdown_complete(graceful)
1262            }
1263            msquic::StreamEvent::ShutdownComplete {
1264                connection_shutdown,
1265                app_close_in_progress,
1266                connection_shutdown_by_app,
1267                connection_closed_remotely,
1268                connection_error_code,
1269                connection_close_status,
1270            } => self.handle_event_shutdown_complete(
1271                msquic_stream,
1272                connection_shutdown,
1273                app_close_in_progress,
1274                connection_shutdown_by_app,
1275                connection_closed_remotely,
1276                connection_error_code,
1277                connection_close_status,
1278            ),
1279            msquic::StreamEvent::IdealSendBufferSize { byte_count } => {
1280                self.handle_event_ideal_send_buffer_size(byte_count)
1281            }
1282            msquic::StreamEvent::PeerAccepted => self.handle_event_peer_accepted(),
1283            _ => {
1284                trace!("StreamInner({:p}) Other callback", self);
1285                Ok(())
1286            }
1287        }
1288    }
1289}
1290
1291impl Drop for StreamInner {
1292    fn drop(&mut self) {
1293        trace!("StreamInner({:p}) dropping", self);
1294    }
1295}
1296
1297impl fmt::Debug for StreamInnerExclusive {
1298    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1299        f.debug_struct("Exclusive")
1300            .field("state", &self.state)
1301            .field("recv_state", &self.recv_state)
1302            .field("send_state", &self.send_state)
1303            .finish()
1304    }
1305}
1306
1307impl fmt::Debug for StreamInnerShared {
1308    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1309        f.debug_struct("Shared")
1310            .field("type", &self.stream_type)
1311            .field("id", &self.id)
1312            .finish()
1313    }
1314}
1315pub struct ReadChunk<'a> {
1316    stream: &'a Arc<StreamInstance>,
1317}
1318
1319impl Future for ReadChunk<'_> {
1320    type Output = Result<Option<StreamRecvBuffer>, ReadError>;
1321
1322    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1323        self.stream.poll_read_chunk(cx)
1324    }
1325}
1326
1327pub struct WriteChunk<'a> {
1328    stream: &'a StreamInstance,
1329    chunk: &'a Bytes,
1330    fin: bool,
1331}
1332
1333impl Future for WriteChunk<'_> {
1334    type Output = Result<usize, WriteError>;
1335
1336    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1337        self.stream.poll_write_chunk(cx, self.chunk, self.fin)
1338    }
1339}
1340
1341pub struct WriteChunks<'a> {
1342    stream: &'a StreamInstance,
1343    chunks: &'a [Bytes],
1344    fin: bool,
1345}
1346
1347impl Future for WriteChunks<'_> {
1348    type Output = Result<usize, WriteError>;
1349
1350    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1351        self.stream.poll_write_chunks(cx, self.chunks, self.fin)
1352    }
1353}
1354
1355#[cfg(feature = "tokio")]
1356impl tokio::io::AsyncRead for Stream {
1357    fn poll_read(
1358        self: Pin<&mut Self>,
1359        cx: &mut Context<'_>,
1360        buf: &mut tokio::io::ReadBuf<'_>,
1361    ) -> Poll<std::io::Result<()>> {
1362        let len = ready!(Self::poll_read(self.get_mut(), cx, buf.initialized_mut()))?;
1363        buf.set_filled(len);
1364        Poll::Ready(Ok(()))
1365    }
1366}
1367
1368#[cfg(feature = "tokio")]
1369impl tokio::io::AsyncWrite for Stream {
1370    fn poll_write(
1371        self: Pin<&mut Self>,
1372        cx: &mut Context<'_>,
1373        buf: &[u8],
1374    ) -> Poll<std::io::Result<usize>> {
1375        let len = ready!(Self::poll_write(self.get_mut(), cx, buf, false))?;
1376        Poll::Ready(Ok(len))
1377    }
1378
1379    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<std::io::Result<()>> {
1380        Poll::Ready(Ok(()))
1381    }
1382
1383    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<std::io::Result<()>> {
1384        ready!(Self::poll_finish_write(self.get_mut(), cx))?;
1385        Poll::Ready(Ok(()))
1386    }
1387}
1388
1389#[cfg(feature = "tokio")]
1390impl tokio::io::AsyncRead for ReadStream {
1391    fn poll_read(
1392        self: Pin<&mut Self>,
1393        cx: &mut Context<'_>,
1394        buf: &mut tokio::io::ReadBuf<'_>,
1395    ) -> Poll<std::io::Result<()>> {
1396        let len = ready!(Self::poll_read(self.get_mut(), cx, buf.initialized_mut()))?;
1397        buf.set_filled(len);
1398        Poll::Ready(Ok(()))
1399    }
1400}
1401
1402#[cfg(feature = "tokio")]
1403impl tokio::io::AsyncWrite for WriteStream {
1404    fn poll_write(
1405        self: Pin<&mut Self>,
1406        cx: &mut Context<'_>,
1407        buf: &[u8],
1408    ) -> Poll<std::io::Result<usize>> {
1409        let len = ready!(Self::poll_write(self.get_mut(), cx, buf, false))?;
1410        Poll::Ready(Ok(len))
1411    }
1412
1413    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<std::io::Result<()>> {
1414        Poll::Ready(Ok(()))
1415    }
1416
1417    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<std::io::Result<()>> {
1418        ready!(Self::poll_finish_write(self.get_mut(), cx))?;
1419        Poll::Ready(Ok(()))
1420    }
1421}
1422
1423impl futures_io::AsyncRead for Stream {
1424    fn poll_read(
1425        self: Pin<&mut Self>,
1426        cx: &mut Context<'_>,
1427        buf: &mut [u8],
1428    ) -> Poll<std::io::Result<usize>> {
1429        let len = ready!(Self::poll_read(self.get_mut(), cx, buf))?;
1430        Poll::Ready(Ok(len))
1431    }
1432}
1433
1434impl futures_io::AsyncWrite for Stream {
1435    fn poll_write(
1436        self: Pin<&mut Self>,
1437        cx: &mut Context<'_>,
1438        buf: &[u8],
1439    ) -> Poll<std::io::Result<usize>> {
1440        let len = ready!(Self::poll_write(self.get_mut(), cx, buf, false))?;
1441        Poll::Ready(Ok(len))
1442    }
1443
1444    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1445        Poll::Ready(Ok(()))
1446    }
1447
1448    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1449        ready!(Self::poll_finish_write(self.get_mut(), cx))?;
1450        Poll::Ready(Ok(()))
1451    }
1452}
1453
1454impl futures_io::AsyncRead for ReadStream {
1455    fn poll_read(
1456        self: Pin<&mut Self>,
1457        cx: &mut Context<'_>,
1458        buf: &mut [u8],
1459    ) -> Poll<std::io::Result<usize>> {
1460        let len = ready!(Self::poll_read(self.get_mut(), cx, buf))?;
1461        Poll::Ready(Ok(len))
1462    }
1463}
1464
1465impl futures_io::AsyncWrite for WriteStream {
1466    fn poll_write(
1467        self: Pin<&mut Self>,
1468        cx: &mut Context<'_>,
1469        buf: &[u8],
1470    ) -> Poll<std::io::Result<usize>> {
1471        let len = ready!(Self::poll_write(self.get_mut(), cx, buf, false))?;
1472        Poll::Ready(Ok(len))
1473    }
1474
1475    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1476        Poll::Ready(Ok(()))
1477    }
1478
1479    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1480        ready!(Self::poll_finish_write(self.get_mut(), cx))?;
1481        Poll::Ready(Ok(()))
1482    }
1483}
1484
1485enum ReadStatus<T> {
1486    Readable(T),
1487    Finished(Option<T>),
1488    Blocked(Option<T>),
1489}
1490
1491impl<T> From<(Option<T>, bool)> for ReadStatus<T> {
1492    fn from(status: (Option<T>, bool)) -> Self {
1493        match status {
1494            (read, true) => Self::Finished(read),
1495            (read, false) => Self::Blocked(read),
1496        }
1497    }
1498}
1499
1500enum WriteStatus<T> {
1501    Writable(T),
1502    Finished(Option<T>),
1503    Blocked(Option<T>),
1504}
1505
1506impl<T> From<(Option<T>, bool)> for WriteStatus<T> {
1507    fn from(status: (Option<T>, bool)) -> Self {
1508        match status {
1509            (write, true) => Self::Finished(write),
1510            (write, false) => Self::Blocked(write),
1511        }
1512    }
1513}
1514
1515#[derive(Debug, Error, Clone)]
1516pub enum StartError {
1517    #[error("connection not started yet")]
1518    ConnectionNotStarted,
1519    #[error("reach stream count limit")]
1520    LimitReached,
1521    #[error("connection lost")]
1522    ConnectionLost(#[from] ConnectionError),
1523    #[error("other error: status {0:?}")]
1524    OtherError(msquic::Status),
1525}
1526
1527#[derive(Debug, Error, Clone)]
1528pub enum ReadError {
1529    #[error("stream not opened for reading")]
1530    Closed,
1531    #[error("stream reset by peer: error {0}")]
1532    Reset(u64),
1533    #[error("connection lost")]
1534    ConnectionLost(#[from] ConnectionError),
1535    #[error("other error: status {0:?}")]
1536    OtherError(msquic::Status),
1537}
1538
1539impl From<ReadError> for std::io::Error {
1540    fn from(e: ReadError) -> Self {
1541        let kind = match e {
1542            ReadError::Closed => std::io::ErrorKind::NotConnected,
1543            ReadError::Reset(_) => std::io::ErrorKind::ConnectionReset,
1544            ReadError::ConnectionLost(ConnectionError::ConnectionClosed) => {
1545                std::io::ErrorKind::NotConnected
1546            }
1547            ReadError::ConnectionLost(_) => std::io::ErrorKind::ConnectionAborted,
1548            ReadError::OtherError(_) => std::io::ErrorKind::Other,
1549        };
1550        Self::new(kind, e)
1551    }
1552}
1553
1554#[derive(Debug, Error, Clone)]
1555pub enum WriteError {
1556    #[error("stream not opened for writing")]
1557    Closed,
1558    #[error("stream finished")]
1559    Finished,
1560    #[error("stream stopped by peer: error {0}")]
1561    Stopped(u64),
1562    #[error("connection lost")]
1563    ConnectionLost(#[from] ConnectionError),
1564    #[error("other error: status {0:?}")]
1565    OtherError(msquic::Status),
1566}
1567
1568impl From<WriteError> for std::io::Error {
1569    fn from(e: WriteError) -> Self {
1570        let kind = match e {
1571            WriteError::Closed
1572            | WriteError::Finished
1573            | WriteError::ConnectionLost(ConnectionError::ConnectionClosed) => {
1574                std::io::ErrorKind::NotConnected
1575            }
1576            WriteError::Stopped(_) => std::io::ErrorKind::ConnectionReset,
1577            WriteError::ConnectionLost(_) => std::io::ErrorKind::ConnectionAborted,
1578            WriteError::OtherError(_) => std::io::ErrorKind::Other,
1579        };
1580        Self::new(kind, e)
1581    }
1582}