compio_quic/
send_stream.rs

1use std::{
2    io,
3    sync::Arc,
4    task::{Context, Poll},
5};
6
7use compio_buf::{BufResult, IoBuf, bytes::Bytes};
8use compio_io::AsyncWrite;
9use futures_util::{future::poll_fn, ready};
10use quinn_proto::{ClosedStream, FinishError, StreamId, VarInt, Written};
11use thiserror::Error;
12
13use crate::{ConnectionError, ConnectionInner, StoppedError};
14
15/// A stream that can only be used to send data.
16///
17/// If dropped, streams that haven't been explicitly [`reset()`] will be
18/// implicitly [`finish()`]ed, continuing to (re)transmit previously written
19/// data until it has been fully acknowledged or the connection is closed.
20///
21/// # Cancellation
22///
23/// A `write` method is said to be *cancel-safe* when dropping its future before
24/// the future becomes ready will always result in no data being written to the
25/// stream. This is true of methods which succeed immediately when any progress
26/// is made, and is not true of methods which might need to perform multiple
27/// writes internally before succeeding. Each `write` method documents whether
28/// it is cancel-safe.
29///
30/// [`reset()`]: SendStream::reset
31/// [`finish()`]: SendStream::finish
32#[derive(Debug)]
33pub struct SendStream {
34    conn: Arc<ConnectionInner>,
35    stream: StreamId,
36    is_0rtt: bool,
37}
38
39impl SendStream {
40    pub(crate) fn new(conn: Arc<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
41        Self {
42            conn,
43            stream,
44            is_0rtt,
45        }
46    }
47
48    /// Get the identity of this stream
49    pub fn id(&self) -> StreamId {
50        self.stream
51    }
52
53    /// Notify the peer that no more data will ever be written to this stream.
54    ///
55    /// It is an error to write to a stream after `finish()`ing it. [`reset()`]
56    /// may still be called after `finish` to abandon transmission of any stream
57    /// data that might still be buffered.
58    ///
59    /// To wait for the peer to receive all buffered stream data, see
60    /// [`stopped()`].
61    ///
62    /// May fail if [`finish()`] or  [`reset()`] was previously called.This
63    /// error is harmless and serves only to indicate that the caller may have
64    /// incorrect assumptions about the stream's state.
65    ///
66    /// [`reset()`]: Self::reset
67    /// [`stopped()`]: Self::stopped
68    /// [`finish()`]: Self::finish
69    pub fn finish(&mut self) -> Result<(), ClosedStream> {
70        let mut state = self.conn.state();
71        match state.conn.send_stream(self.stream).finish() {
72            Ok(()) => {
73                state.wake();
74                Ok(())
75            }
76            Err(FinishError::ClosedStream) => Err(ClosedStream::default()),
77            // Harmless. If the application needs to know about stopped streams at this point,
78            // it should call `stopped`.
79            Err(FinishError::Stopped(_)) => Ok(()),
80        }
81    }
82
83    /// Close the stream immediately.
84    ///
85    /// No new data can be written after calling this method. Locally buffered
86    /// data is dropped, and previously transmitted data will no longer be
87    /// retransmitted if lost. If an attempt has already been made to finish
88    /// the stream, the peer may still receive all written data.
89    ///
90    /// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was
91    /// previously called. This error is harmless and serves only to
92    /// indicate that the caller may have incorrect assumptions about the
93    /// stream's state.
94    pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
95        let mut state = self.conn.state();
96        if self.is_0rtt && !state.check_0rtt() {
97            return Ok(());
98        }
99        state.conn.send_stream(self.stream).reset(error_code)?;
100        state.wake();
101        Ok(())
102    }
103
104    /// Set the priority of the stream.
105    ///
106    /// Every stream has an initial priority of 0. Locally buffered data
107    /// from streams with higher priority will be transmitted before data
108    /// from streams with lower priority. Changing the priority of a stream
109    /// with pending data may only take effect after that data has been
110    /// transmitted. Using many different priority levels per connection may
111    /// have a negative impact on performance.
112    pub fn set_priority(&self, priority: i32) -> Result<(), ClosedStream> {
113        self.conn
114            .state()
115            .conn
116            .send_stream(self.stream)
117            .set_priority(priority)
118    }
119
120    /// Get the priority of the stream
121    pub fn priority(&self) -> Result<i32, ClosedStream> {
122        self.conn.state().conn.send_stream(self.stream).priority()
123    }
124
125    /// Completes when the peer stops the stream or reads the stream to
126    /// completion.
127    ///
128    /// Yields `Some` with the stop error code if the peer stops the stream.
129    /// Yields `None` if the local side [`finish()`](Self::finish)es the stream
130    /// and then the peer acknowledges receipt of all stream data (although not
131    /// necessarily the processing of it), after which the peer closing the
132    /// stream is no longer meaningful.
133    ///
134    /// For a variety of reasons, the peer may not send acknowledgements
135    /// immediately upon receiving data. As such, relying on `stopped` to
136    /// know when the peer has read a stream to completion may introduce
137    /// more latency than using an application-level response of some sort.
138    pub async fn stopped(&mut self) -> Result<Option<VarInt>, StoppedError> {
139        poll_fn(|cx| {
140            let mut state = self.conn.state();
141            if self.is_0rtt && !state.check_0rtt() {
142                return Poll::Ready(Err(StoppedError::ZeroRttRejected));
143            }
144            match state.conn.send_stream(self.stream).stopped() {
145                Err(_) => Poll::Ready(Ok(None)),
146                Ok(Some(error_code)) => Poll::Ready(Ok(Some(error_code))),
147                Ok(None) => {
148                    if let Some(e) = &state.error {
149                        return Poll::Ready(Err(e.clone().into()));
150                    }
151                    state.stopped.insert(self.stream, cx.waker().clone());
152                    Poll::Pending
153                }
154            }
155        })
156        .await
157    }
158
159    fn execute_poll_write<F, R>(&mut self, cx: &mut Context, f: F) -> Poll<Result<R, WriteError>>
160    where
161        F: FnOnce(quinn_proto::SendStream) -> Result<R, quinn_proto::WriteError>,
162    {
163        let mut state = self.conn.try_state()?;
164        if self.is_0rtt && !state.check_0rtt() {
165            return Poll::Ready(Err(WriteError::ZeroRttRejected));
166        }
167        match f(state.conn.send_stream(self.stream)) {
168            Ok(r) => {
169                state.wake();
170                Poll::Ready(Ok(r))
171            }
172            Err(e) => match e.try_into() {
173                Ok(e) => Poll::Ready(Err(e)),
174                Err(()) => {
175                    state.writable.insert(self.stream, cx.waker().clone());
176                    Poll::Pending
177                }
178            },
179        }
180    }
181
182    /// Write bytes to the stream.
183    ///
184    /// Yields the number of bytes written on success. Congestion and flow
185    /// control may cause this to be shorter than `buf.len()`, indicating
186    /// that only a prefix of `buf` was written.
187    ///
188    /// This operation is cancel-safe.
189    pub async fn write(&mut self, buf: &[u8]) -> Result<usize, WriteError> {
190        poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write(buf))).await
191    }
192
193    /// Convenience method to write an entire buffer to the stream.
194    ///
195    /// This operation is *not* cancel-safe.
196    pub async fn write_all(&mut self, buf: &[u8]) -> Result<(), WriteError> {
197        let mut count = 0;
198        poll_fn(|cx| {
199            loop {
200                if count == buf.len() {
201                    return Poll::Ready(Ok(()));
202                }
203                let n =
204                    ready!(self.execute_poll_write(cx, |mut stream| stream.write(&buf[count..])))?;
205                count += n;
206            }
207        })
208        .await
209    }
210
211    /// Write chunks to the stream.
212    ///
213    /// Yields the number of bytes and chunks written on success.
214    /// Congestion and flow control may cause this to be shorter than
215    /// `buf.len()`, indicating that only a prefix of `bufs` was written.
216    ///
217    /// This operation is cancel-safe.
218    pub async fn write_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Written, WriteError> {
219        poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write_chunks(bufs))).await
220    }
221
222    /// Convenience method to write an entire list of chunks to the stream.
223    ///
224    /// This operation is *not* cancel-safe.
225    pub async fn write_all_chunks(&mut self, bufs: &mut [Bytes]) -> Result<(), WriteError> {
226        let mut chunks = 0;
227        poll_fn(|cx| {
228            loop {
229                if chunks == bufs.len() {
230                    return Poll::Ready(Ok(()));
231                }
232                let written = ready!(self.execute_poll_write(cx, |mut stream| {
233                    stream.write_chunks(&mut bufs[chunks..])
234                }))?;
235                chunks += written.chunks;
236            }
237        })
238        .await
239    }
240}
241
242impl Drop for SendStream {
243    fn drop(&mut self) {
244        let mut state = self.conn.state();
245
246        // clean up any previously registered wakers
247        state.stopped.remove(&self.stream);
248        state.writable.remove(&self.stream);
249
250        if state.error.is_some() || (self.is_0rtt && !state.check_0rtt()) {
251            return;
252        }
253        match state.conn.send_stream(self.stream).finish() {
254            Ok(()) => state.wake(),
255            Err(FinishError::Stopped(reason)) => {
256                if state.conn.send_stream(self.stream).reset(reason).is_ok() {
257                    state.wake();
258                }
259            }
260            // Already finished or reset, which is fine.
261            Err(FinishError::ClosedStream) => {}
262        }
263    }
264}
265
266/// Errors that arise from writing to a stream
267#[derive(Debug, Error, Clone, PartialEq, Eq)]
268pub enum WriteError {
269    /// The peer is no longer accepting data on this stream
270    ///
271    /// Carries an application-defined error code.
272    #[error("sending stopped by peer: error {0}")]
273    Stopped(VarInt),
274    /// The connection was lost
275    #[error("connection lost")]
276    ConnectionLost(#[from] ConnectionError),
277    /// The stream has already been finished or reset
278    #[error("closed stream")]
279    ClosedStream,
280    /// This was a 0-RTT stream and the server rejected it
281    ///
282    /// Can only occur on clients for 0-RTT streams, which can be opened using
283    /// [`Connecting::into_0rtt()`].
284    ///
285    /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
286    #[error("0-RTT rejected")]
287    ZeroRttRejected,
288    /// Error when the stream is not ready, because it is still sending
289    /// data from a previous call
290    #[cfg(feature = "h3")]
291    #[error("stream not ready")]
292    NotReady,
293}
294
295impl TryFrom<quinn_proto::WriteError> for WriteError {
296    type Error = ();
297
298    fn try_from(value: quinn_proto::WriteError) -> Result<Self, Self::Error> {
299        use quinn_proto::WriteError::*;
300        match value {
301            Stopped(e) => Ok(Self::Stopped(e)),
302            ClosedStream => Ok(Self::ClosedStream),
303            Blocked => Err(()),
304        }
305    }
306}
307
308impl From<StoppedError> for WriteError {
309    fn from(x: StoppedError) -> Self {
310        match x {
311            StoppedError::ConnectionLost(e) => Self::ConnectionLost(e),
312            StoppedError::ZeroRttRejected => Self::ZeroRttRejected,
313        }
314    }
315}
316
317impl From<WriteError> for io::Error {
318    fn from(x: WriteError) -> Self {
319        use WriteError::*;
320        let kind = match x {
321            Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset,
322            ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
323            #[cfg(feature = "h3")]
324            NotReady => io::ErrorKind::Other,
325        };
326        Self::new(kind, x)
327    }
328}
329
330impl AsyncWrite for SendStream {
331    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
332        let res = self.write(buf.as_slice()).await.map_err(Into::into);
333        BufResult(res, buf)
334    }
335
336    async fn flush(&mut self) -> io::Result<()> {
337        Ok(())
338    }
339
340    async fn shutdown(&mut self) -> io::Result<()> {
341        self.finish()?;
342        Ok(())
343    }
344}
345
346#[cfg(feature = "io-compat")]
347impl futures_util::AsyncWrite for SendStream {
348    fn poll_write(
349        self: std::pin::Pin<&mut Self>,
350        cx: &mut Context<'_>,
351        buf: &[u8],
352    ) -> Poll<io::Result<usize>> {
353        self.get_mut()
354            .execute_poll_write(cx, |mut stream| stream.write(buf))
355            .map_err(Into::into)
356    }
357
358    fn poll_flush(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
359        Poll::Ready(Ok(()))
360    }
361
362    fn poll_close(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
363        self.get_mut().finish()?;
364        Poll::Ready(Ok(()))
365    }
366}
367
368#[cfg(feature = "h3")]
369pub(crate) mod h3_impl {
370    use compio_buf::bytes::Buf;
371    use h3::quic::{self, StreamErrorIncoming, WriteBuf};
372
373    use super::*;
374
375    impl From<WriteError> for StreamErrorIncoming {
376        fn from(e: WriteError) -> Self {
377            use WriteError::*;
378            match e {
379                Stopped(code) => Self::StreamTerminated {
380                    error_code: code.into_inner(),
381                },
382                ConnectionLost(e) => Self::ConnectionErrorIncoming {
383                    connection_error: e.into(),
384                },
385
386                e => Self::Unknown(Box::new(e)),
387            }
388        }
389    }
390
391    /// A wrapper around `SendStream` that implements `quic::SendStream` and
392    /// `quic::SendStreamUnframed`.
393    pub struct SendStream<B> {
394        inner: super::SendStream,
395        buf: Option<WriteBuf<B>>,
396    }
397
398    impl<B> SendStream<B> {
399        pub(crate) fn new(conn: Arc<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
400            Self {
401                inner: super::SendStream::new(conn, stream, is_0rtt),
402                buf: None,
403            }
404        }
405    }
406
407    impl<B> quic::SendStream<B> for SendStream<B>
408    where
409        B: Buf,
410    {
411        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
412            if let Some(data) = &mut self.buf {
413                while data.has_remaining() {
414                    let n = ready!(
415                        self.inner
416                            .execute_poll_write(cx, |mut stream| stream.write(data.chunk()))
417                    )?;
418                    data.advance(n);
419                }
420            }
421            self.buf = None;
422            Poll::Ready(Ok(()))
423        }
424
425        fn send_data<T: Into<WriteBuf<B>>>(&mut self, data: T) -> Result<(), StreamErrorIncoming> {
426            if self.buf.is_some() {
427                return Err(WriteError::NotReady.into());
428            }
429            self.buf = Some(data.into());
430            Ok(())
431        }
432
433        fn poll_finish(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
434            Poll::Ready(
435                self.inner
436                    .finish()
437                    .map_err(|_| WriteError::ClosedStream.into()),
438            )
439        }
440
441        fn reset(&mut self, reset_code: u64) {
442            self.inner
443                .reset(reset_code.try_into().unwrap_or(VarInt::MAX))
444                .ok();
445        }
446
447        fn send_id(&self) -> quic::StreamId {
448            u64::from(self.inner.stream).try_into().unwrap()
449        }
450    }
451
452    impl<B> quic::SendStreamUnframed<B> for SendStream<B>
453    where
454        B: Buf,
455    {
456        fn poll_send<D: Buf>(
457            &mut self,
458            cx: &mut Context<'_>,
459            buf: &mut D,
460        ) -> Poll<Result<usize, StreamErrorIncoming>> {
461            // This signifies a bug in implementation
462            debug_assert!(
463                self.buf.is_some(),
464                "poll_send called while send stream is not ready"
465            );
466
467            let n = ready!(
468                self.inner
469                    .execute_poll_write(cx, |mut stream| stream.write(buf.chunk()))
470            )?;
471            buf.advance(n);
472            Poll::Ready(Ok(n))
473        }
474    }
475}