1use std::{
15 cmp,
16 future::Future,
17 pin::Pin,
18 sync::OnceLock,
19 task::{Context, Poll},
20};
21
22use bitflags::bitflags;
23use parking_lot::RwLock;
24use tokio::{
25 io::{AsyncRead, AsyncWrite, ReadBuf},
26 sync::{broadcast, mpsc, oneshot},
27};
28use tokio_util::bytes::{Buf, Bytes};
29
30use crate::{
31 alloc::StreamId,
32 error::Error,
33 frame::Frame,
34 msg::{self, Message},
35};
36
37type WriteFrameFuture = Pin<Box<dyn Future<Output = Result<usize, Error>> + Send + Sync>>;
39
40pub struct Stream {
46 stream_id: StreamId,
47 status: RwLock<StreamFlags>,
48 read_buf: Bytes,
49 cur_write_fut: Option<WriteFrameFuture>,
50
51 _shutdown_rx: broadcast::Receiver<()>,
52
53 msg_tx: mpsc::UnboundedSender<Message>,
54 frame_rx: mpsc::Receiver<Frame>,
55 remote_fin_rx: oneshot::Receiver<()>,
56 close_tx: mpsc::UnboundedSender<StreamId>,
57
58 close_once: OnceLock<()>,
59}
60
61bitflags! {
62 struct StreamFlags: u8 {
66 const R = 1 << 0;
68 const W = 1 << 1;
70
71 const V = Self::R.bits() | Self::W.bits();
73 }
74}
75
76impl Stream {
77 pub fn close(&self) {
83 self.close_once.get_or_init(|| {
84 self.deny_rw(StreamFlags::W);
85 let _ = msg::send_fin(self.msg_tx.clone(), self.stream_id);
86 });
87 }
88
89 pub(crate) fn new(
91 stream_id: StreamId,
92 shutdown_rx: broadcast::Receiver<()>,
93 msg_tx: mpsc::UnboundedSender<Message>,
94 frame_rx: mpsc::Receiver<Frame>,
95 close_tx: mpsc::UnboundedSender<StreamId>,
96 remote_fin_rx: oneshot::Receiver<()>,
97 ) -> Self {
98 Self {
99 stream_id,
100 status: RwLock::new(StreamFlags::V),
101 read_buf: Bytes::new(),
102 cur_write_fut: None,
103 _shutdown_rx: shutdown_rx,
104 msg_tx,
105 frame_rx,
106 close_tx,
107 remote_fin_rx,
108 close_once: OnceLock::new(),
109 }
110 }
111
112 fn deny_rw(&self, flags: StreamFlags) {
118 let mut status_guard = self.status.write();
119 *status_guard -= flags & StreamFlags::V;
120
121 if !status_guard.contains(StreamFlags::V) {
122 let _ = self.close_tx.send(self.stream_id);
123 }
124 }
125}
126
127impl Drop for Stream {
128 fn drop(&mut self) {
129 if self.status.read().contains(StreamFlags::W) {
130 let _ = msg::send_fin(self.msg_tx.clone(), self.stream_id);
131 }
132 self.deny_rw(StreamFlags::V);
133 }
134}
135
136impl AsyncRead for Stream {
137 fn poll_read(
138 self: Pin<&mut Self>,
139 cx: &mut Context<'_>,
140 buf: &mut ReadBuf<'_>,
141 ) -> Poll<std::io::Result<()>> {
142 let this = self.get_mut();
143 loop {
144 if !this.status.read().contains(StreamFlags::R) {
145 return Poll::Ready(Err(std::io::Error::new(
146 std::io::ErrorKind::BrokenPipe,
147 "stream has been closed",
148 )));
149 }
150
151 if !this.read_buf.is_empty() {
152 let to_copy = cmp::min(this.read_buf.len(), buf.remaining());
153 buf.put_slice(&this.read_buf[..to_copy]);
154 this.read_buf.advance(to_copy);
155 return Poll::Ready(Ok(()));
156 }
157
158 match Pin::new(&mut this.frame_rx).poll_recv(cx) {
159 Poll::Ready(Some(frame)) => {
160 this.read_buf = Bytes::from(frame.data);
161 continue;
162 }
163 Poll::Pending => {
164 match Pin::new(&mut this.remote_fin_rx).poll(cx) {
165 Poll::Ready(_) => {
166 this.deny_rw(StreamFlags::R);
167 return Poll::Ready(Ok(()));
168 }
169 Poll::Pending => {}
170 }
171 return Poll::Pending;
172 }
173 Poll::Ready(None) => {
174 return Poll::Ready(Ok(()));
175 }
176 }
177 }
178 }
179}
180
181impl AsyncWrite for Stream {
182 fn poll_write(
183 mut self: Pin<&mut Self>,
184 cx: &mut Context<'_>,
185 buf: &[u8],
186 ) -> Poll<Result<usize, std::io::Error>> {
187 if !self.status.read().contains(StreamFlags::W) {
188 return Poll::Ready(Err(std::io::Error::new(
189 std::io::ErrorKind::BrokenPipe,
190 "stream is closed for writing",
191 )));
192 }
193
194 if self.cur_write_fut.is_none() {
195 let msg_tx = self.msg_tx.clone();
196 let stream_id = self.stream_id;
197 let data = buf.to_vec();
198
199 self.cur_write_fut =
200 Some(
201 Box::pin(async move { msg::send_psh(msg_tx, stream_id, &data).await })
202 as WriteFrameFuture,
203 );
204 return Poll::Ready(Ok(buf.len()));
205 }
206
207 match self.cur_write_fut.as_mut().unwrap().as_mut().poll(cx) {
208 Poll::Pending => Poll::Pending,
209 Poll::Ready(Ok(_)) => {
210 let msg_tx = self.msg_tx.clone();
211 let stream_id = self.stream_id;
212 let data = buf.to_vec();
213
214 self.cur_write_fut =
215 Some(
216 Box::pin(async move { msg::send_psh(msg_tx, stream_id, &data).await })
217 as WriteFrameFuture,
218 );
219 Poll::Ready(Ok(buf.len()))
220 }
221 Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e.to_string()))),
222 }
223 }
224
225 fn poll_flush(
226 mut self: Pin<&mut Self>,
227 cx: &mut Context<'_>,
228 ) -> Poll<Result<(), std::io::Error>> {
229 if let Some(fut) = self.cur_write_fut.as_mut() {
230 match fut.as_mut().poll(cx) {
231 Poll::Pending => Poll::Pending,
232 Poll::Ready(Ok(_)) => {
233 self.cur_write_fut = None;
234 Poll::Ready(Ok(()))
235 }
236 Poll::Ready(Err(e)) => {
237 self.cur_write_fut = None;
238 Poll::Ready(Err(std::io::Error::other(e.to_string())))
239 }
240 }
241 } else {
242 Poll::Ready(Ok(()))
243 }
244 }
245
246 fn poll_shutdown(
247 mut self: Pin<&mut Self>,
248 cx: &mut Context<'_>,
249 ) -> Poll<Result<(), std::io::Error>> {
250 let res = match &mut self.cur_write_fut {
251 Some(fut) => match fut.as_mut().poll(cx) {
252 Poll::Pending => Poll::Pending,
253 Poll::Ready(Ok(_)) => {
254 self.cur_write_fut = None;
255 Poll::Ready(Ok(()))
256 }
257 Poll::Ready(Err(e)) => {
258 self.cur_write_fut = None;
259 Poll::Ready(Err(std::io::Error::other(e.to_string())))
260 }
261 },
262 None => Poll::Ready(Ok(())),
263 };
264 if !res.is_pending() {
265 self.close();
266 }
267 res
268 }
269}