quiche_tokio/
stream.rs

1use super::connection;
2use std::future::Future;
3
4#[derive(Copy, Clone)]
5pub struct StreamID(pub(super) u64);
6
7impl std::fmt::Debug for StreamID {
8    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
9        f.write_fmt(format_args!(
10            "StreamID({}, bidi={}, {})",
11            self.stream_id(),
12            self.is_bidi(),
13            if self.is_server() { "server" } else { "client" }
14        ))
15    }
16}
17impl StreamID {
18    pub fn new(stream_id: u64, bidi: bool, is_server: bool) -> Self {
19        let client_flag = if is_server { 1 } else { 0 };
20        let bidi_flag = if bidi { 0 } else { 2 };
21        Self(stream_id << 2 | bidi_flag | client_flag)
22    }
23
24    pub fn full_stream_id(&self) -> u64 {
25        self.0
26    }
27
28    pub fn stream_id(&self) -> u64 {
29        self.0 >> 2
30    }
31
32    pub fn is_server(&self) -> bool {
33        self.0 & 1 == 1
34    }
35
36    pub fn is_bidi(&self) -> bool {
37        self.0 & 2 == 0
38    }
39
40    pub fn can_read(&self, is_server: bool) -> bool {
41        self.is_bidi() || (is_server && (self.0 & 1 == 0)) || (!is_server && (self.0 & 1 == 1))
42    }
43
44    pub fn can_write(&self, is_server: bool) -> bool {
45        self.is_bidi() || (is_server && (self.0 & 1 == 1)) || (!is_server && (self.0 & 1 == 0))
46    }
47}
48
49type ReadOutput = std::io::Result<Vec<u8>>;
50type WriteOutput = std::io::Result<usize>;
51type StreamFut<T> = Option<std::pin::Pin<Box<dyn Future<Output = T> + Send + Sync>>>;
52pub struct Stream {
53    is_server: bool,
54    stream_id: StreamID,
55    shared_state: std::sync::Arc<connection::SharedConnectionState>,
56    control_tx: tokio::sync::mpsc::Sender<connection::Control>,
57    async_read: StreamFut<ReadOutput>,
58    async_write: StreamFut<WriteOutput>,
59    async_shutdown: StreamFut<WriteOutput>,
60    read_fin: std::sync::Arc<std::sync::atomic::AtomicBool>,
61}
62
63impl Stream {
64    pub(crate) fn new(
65        is_server: bool,
66        stream_id: StreamID,
67        shared_state: std::sync::Arc<connection::SharedConnectionState>,
68        control_tx: tokio::sync::mpsc::Sender<connection::Control>,
69    ) -> Self {
70        Self {
71            is_server,
72            stream_id,
73            shared_state,
74            control_tx,
75            async_read: None,
76            async_write: None,
77            async_shutdown: None,
78            read_fin: std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)),
79        }
80    }
81
82    pub fn stream_id(&self) -> StreamID {
83        self.stream_id
84    }
85}
86
87impl Clone for Stream {
88    fn clone(&self) -> Self {
89        Self {
90            is_server: self.is_server,
91            stream_id: self.stream_id,
92            shared_state: self.shared_state.clone(),
93            control_tx: self.control_tx.clone(),
94            async_read: None,
95            async_write: None,
96            async_shutdown: None,
97            read_fin: self.read_fin.clone(),
98        }
99    }
100}
101
102impl std::fmt::Debug for Stream {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        f.debug_struct("Stream")
105            .field("is_server", &self.is_server)
106            .field("stream_id", &self.stream_id)
107            .field("shared_state", &self.shared_state)
108            .finish_non_exhaustive()
109    }
110}
111
112impl Stream {
113    pub fn is_bidi(&self) -> bool {
114        self.stream_id.is_bidi()
115    }
116
117    pub fn can_read(&self) -> bool {
118        self.stream_id.can_read(self.is_server)
119    }
120
121    pub fn can_write(&self) -> bool {
122        self.stream_id.can_write(self.is_server)
123    }
124
125    async fn _read(
126        stream_id: StreamID,
127        shared_state: std::sync::Arc<connection::SharedConnectionState>,
128        control_tx: tokio::sync::mpsc::Sender<connection::Control>,
129        len: usize,
130        read_fin: std::sync::Arc<std::sync::atomic::AtomicBool>,
131    ) -> ReadOutput {
132        let (tx, rx) = tokio::sync::oneshot::channel();
133        if control_tx
134            .send(connection::Control::StreamRecv {
135                stream_id: stream_id.0,
136                len,
137                resp: tx,
138            })
139            .await
140            .is_err()
141        {
142            warn!(
143                "Connection error: {:?}",
144                shared_state.connection_error.read().await
145            );
146            return Err(std::io::ErrorKind::ConnectionReset.into());
147        }
148        match rx.await {
149            Ok(Ok((r, fin))) => {
150                if fin {
151                    read_fin.store(true, std::sync::atomic::Ordering::Relaxed);
152                }
153                Ok(r)
154            }
155            Ok(Err(e)) => Err(e.into()),
156            Err(_) => {
157                warn!(
158                    "Connection error: {:?}",
159                    shared_state.connection_error.read().await
160                );
161                Err(std::io::ErrorKind::ConnectionReset.into())
162            }
163        }
164    }
165
166    async fn _write(
167        stream_id: StreamID,
168        shared_state: std::sync::Arc<connection::SharedConnectionState>,
169        control_tx: tokio::sync::mpsc::Sender<connection::Control>,
170        data: Vec<u8>,
171        fin: bool,
172    ) -> WriteOutput {
173        let (tx, rx) = tokio::sync::oneshot::channel();
174        if control_tx
175            .send(connection::Control::StreamSend {
176                stream_id: stream_id.0,
177                data,
178                fin,
179                resp: tx,
180            })
181            .await
182            .is_err()
183        {
184            warn!(
185                "Connection error: {:?}",
186                shared_state.connection_error.read().await
187            );
188            return Err(std::io::ErrorKind::ConnectionReset.into());
189        }
190        match rx.await {
191            Ok(r) => r.map_err(|e| e.into()),
192            Err(_) => {
193                warn!(
194                    "Connection error: {:?}",
195                    shared_state.connection_error.read().await
196                );
197                Err(std::io::ErrorKind::ConnectionReset.into())
198            }
199        }
200    }
201}
202
203impl tokio::io::AsyncRead for Stream {
204    fn poll_read(
205        mut self: std::pin::Pin<&mut Self>,
206        cx: &mut std::task::Context<'_>,
207        buf: &mut tokio::io::ReadBuf<'_>,
208    ) -> std::task::Poll<std::io::Result<()>> {
209        if !self.can_read() {
210            return std::task::Poll::Ready(Err(std::io::Error::new(
211                std::io::ErrorKind::Unsupported,
212                "Write-only stream",
213            )));
214        }
215        if self.read_fin.load(std::sync::atomic::Ordering::Acquire) {
216            return std::task::Poll::Ready(Ok(()));
217        }
218        if let Some(fut) = &mut self.async_read {
219            return match fut.as_mut().poll(cx) {
220                std::task::Poll::Pending => std::task::Poll::Pending,
221                std::task::Poll::Ready(Ok(r)) => {
222                    self.async_read = None;
223                    buf.put_slice(&r);
224                    std::task::Poll::Ready(Ok(()))
225                }
226                std::task::Poll::Ready(Err(e)) => {
227                    self.async_read = None;
228                    std::task::Poll::Ready(Err(e))
229                }
230            };
231        }
232        let mut fut = Box::pin(Self::_read(
233            self.stream_id,
234            self.shared_state.clone(),
235            self.control_tx.clone(),
236            buf.remaining(),
237            self.read_fin.clone(),
238        ));
239        match fut.as_mut().poll(cx) {
240            std::task::Poll::Pending => {
241                self.async_read.replace(fut);
242                std::task::Poll::Pending
243            }
244            std::task::Poll::Ready(Ok(r)) => {
245                buf.put_slice(&r);
246                std::task::Poll::Ready(Ok(()))
247            }
248            std::task::Poll::Ready(Err(e)) => std::task::Poll::Ready(Err(e)),
249        }
250    }
251}
252
253impl tokio::io::AsyncWrite for Stream {
254    fn poll_write(
255        mut self: std::pin::Pin<&mut Self>,
256        cx: &mut std::task::Context<'_>,
257        buf: &[u8],
258    ) -> std::task::Poll<std::io::Result<usize>> {
259        if !self.can_write() {
260            return std::task::Poll::Ready(Err(std::io::Error::new(
261                std::io::ErrorKind::Unsupported,
262                "Read-only stream",
263            )));
264        }
265        if let Some(fut) = &mut self.async_write {
266            return match fut.as_mut().poll(cx) {
267                std::task::Poll::Pending => std::task::Poll::Pending,
268                std::task::Poll::Ready(r) => {
269                    self.async_write = None;
270                    std::task::Poll::Ready(r)
271                }
272            };
273        }
274        let mut fut = Box::pin(Self::_write(
275            self.stream_id,
276            self.shared_state.clone(),
277            self.control_tx.clone(),
278            buf.to_vec(),
279            false,
280        ));
281        match fut.as_mut().poll(cx) {
282            std::task::Poll::Pending => {
283                self.async_write.replace(fut);
284                std::task::Poll::Pending
285            }
286            std::task::Poll::Ready(r) => std::task::Poll::Ready(r),
287        }
288    }
289
290    fn poll_flush(
291        self: std::pin::Pin<&mut Self>,
292        _cx: &mut std::task::Context<'_>,
293    ) -> std::task::Poll<std::io::Result<()>> {
294        std::task::Poll::Ready(Ok(()))
295    }
296
297    fn poll_shutdown(
298        mut self: std::pin::Pin<&mut Self>,
299        cx: &mut std::task::Context<'_>,
300    ) -> std::task::Poll<std::io::Result<()>> {
301        if let Some(fut) = &mut self.async_shutdown {
302            return match fut.as_mut().poll(cx) {
303                std::task::Poll::Pending => std::task::Poll::Pending,
304                std::task::Poll::Ready(r) => {
305                    self.async_shutdown = None;
306                    std::task::Poll::Ready(r.map(|_| ()))
307                }
308            };
309        }
310        let mut fut = Box::pin(Self::_write(
311            self.stream_id,
312            self.shared_state.clone(),
313            self.control_tx.clone(),
314            Vec::new(),
315            true,
316        ));
317        match fut.as_mut().poll(cx) {
318            std::task::Poll::Pending => {
319                self.async_shutdown.replace(fut);
320                std::task::Poll::Pending
321            }
322            std::task::Poll::Ready(r) => std::task::Poll::Ready(r.map(|_| ())),
323        }
324    }
325}