1use super::connection;
2use std::future::Future;
3
4#[derive(Copy, Clone)]
5pub struct StreamID(pub(super) u64);
6
7impl std::fmt::Debug for StreamID {
8 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
9 f.write_fmt(format_args!(
10 "StreamID({}, bidi={}, {})",
11 self.stream_id(),
12 self.is_bidi(),
13 if self.is_server() { "server" } else { "client" }
14 ))
15 }
16}
17impl StreamID {
18 pub fn new(stream_id: u64, bidi: bool, is_server: bool) -> Self {
19 let client_flag = if is_server { 1 } else { 0 };
20 let bidi_flag = if bidi { 0 } else { 2 };
21 Self(stream_id << 2 | bidi_flag | client_flag)
22 }
23
24 pub fn full_stream_id(&self) -> u64 {
25 self.0
26 }
27
28 pub fn stream_id(&self) -> u64 {
29 self.0 >> 2
30 }
31
32 pub fn is_server(&self) -> bool {
33 self.0 & 1 == 1
34 }
35
36 pub fn is_bidi(&self) -> bool {
37 self.0 & 2 == 0
38 }
39
40 pub fn can_read(&self, is_server: bool) -> bool {
41 self.is_bidi() || (is_server && (self.0 & 1 == 0)) || (!is_server && (self.0 & 1 == 1))
42 }
43
44 pub fn can_write(&self, is_server: bool) -> bool {
45 self.is_bidi() || (is_server && (self.0 & 1 == 1)) || (!is_server && (self.0 & 1 == 0))
46 }
47}
48
49type ReadOutput = std::io::Result<Vec<u8>>;
50type WriteOutput = std::io::Result<usize>;
51type StreamFut<T> = Option<std::pin::Pin<Box<dyn Future<Output = T> + Send + Sync>>>;
52pub struct Stream {
53 is_server: bool,
54 stream_id: StreamID,
55 shared_state: std::sync::Arc<connection::SharedConnectionState>,
56 control_tx: tokio::sync::mpsc::Sender<connection::Control>,
57 async_read: StreamFut<ReadOutput>,
58 async_write: StreamFut<WriteOutput>,
59 async_shutdown: StreamFut<WriteOutput>,
60 read_fin: std::sync::Arc<std::sync::atomic::AtomicBool>,
61}
62
63impl Stream {
64 pub(crate) fn new(
65 is_server: bool,
66 stream_id: StreamID,
67 shared_state: std::sync::Arc<connection::SharedConnectionState>,
68 control_tx: tokio::sync::mpsc::Sender<connection::Control>,
69 ) -> Self {
70 Self {
71 is_server,
72 stream_id,
73 shared_state,
74 control_tx,
75 async_read: None,
76 async_write: None,
77 async_shutdown: None,
78 read_fin: std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)),
79 }
80 }
81
82 pub fn stream_id(&self) -> StreamID {
83 self.stream_id
84 }
85}
86
87impl Clone for Stream {
88 fn clone(&self) -> Self {
89 Self {
90 is_server: self.is_server,
91 stream_id: self.stream_id,
92 shared_state: self.shared_state.clone(),
93 control_tx: self.control_tx.clone(),
94 async_read: None,
95 async_write: None,
96 async_shutdown: None,
97 read_fin: self.read_fin.clone(),
98 }
99 }
100}
101
102impl std::fmt::Debug for Stream {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 f.debug_struct("Stream")
105 .field("is_server", &self.is_server)
106 .field("stream_id", &self.stream_id)
107 .field("shared_state", &self.shared_state)
108 .finish_non_exhaustive()
109 }
110}
111
112impl Stream {
113 pub fn is_bidi(&self) -> bool {
114 self.stream_id.is_bidi()
115 }
116
117 pub fn can_read(&self) -> bool {
118 self.stream_id.can_read(self.is_server)
119 }
120
121 pub fn can_write(&self) -> bool {
122 self.stream_id.can_write(self.is_server)
123 }
124
125 async fn _read(
126 stream_id: StreamID,
127 shared_state: std::sync::Arc<connection::SharedConnectionState>,
128 control_tx: tokio::sync::mpsc::Sender<connection::Control>,
129 len: usize,
130 read_fin: std::sync::Arc<std::sync::atomic::AtomicBool>,
131 ) -> ReadOutput {
132 let (tx, rx) = tokio::sync::oneshot::channel();
133 if control_tx
134 .send(connection::Control::StreamRecv {
135 stream_id: stream_id.0,
136 len,
137 resp: tx,
138 })
139 .await
140 .is_err()
141 {
142 warn!(
143 "Connection error: {:?}",
144 shared_state.connection_error.read().await
145 );
146 return Err(std::io::ErrorKind::ConnectionReset.into());
147 }
148 match rx.await {
149 Ok(Ok((r, fin))) => {
150 if fin {
151 read_fin.store(true, std::sync::atomic::Ordering::Relaxed);
152 }
153 Ok(r)
154 }
155 Ok(Err(e)) => Err(e.into()),
156 Err(_) => {
157 warn!(
158 "Connection error: {:?}",
159 shared_state.connection_error.read().await
160 );
161 Err(std::io::ErrorKind::ConnectionReset.into())
162 }
163 }
164 }
165
166 async fn _write(
167 stream_id: StreamID,
168 shared_state: std::sync::Arc<connection::SharedConnectionState>,
169 control_tx: tokio::sync::mpsc::Sender<connection::Control>,
170 data: Vec<u8>,
171 fin: bool,
172 ) -> WriteOutput {
173 let (tx, rx) = tokio::sync::oneshot::channel();
174 if control_tx
175 .send(connection::Control::StreamSend {
176 stream_id: stream_id.0,
177 data,
178 fin,
179 resp: tx,
180 })
181 .await
182 .is_err()
183 {
184 warn!(
185 "Connection error: {:?}",
186 shared_state.connection_error.read().await
187 );
188 return Err(std::io::ErrorKind::ConnectionReset.into());
189 }
190 match rx.await {
191 Ok(r) => r.map_err(|e| e.into()),
192 Err(_) => {
193 warn!(
194 "Connection error: {:?}",
195 shared_state.connection_error.read().await
196 );
197 Err(std::io::ErrorKind::ConnectionReset.into())
198 }
199 }
200 }
201}
202
203impl tokio::io::AsyncRead for Stream {
204 fn poll_read(
205 mut self: std::pin::Pin<&mut Self>,
206 cx: &mut std::task::Context<'_>,
207 buf: &mut tokio::io::ReadBuf<'_>,
208 ) -> std::task::Poll<std::io::Result<()>> {
209 if !self.can_read() {
210 return std::task::Poll::Ready(Err(std::io::Error::new(
211 std::io::ErrorKind::Unsupported,
212 "Write-only stream",
213 )));
214 }
215 if self.read_fin.load(std::sync::atomic::Ordering::Acquire) {
216 return std::task::Poll::Ready(Ok(()));
217 }
218 if let Some(fut) = &mut self.async_read {
219 return match fut.as_mut().poll(cx) {
220 std::task::Poll::Pending => std::task::Poll::Pending,
221 std::task::Poll::Ready(Ok(r)) => {
222 self.async_read = None;
223 buf.put_slice(&r);
224 std::task::Poll::Ready(Ok(()))
225 }
226 std::task::Poll::Ready(Err(e)) => {
227 self.async_read = None;
228 std::task::Poll::Ready(Err(e))
229 }
230 };
231 }
232 let mut fut = Box::pin(Self::_read(
233 self.stream_id,
234 self.shared_state.clone(),
235 self.control_tx.clone(),
236 buf.remaining(),
237 self.read_fin.clone(),
238 ));
239 match fut.as_mut().poll(cx) {
240 std::task::Poll::Pending => {
241 self.async_read.replace(fut);
242 std::task::Poll::Pending
243 }
244 std::task::Poll::Ready(Ok(r)) => {
245 buf.put_slice(&r);
246 std::task::Poll::Ready(Ok(()))
247 }
248 std::task::Poll::Ready(Err(e)) => std::task::Poll::Ready(Err(e)),
249 }
250 }
251}
252
253impl tokio::io::AsyncWrite for Stream {
254 fn poll_write(
255 mut self: std::pin::Pin<&mut Self>,
256 cx: &mut std::task::Context<'_>,
257 buf: &[u8],
258 ) -> std::task::Poll<std::io::Result<usize>> {
259 if !self.can_write() {
260 return std::task::Poll::Ready(Err(std::io::Error::new(
261 std::io::ErrorKind::Unsupported,
262 "Read-only stream",
263 )));
264 }
265 if let Some(fut) = &mut self.async_write {
266 return match fut.as_mut().poll(cx) {
267 std::task::Poll::Pending => std::task::Poll::Pending,
268 std::task::Poll::Ready(r) => {
269 self.async_write = None;
270 std::task::Poll::Ready(r)
271 }
272 };
273 }
274 let mut fut = Box::pin(Self::_write(
275 self.stream_id,
276 self.shared_state.clone(),
277 self.control_tx.clone(),
278 buf.to_vec(),
279 false,
280 ));
281 match fut.as_mut().poll(cx) {
282 std::task::Poll::Pending => {
283 self.async_write.replace(fut);
284 std::task::Poll::Pending
285 }
286 std::task::Poll::Ready(r) => std::task::Poll::Ready(r),
287 }
288 }
289
290 fn poll_flush(
291 self: std::pin::Pin<&mut Self>,
292 _cx: &mut std::task::Context<'_>,
293 ) -> std::task::Poll<std::io::Result<()>> {
294 std::task::Poll::Ready(Ok(()))
295 }
296
297 fn poll_shutdown(
298 mut self: std::pin::Pin<&mut Self>,
299 cx: &mut std::task::Context<'_>,
300 ) -> std::task::Poll<std::io::Result<()>> {
301 if let Some(fut) = &mut self.async_shutdown {
302 return match fut.as_mut().poll(cx) {
303 std::task::Poll::Pending => std::task::Poll::Pending,
304 std::task::Poll::Ready(r) => {
305 self.async_shutdown = None;
306 std::task::Poll::Ready(r.map(|_| ()))
307 }
308 };
309 }
310 let mut fut = Box::pin(Self::_write(
311 self.stream_id,
312 self.shared_state.clone(),
313 self.control_tx.clone(),
314 Vec::new(),
315 true,
316 ));
317 match fut.as_mut().poll(cx) {
318 std::task::Poll::Pending => {
319 self.async_shutdown.replace(fut);
320 std::task::Poll::Pending
321 }
322 std::task::Poll::Ready(r) => std::task::Poll::Ready(r.map(|_| ())),
323 }
324 }
325}