1use std::{
15 cmp,
16 future::Future,
17 pin::Pin,
18 task::{Context, Poll},
19};
20
21use bitflags::bitflags;
22use parking_lot::RwLock;
23use tokio::{
24 io::{AsyncRead, AsyncWrite, ReadBuf},
25 sync::{broadcast, mpsc, oneshot},
26};
27use tokio_util::bytes::{Buf, Bytes};
28
29use crate::{
30 alloc::StreamId,
31 error::Error,
32 frame::Frame,
33 msg::{self, Message},
34};
35
36type WriteFrameFuture = Pin<Box<dyn Future<Output = Result<usize, Error>> + Send + Sync>>;
38
39pub struct Stream {
45 stream_id: StreamId,
46 status: RwLock<StreamFlags>,
47 read_buf: Bytes,
48 cur_write_fut: Option<WriteFrameFuture>,
49
50 _shutdown_rx: broadcast::Receiver<()>,
51
52 msg_tx: mpsc::Sender<Message>,
53 frame_rx: mpsc::Receiver<Frame>,
54 remote_fin_rx: oneshot::Receiver<()>,
55 close_tx: mpsc::UnboundedSender<StreamId>,
56}
57
58bitflags! {
59 struct StreamFlags: u8 {
63 const R = 1 << 0;
65 const W = 1 << 1;
67
68 const V = Self::R.bits() | Self::W.bits();
70 }
71}
72
73impl Stream {
74 pub async fn close(&self) {
80 let _ = msg::send_fin(self.msg_tx.clone(), self.stream_id).await;
81 self.deny_rw(StreamFlags::W);
82 }
83
84 pub(crate) fn new(
86 stream_id: StreamId,
87 shutdown_rx: broadcast::Receiver<()>,
88 msg_tx: mpsc::Sender<Message>,
89 frame_rx: mpsc::Receiver<Frame>,
90 close_tx: mpsc::UnboundedSender<StreamId>,
91 remote_fin_rx: oneshot::Receiver<()>,
92 ) -> Self {
93 Self {
94 stream_id,
95 status: RwLock::new(StreamFlags::V),
96 read_buf: Bytes::new(),
97 cur_write_fut: None,
98 _shutdown_rx: shutdown_rx,
99 msg_tx,
100 frame_rx,
101 close_tx,
102 remote_fin_rx,
103 }
104 }
105
106 fn deny_rw(&self, flags: StreamFlags) {
112 let mut status_guard = self.status.write();
113 *status_guard -= flags & StreamFlags::V;
114
115 if !status_guard.contains(StreamFlags::V) {
116 let _ = self.close_tx.send(self.stream_id);
117 }
118 }
119}
120
121impl Drop for Stream {
122 fn drop(&mut self) {
123 self.deny_rw(StreamFlags::V);
124 let msg_tx = self.msg_tx.clone();
125 let stream_id = self.stream_id;
126 tokio::spawn(async move {
127 let _ = msg::send_fin(msg_tx, stream_id).await;
128 });
129 }
130}
131
132impl AsyncRead for Stream {
133 fn poll_read(
134 self: Pin<&mut Self>,
135 cx: &mut Context<'_>,
136 buf: &mut ReadBuf<'_>,
137 ) -> Poll<std::io::Result<()>> {
138 let this = self.get_mut();
139 loop {
140 if !this.status.read().contains(StreamFlags::R) {
141 return Poll::Ready(Err(std::io::Error::new(
142 std::io::ErrorKind::BrokenPipe,
143 "stream has been closed",
144 )));
145 }
146
147 if !this.read_buf.is_empty() {
148 let to_copy = cmp::min(this.read_buf.len(), buf.remaining());
149 buf.put_slice(&this.read_buf[..to_copy]);
150 this.read_buf.advance(to_copy);
151 return Poll::Ready(Ok(()));
152 }
153
154 match Pin::new(&mut this.remote_fin_rx).poll(cx) {
155 Poll::Ready(_) => {
156 this.deny_rw(StreamFlags::R);
157 return Poll::Ready(Ok(()));
158 }
159 Poll::Pending => {}
160 }
161
162 match Pin::new(&mut this.frame_rx).poll_recv(cx) {
163 Poll::Ready(Some(frame)) => {
164 this.read_buf = Bytes::from(frame.data);
165 continue;
166 }
167 Poll::Pending => {
168 return Poll::Pending;
169 }
170 Poll::Ready(None) => {
171 unreachable!()
172 }
173 }
174 }
175 }
176}
177
178impl AsyncWrite for Stream {
179 fn poll_write(
180 mut self: Pin<&mut Self>,
181 cx: &mut Context<'_>,
182 buf: &[u8],
183 ) -> Poll<Result<usize, std::io::Error>> {
184 if !self.status.read().contains(StreamFlags::W) {
185 return Poll::Ready(Err(std::io::Error::new(
186 std::io::ErrorKind::BrokenPipe,
187 "stream is closed for writing",
188 )));
189 }
190
191 if self.cur_write_fut.is_none() {
192 let msg_tx = self.msg_tx.clone();
193 let stream_id = self.stream_id;
194 let data = buf.to_vec();
195
196 self.cur_write_fut =
197 Some(
198 Box::pin(async move { msg::send_psh(msg_tx, stream_id, &data).await })
199 as WriteFrameFuture,
200 );
201 return Poll::Ready(Ok(buf.len()));
202 }
203
204 match self.cur_write_fut.as_mut().unwrap().as_mut().poll(cx) {
205 Poll::Pending => Poll::Pending,
206 Poll::Ready(Ok(_)) => {
207 let msg_tx = self.msg_tx.clone();
208 let stream_id = self.stream_id;
209 let data = buf.to_vec();
210
211 self.cur_write_fut =
212 Some(
213 Box::pin(async move { msg::send_psh(msg_tx, stream_id, &data).await })
214 as WriteFrameFuture,
215 );
216 Poll::Ready(Ok(buf.len()))
217 }
218 Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e.to_string()))),
219 }
220 }
221
222 fn poll_flush(
223 mut self: Pin<&mut Self>,
224 cx: &mut Context<'_>,
225 ) -> Poll<Result<(), std::io::Error>> {
226 if let Some(fut) = self.cur_write_fut.as_mut() {
227 match fut.as_mut().poll(cx) {
228 Poll::Pending => Poll::Pending,
229 Poll::Ready(Ok(_)) => {
230 self.cur_write_fut = None;
231 Poll::Ready(Ok(()))
232 }
233 Poll::Ready(Err(e)) => {
234 self.cur_write_fut = None;
235 Poll::Ready(Err(std::io::Error::other(e.to_string())))
236 }
237 }
238 } else {
239 Poll::Ready(Ok(()))
240 }
241 }
242
243 fn poll_shutdown(
244 self: Pin<&mut Self>,
245 _cx: &mut Context<'_>,
246 ) -> Poll<Result<(), std::io::Error>> {
247 Poll::Ready(Ok(()))
248 }
249}