kvarn_quinn/
send_stream.rs

1use std::{
2    future::Future,
3    io,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use bytes::Bytes;
9use proto::{ConnectionError, FinishError, StreamId, Written};
10use thiserror::Error;
11use tokio::sync::oneshot;
12
13use crate::{
14    connection::{ConnectionRef, UnknownStream},
15    VarInt,
16};
17
18/// A stream that can only be used to send data
19///
20/// If dropped, streams that haven't been explicitly [`reset()`] will continue to (re)transmit
21/// previously written data until it has been fully acknowledged or the connection is closed.
22///
23/// [`reset()`]: SendStream::reset
24#[derive(Debug)]
25pub struct SendStream {
26    conn: ConnectionRef,
27    stream: StreamId,
28    is_0rtt: bool,
29    finishing: Option<oneshot::Receiver<Option<WriteError>>>,
30}
31
32impl SendStream {
33    pub(crate) fn new(conn: ConnectionRef, stream: StreamId, is_0rtt: bool) -> Self {
34        Self {
35            conn,
36            stream,
37            is_0rtt,
38            finishing: None,
39        }
40    }
41
42    /// Write bytes to the stream
43    ///
44    /// Yields the number of bytes written on success. Congestion and flow control may cause this to
45    /// be shorter than `buf.len()`, indicating that only a prefix of `buf` was written.
46    pub async fn write(&mut self, buf: &[u8]) -> Result<usize, WriteError> {
47        Write { stream: self, buf }.await
48    }
49
50    /// Convenience method to write an entire buffer to the stream
51    pub async fn write_all(&mut self, buf: &[u8]) -> Result<(), WriteError> {
52        WriteAll { stream: self, buf }.await
53    }
54
55    /// Write chunks to the stream
56    ///
57    /// Yields the number of bytes and chunks written on success.
58    /// Congestion and flow control may cause this to be shorter than `buf.len()`,
59    /// indicating that only a prefix of `bufs` was written
60    pub async fn write_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Written, WriteError> {
61        WriteChunks { stream: self, bufs }.await
62    }
63
64    /// Convenience method to write a single chunk in its entirety to the stream
65    pub async fn write_chunk(&mut self, buf: Bytes) -> Result<(), WriteError> {
66        WriteChunk {
67            stream: self,
68            buf: [buf],
69        }
70        .await
71    }
72
73    /// Convenience method to write an entire list of chunks to the stream
74    pub async fn write_all_chunks(&mut self, bufs: &mut [Bytes]) -> Result<(), WriteError> {
75        WriteAllChunks {
76            stream: self,
77            bufs,
78            offset: 0,
79        }
80        .await
81    }
82
83    fn execute_poll<F, R>(&mut self, cx: &mut Context, write_fn: F) -> Poll<Result<R, WriteError>>
84    where
85        F: FnOnce(&mut proto::SendStream) -> Result<R, proto::WriteError>,
86    {
87        use proto::WriteError::*;
88        let mut conn = self.conn.state.lock("SendStream::poll_write");
89        if self.is_0rtt {
90            conn.check_0rtt()
91                .map_err(|()| WriteError::ZeroRttRejected)?;
92        }
93        if let Some(ref x) = conn.error {
94            return Poll::Ready(Err(WriteError::ConnectionLost(x.clone())));
95        }
96
97        let result = match write_fn(&mut conn.inner.send_stream(self.stream)) {
98            Ok(result) => result,
99            Err(Blocked) => {
100                conn.blocked_writers.insert(self.stream, cx.waker().clone());
101                return Poll::Pending;
102            }
103            Err(Stopped(error_code)) => {
104                return Poll::Ready(Err(WriteError::Stopped(error_code)));
105            }
106            Err(UnknownStream) => {
107                return Poll::Ready(Err(WriteError::UnknownStream));
108            }
109        };
110
111        conn.wake();
112        Poll::Ready(Ok(result))
113    }
114
115    /// Shut down the send stream gracefully.
116    ///
117    /// No new data may be written after calling this method. Completes when the peer has
118    /// acknowledged all sent data, retransmitting data as needed.
119    pub async fn finish(&mut self) -> Result<(), WriteError> {
120        Finish { stream: self }.await
121    }
122
123    #[doc(hidden)]
124    pub fn poll_finish(&mut self, cx: &mut Context) -> Poll<Result<(), WriteError>> {
125        let mut conn = self.conn.state.lock("poll_finish");
126        if self.is_0rtt {
127            conn.check_0rtt()
128                .map_err(|()| WriteError::ZeroRttRejected)?;
129        }
130        if self.finishing.is_none() {
131            conn.inner
132                .send_stream(self.stream)
133                .finish()
134                .map_err(|e| match e {
135                    FinishError::UnknownStream => WriteError::UnknownStream,
136                    FinishError::Stopped(error_code) => WriteError::Stopped(error_code),
137                })?;
138            let (send, recv) = oneshot::channel();
139            self.finishing = Some(recv);
140            conn.finishing.insert(self.stream, send);
141            conn.wake();
142        }
143        match Pin::new(self.finishing.as_mut().unwrap())
144            .poll(cx)
145            .map(|x| x.unwrap())
146        {
147            Poll::Ready(x) => {
148                self.finishing = None;
149                Poll::Ready(x.map_or(Ok(()), Err))
150            }
151            Poll::Pending => {
152                // To ensure that finished streams can be detected even after the connection is
153                // closed, we must only check for connection errors after determining that the
154                // stream has not yet been finished. Note that this relies on holding the connection
155                // lock so that it is impossible for the stream to become finished between the above
156                // poll call and this check.
157                if let Some(ref x) = conn.error {
158                    return Poll::Ready(Err(WriteError::ConnectionLost(x.clone())));
159                }
160                Poll::Pending
161            }
162        }
163    }
164
165    /// Close the send stream immediately.
166    ///
167    /// No new data can be written after calling this method. Locally buffered data is dropped, and
168    /// previously transmitted data will no longer be retransmitted if lost. If an attempt has
169    /// already been made to finish the stream, the peer may still receive all written data.
170    pub fn reset(&mut self, error_code: VarInt) -> Result<(), UnknownStream> {
171        let mut conn = self.conn.state.lock("SendStream::reset");
172        if self.is_0rtt && conn.check_0rtt().is_err() {
173            return Ok(());
174        }
175        conn.inner.send_stream(self.stream).reset(error_code)?;
176        conn.wake();
177        Ok(())
178    }
179
180    /// Set the priority of the send stream
181    ///
182    /// Every send stream has an initial priority of 0. Locally buffered data from streams with
183    /// higher priority will be transmitted before data from streams with lower priority. Changing
184    /// the priority of a stream with pending data may only take effect after that data has been
185    /// transmitted. Using many different priority levels per connection may have a negative
186    /// impact on performance.
187    pub fn set_priority(&self, priority: i32) -> Result<(), UnknownStream> {
188        let mut conn = self.conn.state.lock("SendStream::set_priority");
189        conn.inner.send_stream(self.stream).set_priority(priority)?;
190        Ok(())
191    }
192
193    /// Get the priority of the send stream
194    pub fn priority(&self) -> Result<i32, UnknownStream> {
195        let mut conn = self.conn.state.lock("SendStream::priority");
196        Ok(conn.inner.send_stream(self.stream).priority()?)
197    }
198
199    /// Completes if/when the peer stops the stream, yielding the error code
200    pub async fn stopped(&mut self) -> Result<VarInt, StoppedError> {
201        Stopped { stream: self }.await
202    }
203
204    #[doc(hidden)]
205    pub fn poll_stopped(&mut self, cx: &mut Context) -> Poll<Result<VarInt, StoppedError>> {
206        let mut conn = self.conn.state.lock("SendStream::poll_stopped");
207
208        if self.is_0rtt {
209            conn.check_0rtt()
210                .map_err(|()| StoppedError::ZeroRttRejected)?;
211        }
212
213        match conn.inner.send_stream(self.stream).stopped() {
214            Err(_) => Poll::Ready(Err(StoppedError::UnknownStream)),
215            Ok(Some(error_code)) => Poll::Ready(Ok(error_code)),
216            Ok(None) => {
217                conn.stopped.insert(self.stream, cx.waker().clone());
218                Poll::Pending
219            }
220        }
221    }
222
223    /// Get the identity of this stream
224    pub fn id(&self) -> StreamId {
225        self.stream
226    }
227}
228
229#[cfg(feature = "futures-io")]
230impl futures_io::AsyncWrite for SendStream {
231    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
232        SendStream::execute_poll(self.get_mut(), cx, |stream| stream.write(buf)).map_err(Into::into)
233    }
234
235    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
236        Poll::Ready(Ok(()))
237    }
238
239    fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
240        self.get_mut().poll_finish(cx).map_err(Into::into)
241    }
242}
243
244#[cfg(feature = "runtime-tokio")]
245impl tokio::io::AsyncWrite for SendStream {
246    fn poll_write(
247        self: Pin<&mut Self>,
248        cx: &mut Context<'_>,
249        buf: &[u8],
250    ) -> Poll<io::Result<usize>> {
251        Self::execute_poll(self.get_mut(), cx, |stream| stream.write(buf)).map_err(Into::into)
252    }
253
254    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
255        Poll::Ready(Ok(()))
256    }
257
258    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
259        self.get_mut().poll_finish(cx).map_err(Into::into)
260    }
261}
262
263impl Drop for SendStream {
264    fn drop(&mut self) {
265        let mut conn = self.conn.state.lock("SendStream::drop");
266
267        // clean up any previously registered wakers
268        conn.finishing.remove(&self.stream);
269        conn.stopped.remove(&self.stream);
270        conn.blocked_writers.remove(&self.stream);
271
272        if conn.error.is_some() || (self.is_0rtt && conn.check_0rtt().is_err()) {
273            return;
274        }
275        if self.finishing.is_none() {
276            match conn.inner.send_stream(self.stream).finish() {
277                Ok(()) => conn.wake(),
278                Err(FinishError::Stopped(reason)) => {
279                    if conn.inner.send_stream(self.stream).reset(reason).is_ok() {
280                        conn.wake();
281                    }
282                }
283                // Already finished or reset, which is fine.
284                Err(FinishError::UnknownStream) => {}
285            }
286        }
287    }
288}
289
290/// Future produced by `SendStream::finish`
291#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
292struct Finish<'a> {
293    stream: &'a mut SendStream,
294}
295
296impl Future for Finish<'_> {
297    type Output = Result<(), WriteError>;
298
299    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
300        self.get_mut().stream.poll_finish(cx)
301    }
302}
303
304/// Future produced by `SendStream::stopped`
305#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
306struct Stopped<'a> {
307    stream: &'a mut SendStream,
308}
309
310impl Future for Stopped<'_> {
311    type Output = Result<VarInt, StoppedError>;
312
313    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
314        self.get_mut().stream.poll_stopped(cx)
315    }
316}
317
318/// Future produced by [`SendStream::write()`].
319///
320/// [`SendStream::write()`]: crate::SendStream::write
321#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
322struct Write<'a> {
323    stream: &'a mut SendStream,
324    buf: &'a [u8],
325}
326
327impl<'a> Future for Write<'a> {
328    type Output = Result<usize, WriteError>;
329    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
330        let this = self.get_mut();
331        let buf = this.buf;
332        this.stream.execute_poll(cx, |s| s.write(buf))
333    }
334}
335
336/// Future produced by [`SendStream::write_all()`].
337///
338/// [`SendStream::write_all()`]: crate::SendStream::write_all
339#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
340struct WriteAll<'a> {
341    stream: &'a mut SendStream,
342    buf: &'a [u8],
343}
344
345impl<'a> Future for WriteAll<'a> {
346    type Output = Result<(), WriteError>;
347    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
348        let this = self.get_mut();
349        loop {
350            if this.buf.is_empty() {
351                return Poll::Ready(Ok(()));
352            }
353            let buf = this.buf;
354            let n = ready!(this.stream.execute_poll(cx, |s| s.write(buf)))?;
355            this.buf = &this.buf[n..];
356        }
357    }
358}
359
360/// Future produced by [`SendStream::write_chunks()`].
361///
362/// [`SendStream::write_chunks()`]: crate::SendStream::write_chunks
363#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
364struct WriteChunks<'a> {
365    stream: &'a mut SendStream,
366    bufs: &'a mut [Bytes],
367}
368
369impl<'a> Future for WriteChunks<'a> {
370    type Output = Result<Written, WriteError>;
371    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
372        let this = self.get_mut();
373        let bufs = &mut *this.bufs;
374        this.stream.execute_poll(cx, |s| s.write_chunks(bufs))
375    }
376}
377
378/// Future produced by [`SendStream::write_chunk()`].
379///
380/// [`SendStream::write_chunk()`]: crate::SendStream::write_chunk
381#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
382struct WriteChunk<'a> {
383    stream: &'a mut SendStream,
384    buf: [Bytes; 1],
385}
386
387impl<'a> Future for WriteChunk<'a> {
388    type Output = Result<(), WriteError>;
389    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
390        let this = self.get_mut();
391        loop {
392            if this.buf[0].is_empty() {
393                return Poll::Ready(Ok(()));
394            }
395            let bufs = &mut this.buf[..];
396            ready!(this.stream.execute_poll(cx, |s| s.write_chunks(bufs)))?;
397        }
398    }
399}
400
401/// Future produced by [`SendStream::write_all_chunks()`].
402///
403/// [`SendStream::write_all_chunks()`]: crate::SendStream::write_all_chunks
404#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
405struct WriteAllChunks<'a> {
406    stream: &'a mut SendStream,
407    bufs: &'a mut [Bytes],
408    offset: usize,
409}
410
411impl<'a> Future for WriteAllChunks<'a> {
412    type Output = Result<(), WriteError>;
413    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
414        let this = self.get_mut();
415        loop {
416            if this.offset == this.bufs.len() {
417                return Poll::Ready(Ok(()));
418            }
419            let bufs = &mut this.bufs[this.offset..];
420            let written = ready!(this.stream.execute_poll(cx, |s| s.write_chunks(bufs)))?;
421            this.offset += written.chunks;
422        }
423    }
424}
425
426/// Errors that arise from writing to a stream
427#[derive(Debug, Error, Clone, PartialEq, Eq)]
428pub enum WriteError {
429    /// The peer is no longer accepting data on this stream
430    ///
431    /// Carries an application-defined error code.
432    #[error("sending stopped by peer: error {0}")]
433    Stopped(VarInt),
434    /// The connection was lost
435    #[error("connection lost")]
436    ConnectionLost(#[from] ConnectionError),
437    /// The stream has already been finished or reset
438    #[error("unknown stream")]
439    UnknownStream,
440    /// This was a 0-RTT stream and the server rejected it
441    ///
442    /// Can only occur on clients for 0-RTT streams, which can be opened using
443    /// [`Connecting::into_0rtt()`].
444    ///
445    /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
446    #[error("0-RTT rejected")]
447    ZeroRttRejected,
448}
449
450/// Errors that arise while monitoring for a send stream stop from the peer
451#[derive(Debug, Error, Clone, PartialEq, Eq)]
452pub enum StoppedError {
453    /// The connection was lost
454    #[error("connection lost")]
455    ConnectionLost(#[from] ConnectionError),
456    /// The stream has already been finished or reset
457    #[error("unknown stream")]
458    UnknownStream,
459    /// This was a 0-RTT stream and the server rejected it
460    ///
461    /// Can only occur on clients for 0-RTT streams, which can be opened using
462    /// [`Connecting::into_0rtt()`].
463    ///
464    /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
465    #[error("0-RTT rejected")]
466    ZeroRttRejected,
467}
468
469impl From<WriteError> for io::Error {
470    fn from(x: WriteError) -> Self {
471        use self::WriteError::*;
472        let kind = match x {
473            Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset,
474            ConnectionLost(_) | UnknownStream => io::ErrorKind::NotConnected,
475        };
476        Self::new(kind, x)
477    }
478}