Skip to main content

ant_quic/high_level/
send_stream.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8use std::{
9    future::{Future, poll_fn},
10    io,
11    pin::{Pin, pin},
12    task::{Context, Poll},
13};
14
15use crate::{ClosedStream, ConnectionError, FinishError, StreamId, Written};
16use bytes::Bytes;
17use thiserror::Error;
18
19use super::connection::{ConnectionRef, State};
20use crate::VarInt;
21
22/// A stream that can only be used to send data
23///
24/// If dropped, streams that haven't been explicitly [`reset()`] will be implicitly [`finish()`]ed,
25/// continuing to (re)transmit previously written data until it has been fully acknowledged or the
26/// connection is closed.
27///
28/// # Cancellation
29///
30/// A `write` method is said to be *cancel-safe* when dropping its future before the future becomes
31/// ready will always result in no data being written to the stream. This is true of methods which
32/// succeed immediately when any progress is made, and is not true of methods which might need to
33/// perform multiple writes internally before succeeding. Each `write` method documents whether it is
34/// cancel-safe.
35///
36/// [`reset()`]: SendStream::reset
37/// [`finish()`]: SendStream::finish
38#[derive(Debug)]
39pub struct SendStream {
40    conn: ConnectionRef,
41    stream: StreamId,
42    is_0rtt: bool,
43}
44
45impl SendStream {
46    pub(crate) fn new(conn: ConnectionRef, stream: StreamId, is_0rtt: bool) -> Self {
47        Self {
48            conn,
49            stream,
50            is_0rtt,
51        }
52    }
53
54    /// Write bytes to the stream
55    ///
56    /// Yields the number of bytes written on success. Congestion and flow control may cause this to
57    /// be shorter than `buf.len()`, indicating that only a prefix of `buf` was written.
58    ///
59    /// This operation is cancel-safe.
60    pub async fn write(&mut self, buf: &[u8]) -> Result<usize, WriteError> {
61        poll_fn(|cx| self.execute_poll(cx, |s| s.write(buf))).await
62    }
63
64    /// Convenience method to write an entire buffer to the stream
65    ///
66    /// This operation is *not* cancel-safe.
67    pub async fn write_all(&mut self, mut buf: &[u8]) -> Result<(), WriteError> {
68        while !buf.is_empty() {
69            let written = self.write(buf).await?;
70            buf = &buf[written..];
71        }
72        Ok(())
73    }
74
75    /// Write chunks to the stream
76    ///
77    /// Yields the number of bytes and chunks written on success.
78    /// Congestion and flow control may cause this to be shorter than `buf.len()`,
79    /// indicating that only a prefix of `bufs` was written
80    ///
81    /// This operation is cancel-safe.
82    pub async fn write_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Written, WriteError> {
83        poll_fn(|cx| self.execute_poll(cx, |s| s.write_chunks(bufs))).await
84    }
85
86    /// Convenience method to write a single chunk in its entirety to the stream
87    ///
88    /// This operation is *not* cancel-safe.
89    pub async fn write_chunk(&mut self, buf: Bytes) -> Result<(), WriteError> {
90        self.write_all_chunks(&mut [buf]).await?;
91        Ok(())
92    }
93
94    /// Convenience method to write an entire list of chunks to the stream
95    ///
96    /// This operation is *not* cancel-safe.
97    pub async fn write_all_chunks(&mut self, mut bufs: &mut [Bytes]) -> Result<(), WriteError> {
98        while !bufs.is_empty() {
99            let written = self.write_chunks(bufs).await?;
100            bufs = &mut bufs[written.chunks..];
101        }
102        Ok(())
103    }
104
105    fn execute_poll<F, R>(&mut self, cx: &mut Context, write_fn: F) -> Poll<Result<R, WriteError>>
106    where
107        F: FnOnce(&mut crate::SendStream) -> Result<R, crate::WriteError>,
108    {
109        use crate::WriteError::*;
110        let mut conn = self.conn.state.lock("SendStream::poll_write");
111        if self.is_0rtt {
112            conn.check_0rtt()
113                .map_err(|()| WriteError::ZeroRttRejected)?;
114        }
115        if let Some(ref x) = conn.error {
116            return Poll::Ready(Err(WriteError::ConnectionLost(x.clone())));
117        }
118
119        let result = match write_fn(&mut conn.inner.send_stream(self.stream)) {
120            Ok(result) => result,
121            Err(Blocked) => {
122                conn.blocked_writers.insert(self.stream, cx.waker().clone());
123                return Poll::Pending;
124            }
125            Err(Stopped(error_code)) => {
126                return Poll::Ready(Err(WriteError::Stopped(error_code)));
127            }
128            Err(ClosedStream) => {
129                return Poll::Ready(Err(WriteError::ClosedStream));
130            }
131            Err(ConnectionClosed) => {
132                return Poll::Ready(Err(WriteError::ClosedStream));
133            }
134        };
135
136        conn.wake();
137        Poll::Ready(Ok(result))
138    }
139
140    /// Notify the peer that no more data will ever be written to this stream
141    ///
142    /// It is an error to write to a [`SendStream`] after `finish()`ing it. [`reset()`](Self::reset)
143    /// may still be called after `finish` to abandon transmission of any stream data that might
144    /// still be buffered.
145    ///
146    /// To wait for the peer to receive all buffered stream data, see [`stopped()`](Self::stopped).
147    ///
148    /// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was previously
149    /// called. This error is harmless and serves only to indicate that the caller may have
150    /// incorrect assumptions about the stream's state.
151    pub fn finish(&mut self) -> Result<(), ClosedStream> {
152        let mut conn = self.conn.state.lock("finish");
153        match conn.inner.send_stream(self.stream).finish() {
154            Ok(()) => {
155                conn.wake();
156                Ok(())
157            }
158            Err(FinishError::ClosedStream) => Err(ClosedStream::default()),
159            // Harmless. If the application needs to know about stopped streams at this point, it
160            // should call `stopped`.
161            Err(FinishError::Stopped(_)) => Ok(()),
162            Err(FinishError::ConnectionClosed) => Err(ClosedStream::default()),
163        }
164    }
165
166    /// Close the send stream immediately.
167    ///
168    /// No new data can be written after calling this method. Locally buffered data is dropped, and
169    /// previously transmitted data will no longer be retransmitted if lost. If an attempt has
170    /// already been made to finish the stream, the peer may still receive all written data.
171    ///
172    /// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was previously
173    /// called. This error is harmless and serves only to indicate that the caller may have
174    /// incorrect assumptions about the stream's state.
175    pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
176        let mut conn = self.conn.state.lock("SendStream::reset");
177        if self.is_0rtt && conn.check_0rtt().is_err() {
178            return Ok(());
179        }
180        conn.inner.send_stream(self.stream).reset(error_code)?;
181        conn.wake();
182        Ok(())
183    }
184
185    /// Set the priority of the send stream
186    ///
187    /// Every send stream has an initial priority of 0. Locally buffered data from streams with
188    /// higher priority will be transmitted before data from streams with lower priority. Changing
189    /// the priority of a stream with pending data may only take effect after that data has been
190    /// transmitted. Using many different priority levels per connection may have a negative
191    /// impact on performance.
192    pub fn set_priority(&self, priority: i32) -> Result<(), ClosedStream> {
193        let mut conn = self.conn.state.lock("SendStream::set_priority");
194        conn.inner.send_stream(self.stream).set_priority(priority)?;
195        Ok(())
196    }
197
198    /// Get the priority of the send stream
199    pub fn priority(&self) -> Result<i32, ClosedStream> {
200        let mut conn = self.conn.state.lock("SendStream::priority");
201        conn.inner.send_stream(self.stream).priority()
202    }
203
204    /// Completes when the peer stops the stream or reads the stream to completion
205    ///
206    /// Yields `Some` with the stop error code if the peer stops the stream. Yields `None` if the
207    /// local side [`finish()`](Self::finish)es the stream and then the peer acknowledges receipt
208    /// of all stream data (although not necessarily the processing of it), after which the peer
209    /// closing the stream is no longer meaningful.
210    ///
211    /// For a variety of reasons, the peer may not send acknowledgements immediately upon receiving
212    /// data. As such, relying on `stopped` to know when the peer has read a stream to completion
213    /// may introduce more latency than using an application-level response of some sort.
214    pub fn stopped(
215        &self,
216    ) -> impl Future<Output = Result<Option<VarInt>, StoppedError>> + Send + Sync + 'static + use<>
217    {
218        let conn = self.conn.clone();
219        let stream = self.stream;
220        let is_0rtt = self.is_0rtt;
221        async move {
222            loop {
223                // The `Notify::notified` future needs to be created while the lock is being held,
224                // otherwise a wakeup could be missed if triggered inbetween releasing the lock
225                // and creating the future.
226                // The lock may only be held in a block without `await`s, otherwise the future
227                // becomes `!Send`. `Notify::notified` is lifetime-bound to `Notify`, therefore
228                // we need to declare `notify` outside of the block, and initialize it inside.
229                let notify;
230                {
231                    let mut conn = conn.state.lock("SendStream::stopped");
232                    if let Some(output) = send_stream_stopped(&mut conn, stream, is_0rtt) {
233                        return output;
234                    }
235
236                    notify = conn.stopped.entry(stream).or_default().clone();
237                    notify.notified()
238                }
239                .await
240            }
241        }
242    }
243
244    /// Get the identity of this stream
245    pub fn id(&self) -> StreamId {
246        self.stream
247    }
248
249    /// Attempt to write bytes from buf into the stream.
250    ///
251    /// On success, returns Poll::Ready(Ok(num_bytes_written)).
252    ///
253    /// If the stream is not ready for writing, the method returns Poll::Pending and arranges
254    /// for the current task (via cx.waker().wake_by_ref()) to receive a notification when the
255    /// stream becomes writable or is closed.
256    pub fn poll_write(
257        self: Pin<&mut Self>,
258        cx: &mut Context,
259        buf: &[u8],
260    ) -> Poll<Result<usize, WriteError>> {
261        pin!(self.get_mut().write(buf)).as_mut().poll(cx)
262    }
263}
264
265/// Check if a send stream is stopped.
266///
267/// Returns `Some` if the stream is stopped or the connection is closed.
268/// Returns `None` if the stream is not stopped.
269fn send_stream_stopped(
270    conn: &mut State,
271    stream: StreamId,
272    is_0rtt: bool,
273) -> Option<Result<Option<VarInt>, StoppedError>> {
274    if is_0rtt && conn.check_0rtt().is_err() {
275        return Some(Err(StoppedError::ZeroRttRejected));
276    }
277    match conn.inner.send_stream(stream).stopped() {
278        Err(ClosedStream { .. }) => Some(Ok(None)),
279        Ok(Some(error_code)) => Some(Ok(Some(error_code))),
280        Ok(None) => conn.error.clone().map(|error| Err(error.into())),
281    }
282}
283
284impl tokio::io::AsyncWrite for SendStream {
285    fn poll_write(
286        self: Pin<&mut Self>,
287        cx: &mut Context<'_>,
288        buf: &[u8],
289    ) -> Poll<io::Result<usize>> {
290        self.poll_write(cx, buf).map_err(Into::into)
291    }
292
293    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
294        Poll::Ready(Ok(()))
295    }
296
297    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
298        Poll::Ready(self.get_mut().finish().map_err(Into::into))
299    }
300}
301
302impl Drop for SendStream {
303    fn drop(&mut self) {
304        let mut conn = self.conn.state.lock("SendStream::drop");
305
306        // clean up any previously registered wakers
307        conn.blocked_writers.remove(&self.stream);
308
309        if conn.error.is_some() || (self.is_0rtt && conn.check_0rtt().is_err()) {
310            return;
311        }
312        match conn.inner.send_stream(self.stream).finish() {
313            Ok(()) => conn.wake(),
314            Err(FinishError::Stopped(reason)) => {
315                if conn.inner.send_stream(self.stream).reset(reason).is_ok() {
316                    conn.wake();
317                }
318            }
319            // Already finished or reset, which is fine.
320            Err(FinishError::ClosedStream) => {}
321            Err(FinishError::ConnectionClosed) => {}
322        }
323    }
324}
325
326/// Errors that arise from writing to a stream
327#[derive(Debug, Error, Clone, PartialEq, Eq)]
328pub enum WriteError {
329    /// The peer is no longer accepting data on this stream
330    ///
331    /// Carries an application-defined error code.
332    #[error("sending stopped by peer: error {0}")]
333    Stopped(VarInt),
334    /// The connection was lost
335    #[error("connection lost")]
336    ConnectionLost(#[from] ConnectionError),
337    /// The stream has already been finished or reset
338    #[error("closed stream")]
339    ClosedStream,
340    /// This was a 0-RTT stream and the server rejected it
341    ///
342    /// Can only occur on clients for 0-RTT streams, which can be opened using
343    /// [`Connecting::into_0rtt()`].
344    ///
345    /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
346    #[error("0-RTT rejected")]
347    ZeroRttRejected,
348}
349
350impl From<ClosedStream> for WriteError {
351    #[inline]
352    fn from(_: ClosedStream) -> Self {
353        Self::ClosedStream
354    }
355}
356
357impl From<StoppedError> for WriteError {
358    fn from(x: StoppedError) -> Self {
359        match x {
360            StoppedError::ConnectionLost(e) => Self::ConnectionLost(e),
361            StoppedError::ZeroRttRejected => Self::ZeroRttRejected,
362        }
363    }
364}
365
366impl From<WriteError> for io::Error {
367    fn from(x: WriteError) -> Self {
368        use WriteError::*;
369        let kind = match x {
370            Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset,
371            ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
372        };
373        Self::new(kind, x)
374    }
375}
376
377/// Errors that arise while monitoring for a send stream stop from the peer
378#[derive(Debug, Error, Clone, PartialEq, Eq)]
379pub enum StoppedError {
380    /// The connection was lost
381    #[error("connection lost")]
382    ConnectionLost(#[from] ConnectionError),
383    /// This was a 0-RTT stream and the server rejected it
384    ///
385    /// Can only occur on clients for 0-RTT streams, which can be opened using
386    /// [`Connecting::into_0rtt()`].
387    ///
388    /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
389    #[error("0-RTT rejected")]
390    ZeroRttRejected,
391}
392
393impl From<StoppedError> for io::Error {
394    fn from(x: StoppedError) -> Self {
395        use StoppedError::*;
396        let kind = match x {
397            ZeroRttRejected => io::ErrorKind::ConnectionReset,
398            ConnectionLost(_) => io::ErrorKind::NotConnected,
399        };
400        Self::new(kind, x)
401    }
402}