ant_quic/high_level/
send_stream.rs

1use std::{
2    future::{Future, poll_fn},
3    io,
4    pin::{Pin, pin},
5    task::{Context, Poll},
6};
7
8use crate::{ClosedStream, ConnectionError, FinishError, StreamId, Written};
9use bytes::Bytes;
10use thiserror::Error;
11
12use super::connection::{ConnectionRef, State};
13use crate::VarInt;
14
15/// A stream that can only be used to send data
16///
17/// If dropped, streams that haven't been explicitly [`reset()`] will be implicitly [`finish()`]ed,
18/// continuing to (re)transmit previously written data until it has been fully acknowledged or the
19/// connection is closed.
20///
21/// # Cancellation
22///
23/// A `write` method is said to be *cancel-safe* when dropping its future before the future becomes
24/// ready will always result in no data being written to the stream. This is true of methods which
25/// succeed immediately when any progress is made, and is not true of methods which might need to
26/// perform multiple writes internally before succeeding. Each `write` method documents whether it is
27/// cancel-safe.
28///
29/// [`reset()`]: SendStream::reset
30/// [`finish()`]: SendStream::finish
31#[derive(Debug)]
32pub struct SendStream {
33    conn: ConnectionRef,
34    stream: StreamId,
35    is_0rtt: bool,
36}
37
38impl SendStream {
39    pub(crate) fn new(conn: ConnectionRef, stream: StreamId, is_0rtt: bool) -> Self {
40        Self {
41            conn,
42            stream,
43            is_0rtt,
44        }
45    }
46
47    /// Write bytes to the stream
48    ///
49    /// Yields the number of bytes written on success. Congestion and flow control may cause this to
50    /// be shorter than `buf.len()`, indicating that only a prefix of `buf` was written.
51    ///
52    /// This operation is cancel-safe.
53    pub async fn write(&mut self, buf: &[u8]) -> Result<usize, WriteError> {
54        poll_fn(|cx| self.execute_poll(cx, |s| s.write(buf))).await
55    }
56
57    /// Convenience method to write an entire buffer to the stream
58    ///
59    /// This operation is *not* cancel-safe.
60    pub async fn write_all(&mut self, mut buf: &[u8]) -> Result<(), WriteError> {
61        while !buf.is_empty() {
62            let written = self.write(buf).await?;
63            buf = &buf[written..];
64        }
65        Ok(())
66    }
67
68    /// Write chunks to the stream
69    ///
70    /// Yields the number of bytes and chunks written on success.
71    /// Congestion and flow control may cause this to be shorter than `buf.len()`,
72    /// indicating that only a prefix of `bufs` was written
73    ///
74    /// This operation is cancel-safe.
75    pub async fn write_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Written, WriteError> {
76        poll_fn(|cx| self.execute_poll(cx, |s| s.write_chunks(bufs))).await
77    }
78
79    /// Convenience method to write a single chunk in its entirety to the stream
80    ///
81    /// This operation is *not* cancel-safe.
82    pub async fn write_chunk(&mut self, buf: Bytes) -> Result<(), WriteError> {
83        self.write_all_chunks(&mut [buf]).await?;
84        Ok(())
85    }
86
87    /// Convenience method to write an entire list of chunks to the stream
88    ///
89    /// This operation is *not* cancel-safe.
90    pub async fn write_all_chunks(&mut self, mut bufs: &mut [Bytes]) -> Result<(), WriteError> {
91        while !bufs.is_empty() {
92            let written = self.write_chunks(bufs).await?;
93            bufs = &mut bufs[written.chunks..];
94        }
95        Ok(())
96    }
97
98    fn execute_poll<F, R>(&mut self, cx: &mut Context, write_fn: F) -> Poll<Result<R, WriteError>>
99    where
100        F: FnOnce(&mut crate::SendStream) -> Result<R, crate::WriteError>,
101    {
102        use crate::WriteError::*;
103        let mut conn = self.conn.state.lock("SendStream::poll_write");
104        if self.is_0rtt {
105            conn.check_0rtt()
106                .map_err(|()| WriteError::ZeroRttRejected)?;
107        }
108        if let Some(ref x) = conn.error {
109            return Poll::Ready(Err(WriteError::ConnectionLost(x.clone())));
110        }
111
112        let result = match write_fn(&mut conn.inner.send_stream(self.stream)) {
113            Ok(result) => result,
114            Err(Blocked) => {
115                conn.blocked_writers.insert(self.stream, cx.waker().clone());
116                return Poll::Pending;
117            }
118            Err(Stopped(error_code)) => {
119                return Poll::Ready(Err(WriteError::Stopped(error_code)));
120            }
121            Err(ClosedStream) => {
122                return Poll::Ready(Err(WriteError::ClosedStream));
123            }
124            Err(ConnectionClosed) => {
125                return Poll::Ready(Err(WriteError::ClosedStream));
126            }
127        };
128
129        conn.wake();
130        Poll::Ready(Ok(result))
131    }
132
133    /// Notify the peer that no more data will ever be written to this stream
134    ///
135    /// It is an error to write to a [`SendStream`] after `finish()`ing it. [`reset()`](Self::reset)
136    /// may still be called after `finish` to abandon transmission of any stream data that might
137    /// still be buffered.
138    ///
139    /// To wait for the peer to receive all buffered stream data, see [`stopped()`](Self::stopped).
140    ///
141    /// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was previously
142    /// called. This error is harmless and serves only to indicate that the caller may have
143    /// incorrect assumptions about the stream's state.
144    pub fn finish(&mut self) -> Result<(), ClosedStream> {
145        let mut conn = self.conn.state.lock("finish");
146        match conn.inner.send_stream(self.stream).finish() {
147            Ok(()) => {
148                conn.wake();
149                Ok(())
150            }
151            Err(FinishError::ClosedStream) => Err(ClosedStream::default()),
152            // Harmless. If the application needs to know about stopped streams at this point, it
153            // should call `stopped`.
154            Err(FinishError::Stopped(_)) => Ok(()),
155            Err(FinishError::ConnectionClosed) => Err(ClosedStream::default()),
156        }
157    }
158
159    /// Close the send stream immediately.
160    ///
161    /// No new data can be written after calling this method. Locally buffered data is dropped, and
162    /// previously transmitted data will no longer be retransmitted if lost. If an attempt has
163    /// already been made to finish the stream, the peer may still receive all written data.
164    ///
165    /// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was previously
166    /// called. This error is harmless and serves only to indicate that the caller may have
167    /// incorrect assumptions about the stream's state.
168    pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
169        let mut conn = self.conn.state.lock("SendStream::reset");
170        if self.is_0rtt && conn.check_0rtt().is_err() {
171            return Ok(());
172        }
173        conn.inner.send_stream(self.stream).reset(error_code)?;
174        conn.wake();
175        Ok(())
176    }
177
178    /// Set the priority of the send stream
179    ///
180    /// Every send stream has an initial priority of 0. Locally buffered data from streams with
181    /// higher priority will be transmitted before data from streams with lower priority. Changing
182    /// the priority of a stream with pending data may only take effect after that data has been
183    /// transmitted. Using many different priority levels per connection may have a negative
184    /// impact on performance.
185    pub fn set_priority(&self, priority: i32) -> Result<(), ClosedStream> {
186        let mut conn = self.conn.state.lock("SendStream::set_priority");
187        conn.inner.send_stream(self.stream).set_priority(priority)?;
188        Ok(())
189    }
190
191    /// Get the priority of the send stream
192    pub fn priority(&self) -> Result<i32, ClosedStream> {
193        let mut conn = self.conn.state.lock("SendStream::priority");
194        conn.inner.send_stream(self.stream).priority()
195    }
196
197    /// Completes when the peer stops the stream or reads the stream to completion
198    ///
199    /// Yields `Some` with the stop error code if the peer stops the stream. Yields `None` if the
200    /// local side [`finish()`](Self::finish)es the stream and then the peer acknowledges receipt
201    /// of all stream data (although not necessarily the processing of it), after which the peer
202    /// closing the stream is no longer meaningful.
203    ///
204    /// For a variety of reasons, the peer may not send acknowledgements immediately upon receiving
205    /// data. As such, relying on `stopped` to know when the peer has read a stream to completion
206    /// may introduce more latency than using an application-level response of some sort.
207    pub fn stopped(
208        &self,
209    ) -> impl Future<Output = Result<Option<VarInt>, StoppedError>> + Send + Sync + 'static + use<>
210    {
211        let conn = self.conn.clone();
212        let stream = self.stream;
213        let is_0rtt = self.is_0rtt;
214        async move {
215            loop {
216                // The `Notify::notified` future needs to be created while the lock is being held,
217                // otherwise a wakeup could be missed if triggered inbetween releasing the lock
218                // and creating the future.
219                // The lock may only be held in a block without `await`s, otherwise the future
220                // becomes `!Send`. `Notify::notified` is lifetime-bound to `Notify`, therefore
221                // we need to declare `notify` outside of the block, and initialize it inside.
222                let notify;
223                {
224                    let mut conn = conn.state.lock("SendStream::stopped");
225                    if let Some(output) = send_stream_stopped(&mut conn, stream, is_0rtt) {
226                        return output;
227                    }
228
229                    notify = conn.stopped.entry(stream).or_default().clone();
230                    notify.notified()
231                }
232                .await
233            }
234        }
235    }
236
237    /// Get the identity of this stream
238    pub fn id(&self) -> StreamId {
239        self.stream
240    }
241
242    /// Attempt to write bytes from buf into the stream.
243    ///
244    /// On success, returns Poll::Ready(Ok(num_bytes_written)).
245    ///
246    /// If the stream is not ready for writing, the method returns Poll::Pending and arranges
247    /// for the current task (via cx.waker().wake_by_ref()) to receive a notification when the
248    /// stream becomes writable or is closed.
249    pub fn poll_write(
250        self: Pin<&mut Self>,
251        cx: &mut Context,
252        buf: &[u8],
253    ) -> Poll<Result<usize, WriteError>> {
254        pin!(self.get_mut().write(buf)).as_mut().poll(cx)
255    }
256}
257
258/// Check if a send stream is stopped.
259///
260/// Returns `Some` if the stream is stopped or the connection is closed.
261/// Returns `None` if the stream is not stopped.
262fn send_stream_stopped(
263    conn: &mut State,
264    stream: StreamId,
265    is_0rtt: bool,
266) -> Option<Result<Option<VarInt>, StoppedError>> {
267    if is_0rtt && conn.check_0rtt().is_err() {
268        return Some(Err(StoppedError::ZeroRttRejected));
269    }
270    match conn.inner.send_stream(stream).stopped() {
271        Err(ClosedStream { .. }) => Some(Ok(None)),
272        Ok(Some(error_code)) => Some(Ok(Some(error_code))),
273        Ok(None) => conn.error.clone().map(|error| Err(error.into())),
274    }
275}
276
277/* TODO: Enable when futures-io feature is added
278#[cfg(feature = "futures-io")]
279impl futures_io::AsyncWrite for SendStream {
280    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
281        self.poll_write(cx, buf).map_err(Into::into)
282    }
283
284    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
285        Poll::Ready(Ok(()))
286    }
287
288    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
289        Poll::Ready(self.get_mut().finish().map_err(Into::into))
290    }
291}
292*/
293
294impl tokio::io::AsyncWrite for SendStream {
295    fn poll_write(
296        self: Pin<&mut Self>,
297        cx: &mut Context<'_>,
298        buf: &[u8],
299    ) -> Poll<io::Result<usize>> {
300        self.poll_write(cx, buf).map_err(Into::into)
301    }
302
303    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
304        Poll::Ready(Ok(()))
305    }
306
307    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
308        Poll::Ready(self.get_mut().finish().map_err(Into::into))
309    }
310}
311
312impl Drop for SendStream {
313    fn drop(&mut self) {
314        let mut conn = self.conn.state.lock("SendStream::drop");
315
316        // clean up any previously registered wakers
317        conn.blocked_writers.remove(&self.stream);
318
319        if conn.error.is_some() || (self.is_0rtt && conn.check_0rtt().is_err()) {
320            return;
321        }
322        match conn.inner.send_stream(self.stream).finish() {
323            Ok(()) => conn.wake(),
324            Err(FinishError::Stopped(reason)) => {
325                if conn.inner.send_stream(self.stream).reset(reason).is_ok() {
326                    conn.wake();
327                }
328            }
329            // Already finished or reset, which is fine.
330            Err(FinishError::ClosedStream) => {}
331            Err(FinishError::ConnectionClosed) => {}
332        }
333    }
334}
335
336/// Errors that arise from writing to a stream
337#[derive(Debug, Error, Clone, PartialEq, Eq)]
338pub enum WriteError {
339    /// The peer is no longer accepting data on this stream
340    ///
341    /// Carries an application-defined error code.
342    #[error("sending stopped by peer: error {0}")]
343    Stopped(VarInt),
344    /// The connection was lost
345    #[error("connection lost")]
346    ConnectionLost(#[from] ConnectionError),
347    /// The stream has already been finished or reset
348    #[error("closed stream")]
349    ClosedStream,
350    /// This was a 0-RTT stream and the server rejected it
351    ///
352    /// Can only occur on clients for 0-RTT streams, which can be opened using
353    /// [`Connecting::into_0rtt()`].
354    ///
355    /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
356    #[error("0-RTT rejected")]
357    ZeroRttRejected,
358}
359
360impl From<ClosedStream> for WriteError {
361    #[inline]
362    fn from(_: ClosedStream) -> Self {
363        Self::ClosedStream
364    }
365}
366
367impl From<StoppedError> for WriteError {
368    fn from(x: StoppedError) -> Self {
369        match x {
370            StoppedError::ConnectionLost(e) => Self::ConnectionLost(e),
371            StoppedError::ZeroRttRejected => Self::ZeroRttRejected,
372        }
373    }
374}
375
376impl From<WriteError> for io::Error {
377    fn from(x: WriteError) -> Self {
378        use WriteError::*;
379        let kind = match x {
380            Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset,
381            ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
382        };
383        Self::new(kind, x)
384    }
385}
386
387/// Errors that arise while monitoring for a send stream stop from the peer
388#[derive(Debug, Error, Clone, PartialEq, Eq)]
389pub enum StoppedError {
390    /// The connection was lost
391    #[error("connection lost")]
392    ConnectionLost(#[from] ConnectionError),
393    /// This was a 0-RTT stream and the server rejected it
394    ///
395    /// Can only occur on clients for 0-RTT streams, which can be opened using
396    /// [`Connecting::into_0rtt()`].
397    ///
398    /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
399    #[error("0-RTT rejected")]
400    ZeroRttRejected,
401}
402
403impl From<StoppedError> for io::Error {
404    fn from(x: StoppedError) -> Self {
405        use StoppedError::*;
406        let kind = match x {
407            ZeroRttRejected => io::ErrorKind::ConnectionReset,
408            ConnectionLost(_) => io::ErrorKind::NotConnected,
409        };
410        Self::new(kind, x)
411    }
412}